netgen/libsrc/general/mpi_interface.hpp

416 lines
11 KiB
C++
Raw Normal View History

2009-01-13 04:40:13 +05:00
#ifndef FILE_PARALLEL
#define FILE_PARALLEL
#ifdef VTRACE
#include "vt_user.h"
#else
#define VT_USER_START(n)
#define VT_USER_END(n)
#define VT_TRACER(n)
#endif
2009-07-20 14:36:36 +06:00
namespace netgen
{
2019-01-14 17:04:27 +05:00
using ngcore::id;
using ngcore::ntasks;
2009-01-13 04:40:13 +05:00
#ifndef PARALLEL
2019-01-31 00:55:45 +05:00
/** without MPI, we need a dummy typedef **/
typedef int MPI_Comm;
#endif
/** This is the "standard" communicator that will be used for netgen-objects. **/
2019-02-01 20:12:30 +05:00
extern DLL_HEADER MPI_Comm ng_comm;
2011-07-07 03:08:58 +06:00
#ifdef PARALLEL
inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm)
{
int ntasks;
MPI_Comm_size(comm, &ntasks);
return ntasks;
}
inline int MyMPI_GetId (MPI_Comm comm = ng_comm)
{
int id;
MPI_Comm_rank(comm, &id);
return id;
}
2019-01-31 00:55:45 +05:00
#else
enum { MPI_COMM_WORLD = 12345, MPI_COMM_NULL = 0};
inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm) { return 1; }
inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { return 0; }
#endif
2019-01-31 00:55:45 +05:00
#ifdef PARALLEL
// For python wrapping of communicators
struct PyMPI_Comm {
MPI_Comm comm;
bool owns_comm;
PyMPI_Comm (MPI_Comm _comm, bool _owns_comm = false) : comm(_comm), owns_comm(_owns_comm) { }
PyMPI_Comm (const PyMPI_Comm & c) = delete;
~PyMPI_Comm () {
if (owns_comm)
MPI_Comm_free(&comm);
}
inline int Rank() const { return MyMPI_GetId(comm); }
inline int Size() const { return MyMPI_GetNTasks(comm); }
};
#else
// dummy without MPI
struct PyMPI_Comm {
MPI_Comm comm = 0;
PyMPI_Comm (MPI_Comm _comm, bool _owns_comm = false) { }
~PyMPI_Comm () { }
inline int Rank() const { return 0; }
inline int Size() const { return 1; }
};
#endif
2009-01-13 04:40:13 +05:00
2019-01-31 00:55:45 +05:00
#ifdef PARALLEL
2009-01-13 04:40:13 +05:00
template <class T>
2019-01-31 00:55:45 +05:00
inline MPI_Datatype MyGetMPIType ( )
2012-06-16 22:58:46 +06:00
{ cerr << "ERROR in GetMPIType() -- no type found" << endl;return 0; }
2009-01-13 04:40:13 +05:00
template <>
2012-06-16 22:58:46 +06:00
inline MPI_Datatype MyGetMPIType<int> ( )
2009-01-13 04:40:13 +05:00
{ return MPI_INT; }
template <>
inline MPI_Datatype MyGetMPIType<double> ( )
{ return MPI_DOUBLE; }
template <>
inline MPI_Datatype MyGetMPIType<char> ( )
{ return MPI_CHAR; }
2019-01-31 00:55:45 +05:00
template<>
inline MPI_Datatype MyGetMPIType<size_t> ( )
{ return MPI_UINT64_T; }
#else
typedef int MPI_Datatype;
template <class T> inline MPI_Datatype MyGetMPIType ( ) { return 0; }
#endif
#ifdef PARALLEL
inline MPI_Comm MyMPI_SubCommunicator(MPI_Comm comm, Array<int> & procs)
{
MPI_Comm subcomm;
MPI_Group gcomm, gsubcomm;
MPI_Comm_group(comm, &gcomm);
MPI_Group_incl(gcomm, procs.Size(), &(procs[0]), &gsubcomm);
MPI_Comm_create_group(comm, gsubcomm, 6969, &subcomm);
return subcomm;
}
#else
inline MPI_Comm MyMPI_SubCommunicator(MPI_Comm comm, Array<int> & procs)
{ return comm; }
#endif
#ifdef PARALLEL
enum { MPI_TAG_CMD = 110 };
enum { MPI_TAG_MESH = 210 };
enum { MPI_TAG_VIS = 310 };
inline void MyMPI_Send (int i, int dest, int tag, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
int hi = i;
MPI_Send( &hi, 1, MPI_INT, dest, tag, comm);
2009-01-13 04:40:13 +05:00
}
inline void MyMPI_Recv (int & i, int src, int tag, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
MPI_Status status;
MPI_Recv( &i, 1, MPI_INT, src, tag, comm, &status);
2009-01-13 04:40:13 +05:00
}
inline void MyMPI_Send (const string & s, int dest, int tag, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
MPI_Send( const_cast<char*> (s.c_str()), s.length(), MPI_CHAR, dest, tag, comm);
2009-01-13 04:40:13 +05:00
}
inline void MyMPI_Recv (string & s, int src, int tag, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
MPI_Status status;
int len;
2011-07-07 03:08:58 +06:00
MPI_Probe (src, tag, MPI_COMM_WORLD, &status);
2009-01-13 04:40:13 +05:00
MPI_Get_count (&status, MPI_CHAR, &len);
s.assign (len, ' ');
MPI_Recv( &s[0], len, MPI_CHAR, src, tag, comm, &status);
2009-01-13 04:40:13 +05:00
}
template <class T, int BASE>
inline void MyMPI_Send (FlatArray<T, BASE> s, int dest, int tag, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
MPI_Send( &s.First(), s.Size(), MyGetMPIType<T>(), dest, tag, comm);
2009-01-13 04:40:13 +05:00
}
template <class T, int BASE>
inline void MyMPI_Recv ( FlatArray<T, BASE> s, int src, int tag, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
MPI_Status status;
MPI_Recv( &s.First(), s.Size(), MyGetMPIType<T>(), src, tag, comm, &status);
2009-01-13 04:40:13 +05:00
}
template <class T, int BASE>
inline void MyMPI_Recv ( Array <T, BASE> & s, int src, int tag, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
MPI_Status status;
int len;
MPI_Probe (src, tag, comm, &status);
2009-01-13 04:40:13 +05:00
MPI_Get_count (&status, MyGetMPIType<T>(), &len);
s.SetSize (len);
MPI_Recv( &s.First(), len, MyGetMPIType<T>(), src, tag, comm, &status);
2009-01-13 04:40:13 +05:00
}
template <class T, int BASE>
inline int MyMPI_Recv ( Array <T, BASE> & s, int tag, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
MPI_Status status;
int len;
MPI_Probe (MPI_ANY_SOURCE, tag, comm, &status);
2009-01-13 04:40:13 +05:00
int src = status.MPI_SOURCE;
MPI_Get_count (&status, MyGetMPIType<T>(), &len);
s.SetSize (len);
MPI_Recv( &s.First(), len, MyGetMPIType<T>(), src, tag, comm, &status);
2009-01-13 04:40:13 +05:00
return src;
}
2011-07-15 03:36:19 +06:00
/*
2009-01-13 04:40:13 +05:00
template <class T, int BASE>
2011-07-07 03:08:58 +06:00
inline void MyMPI_ISend (FlatArray<T, BASE> s, int dest, int tag, MPI_Request & request)
2009-01-13 04:40:13 +05:00
{
2011-07-07 03:08:58 +06:00
MPI_Isend( &s.First(), s.Size(), MyGetMPIType<T>(), dest, tag, MPI_COMM_WORLD, & request);
2009-01-13 04:40:13 +05:00
}
template <class T, int BASE>
2011-07-07 03:08:58 +06:00
inline void MyMPI_IRecv (FlatArray<T, BASE> s, int dest, int tag, MPI_Request & request)
2009-01-13 04:40:13 +05:00
{
2011-07-07 03:08:58 +06:00
MPI_Irecv( &s.First(), s.Size(), MyGetMPIType<T>(), dest, tag, MPI_COMM_WORLD, & request);
2009-01-13 04:40:13 +05:00
}
2011-07-07 03:08:58 +06:00
*/
2009-01-13 04:40:13 +05:00
template <class T, int BASE>
inline MPI_Request MyMPI_ISend (FlatArray<T, BASE> s, int dest, int tag, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
MPI_Request request;
2011-07-15 03:36:19 +06:00
MPI_Isend( &s.First(), s.Size(), MyGetMPIType<T>(), dest, tag, comm, &request);
2009-01-13 04:40:13 +05:00
return request;
}
template <class T, int BASE>
inline MPI_Request MyMPI_IRecv (FlatArray<T, BASE> s, int dest, int tag, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
MPI_Request request;
2011-07-15 03:36:19 +06:00
MPI_Irecv( &s.First(), s.Size(), MyGetMPIType<T>(), dest, tag, comm, &request);
2009-01-13 04:40:13 +05:00
return request;
}
2011-07-04 18:29:18 +06:00
2011-07-15 03:36:19 +06:00
/*
2011-07-04 18:29:18 +06:00
template <class T, int BASE>
2011-07-07 03:08:58 +06:00
inline void MyMPI_ISend (FlatArray<T, BASE> s, int dest, int tag)
2011-07-04 18:29:18 +06:00
{
MPI_Request request;
2011-07-07 03:08:58 +06:00
MPI_Isend( &s.First(), s.Size(), MyGetMPIType<T>(), dest, tag, MPI_COMM_WORLD, &request);
2011-07-04 18:29:18 +06:00
MPI_Request_free (&request);
}
template <class T, int BASE>
2011-07-07 03:08:58 +06:00
inline void MyMPI_IRecv (FlatArray<T, BASE> s, int dest, int tag)
2011-07-04 18:29:18 +06:00
{
MPI_Request request;
2011-07-07 03:08:58 +06:00
MPI_Irecv( &s.First(), s.Size(), MyGetMPIType<T>(), dest, tag, MPI_COMM_WORLD, &request);
2011-07-04 18:29:18 +06:00
MPI_Request_free (&request);
}
2011-07-15 03:36:19 +06:00
*/
2009-01-13 04:40:13 +05:00
2012-06-16 18:03:36 +06:00
/*
2018-01-08 20:45:53 +05:00
send a table entry to each of the processes in the group ...
2012-06-16 18:03:36 +06:00
receive-table entries will be set
*/
2012-09-03 15:49:18 +06:00
2012-08-31 01:24:20 +06:00
/*
2012-06-16 18:03:36 +06:00
template <typename T>
2012-08-20 20:10:23 +06:00
inline void MyMPI_ExchangeTable (TABLE<T> & send_data,
TABLE<T> & recv_data, int tag,
2012-06-16 18:03:36 +06:00
MPI_Comm comm = MPI_COMM_WORLD)
{
int ntasks, rank;
MPI_Comm_size(comm, &ntasks);
MPI_Comm_rank(comm, &rank);
Array<MPI_Request> requests;
for (int dest = 0; dest < ntasks; dest++)
if (dest != rank)
2012-08-20 20:10:23 +06:00
requests.Append (MyMPI_ISend (send_data[dest], dest, tag, comm));
2012-06-16 18:03:36 +06:00
for (int i = 0; i < ntasks-1; i++)
{
MPI_Status status;
MPI_Probe (MPI_ANY_SOURCE, tag, comm, &status);
int size, src = status.MPI_SOURCE;
MPI_Get_count (&status, MPI_INT, &size);
2012-09-03 15:49:18 +06:00
recv_data.SetEntrySize (src, size, sizeof(T));
2012-08-20 20:10:23 +06:00
requests.Append (MyMPI_IRecv (recv_data[src], src, tag, comm));
2012-06-16 18:03:36 +06:00
}
2012-08-20 20:10:23 +06:00
MPI_Barrier (comm);
2012-06-16 18:03:36 +06:00
MPI_Waitall (requests.Size(), &requests[0], MPI_STATUS_IGNORE);
}
2012-08-31 01:24:20 +06:00
*/
template <typename T>
inline void MyMPI_ExchangeTable (TABLE<T> & send_data,
TABLE<T> & recv_data, int tag,
MPI_Comm comm = ng_comm)
2012-08-31 01:24:20 +06:00
{
int rank = MyMPI_GetId(comm);
int ntasks = MyMPI_GetNTasks(comm);
2012-08-31 01:24:20 +06:00
Array<int> send_sizes(ntasks);
Array<int> recv_sizes(ntasks);
for (int i = 0; i < ntasks; i++)
send_sizes[i] = send_data[i].Size();
MPI_Alltoall (&send_sizes[0], 1, MPI_INT,
&recv_sizes[0], 1, MPI_INT, comm);
2012-09-03 15:49:18 +06:00
2012-08-31 01:24:20 +06:00
// in-place is buggy !
2012-09-03 15:49:18 +06:00
// MPI_Alltoall (MPI_IN_PLACE, 1, MPI_INT,
// &recv_sizes[0], 1, MPI_INT, comm);
2012-08-31 01:24:20 +06:00
for (int i = 0; i < ntasks; i++)
2012-09-03 15:49:18 +06:00
recv_data.SetEntrySize (i, recv_sizes[i], sizeof(T));
2012-08-31 01:24:20 +06:00
Array<MPI_Request> requests;
for (int dest = 0; dest < ntasks; dest++)
if (dest != rank && send_data[dest].Size())
requests.Append (MyMPI_ISend (send_data[dest], dest, tag, comm));
for (int dest = 0; dest < ntasks; dest++)
if (dest != rank && recv_data[dest].Size())
requests.Append (MyMPI_IRecv (recv_data[dest], dest, tag, comm));
// MPI_Barrier (comm);
MPI_Waitall (requests.Size(), &requests[0], MPI_STATUS_IGNORE);
}
2012-06-16 18:03:36 +06:00
2012-06-16 22:58:46 +06:00
extern void MyMPI_SendCmd (const char * cmd);
extern string MyMPI_RecvCmd ();
2009-01-13 04:40:13 +05:00
2011-11-01 18:54:07 +06:00
2012-06-16 21:22:46 +06:00
2009-01-13 04:40:13 +05:00
template <class T>
inline void MyMPI_Bcast (T & s, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
MPI_Bcast (&s, 1, MyGetMPIType<T>(), 0, comm);
}
template <class T>
inline void MyMPI_Bcast (Array<T, 0> & s, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
int size = s.Size();
MyMPI_Bcast (size, comm);
if (MyMPI_GetId(comm) != 0) s.SetSize (size);
2009-01-13 04:40:13 +05:00
MPI_Bcast (&s[0], size, MyGetMPIType<T>(), 0, comm);
}
template <class T>
inline void MyMPI_Bcast (Array<T, 0> & s, int root, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
int id;
2011-06-26 13:35:08 +06:00
MPI_Comm_rank(comm, &id);
2009-01-13 04:40:13 +05:00
int size = s.Size();
MPI_Bcast (&size, 1, MPI_INT, root, comm);
if (id != root) s.SetSize (size);
if ( !size ) return;
MPI_Bcast (&s[0], size, MyGetMPIType<T>(), root, comm);
}
template <class T, class T2>
inline void MyMPI_Allgather (const T & send, FlatArray<T2> recv, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
MPI_Allgather( const_cast<T*> (&send), 1, MyGetMPIType<T>(), &recv[0], 1, MyGetMPIType<T2>(), comm);
}
template <class T, class T2>
inline void MyMPI_Alltoall (FlatArray<T> send, FlatArray<T2> recv, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
{
MPI_Alltoall( &send[0], 1, MyGetMPIType<T>(), &recv[0], 1, MyGetMPIType<T2>(), comm);
}
// template <class T, class T2>
// inline void MyMPI_Alltoall_Block (FlatArray<T> send, FlatArray<T2> recv, int blocklen, MPI_Comm comm = ng_comm)
2009-01-13 04:40:13 +05:00
// {
// MPI_Alltoall( &send[0], blocklen, MyGetMPIType<T>(), &recv[0], blocklen, MyGetMPIType<T2>(), comm);
// }
2011-07-15 03:36:19 +06:00
/*
2011-07-14 00:27:17 +06:00
inline void MyMPI_Send ( int *& s, int len, int dest, int tag)
2009-01-13 04:40:13 +05:00
{
2011-07-14 00:27:17 +06:00
int hlen = len;
2011-07-14 00:32:11 +06:00
MPI_Send( &hlen, 1, MPI_INT, dest, tag, MPI_COMM_WORLD);
2011-07-14 00:27:17 +06:00
MPI_Send( s, len, MPI_INT, dest, tag, MPI_COMM_WORLD);
2009-01-13 04:40:13 +05:00
}
2011-07-07 03:08:58 +06:00
inline void MyMPI_Recv ( int *& s, int & len, int src, int tag)
2009-01-13 04:40:13 +05:00
{
MPI_Status status;
2011-07-07 03:08:58 +06:00
MPI_Recv( &len, 1, MPI_INT, src, tag, MPI_COMM_WORLD, &status);
2009-01-13 04:40:13 +05:00
if ( s )
delete [] s;
s = new int [len];
2011-07-07 03:08:58 +06:00
MPI_Recv( s, len, MPI_INT, src, tag, MPI_COMM_WORLD, &status);
2009-01-13 04:40:13 +05:00
}
2011-07-07 03:08:58 +06:00
inline void MyMPI_Send ( double * s, int len, int dest, int tag)
2009-01-13 04:40:13 +05:00
{
2011-07-07 03:08:58 +06:00
MPI_Send( &len, 1, MPI_INT, dest, tag, MPI_COMM_WORLD);
MPI_Send( s, len, MPI_DOUBLE, dest, tag, MPI_COMM_WORLD);
2009-01-13 04:40:13 +05:00
}
2011-07-07 03:08:58 +06:00
inline void MyMPI_Recv ( double *& s, int & len, int src, int tag)
2009-01-13 04:40:13 +05:00
{
MPI_Status status;
2011-07-07 03:08:58 +06:00
MPI_Recv( &len, 1, MPI_INT, src, tag, MPI_COMM_WORLD, &status);
2009-01-13 04:40:13 +05:00
if ( s )
delete [] s;
s = new double [len];
2011-07-07 03:08:58 +06:00
MPI_Recv( s, len, MPI_DOUBLE, src, tag, MPI_COMM_WORLD, &status);
2009-01-13 04:40:13 +05:00
}
2011-07-15 03:36:19 +06:00
*/
2009-01-13 04:40:13 +05:00
#endif // PARALLEL
2009-07-20 14:36:36 +06:00
}
2009-01-13 04:40:13 +05:00
#endif