diff --git a/libsrc/core/mpi_wrapper.hpp b/libsrc/core/mpi_wrapper.hpp index b42c551e..4407173d 100644 --- a/libsrc/core/mpi_wrapper.hpp +++ b/libsrc/core/mpi_wrapper.hpp @@ -6,6 +6,7 @@ #include #endif +#include "array.hpp" #include "exception.hpp" namespace ngcore @@ -127,11 +128,32 @@ namespace ngcore MPI_Send (&val, 1, GetMPIType(), dest, tag, comm); } + template())> + void Send(FlatArray s, int dest, int tag) const { + MPI_Send (s.Data(), s.Size(), GetMPIType(), dest, tag, comm); + } + template())> void Recv (T & val, int src, int tag) const { MPI_Recv (&val, 1, GetMPIType(), src, tag, comm, MPI_STATUS_IGNORE); } + template ())> + void Recv (FlatArray s, int src, int tag) const { + MPI_Recv (s.Data(), s.Size(), GetMPIType (), src, tag, comm, MPI_STATUS_IGNORE); + } + + template ())> + void Recv (Array & s, int src, int tag) const + { + MPI_Status status; + int len; + const MPI_Datatype MPI_T = GetMPIType (); + MPI_Probe (src, tag, comm, &status); + MPI_Get_count (&status, MPI_T, &len); + s.SetSize (len); + MPI_Recv (s.Data(), len, MPI_T, src, tag, comm, MPI_STATUS_IGNORE); + } /** --- non-blocking P2P --- **/ @@ -144,13 +166,22 @@ namespace ngcore } template())> - MPI_Request IRecv (T & val, int dest, int tag) const + MPI_Request IRecv (T & val, int src, int tag) const { MPI_Request request; - MPI_Irecv (&val, 1, GetMPIType(), dest, tag, comm, &request); + MPI_Irecv (&val, 1, GetMPIType(), src, tag, comm, &request); return request; } + template())> + MPI_Request IRecv (const FlatArray & s, int src, int tag) const + { + MPI_Request request; + MPI_Irecv (s.Data(), s.Size(), GetMPIType(), src, tag, comm, &request); + return request; + } + + /** --- collectives --- **/ template ())> @@ -188,10 +219,21 @@ namespace ngcore MPI_Bcast (&s[0], len, MPI_CHAR, root, comm); } - - }; + }; // class NgMPI_Comm + NETGEN_INLINE void MyMPI_WaitAll (FlatArray requests) + { + if (!requests.Size()) return; + MPI_Waitall (requests.Size(), requests.Data(), MPI_STATUSES_IGNORE); + } + NETGEN_INLINE int MyMPI_WaitAny (FlatArray requests) + { + int nr; + MPI_Waitany (requests.Size(), requests.Data(), &nr, MPI_STATUS_IGNORE); + return nr; + } + #else // PARALLEL class MPI_Comm { int nr; @@ -223,24 +265,44 @@ namespace ngcore void Send( T & val, int dest, int tag) const { ; } template - void MyMPI_Recv (T & val, int src, int tag) const { ; } + void Send(FlatArray s, int dest, int tag) const { ; } + + template + void Recv (T & val, int src, int tag) const { ; } + + template + void Recv (FlatArray s, int src, int tag) const { ; } + + template + void Recv (Array & s, int src, int tag) const { ; } template MPI_Request ISend (T & val, int dest, int tag) const { return 0; } + template + MPI_Request ISend (const FlatArray & s, int dest, int tag) const { return 0; } + template MPI_Request IRecv (T & val, int dest, int tag) const { return 0; } + template + MPI_Request IRecv (const FlatArray & s, int src, int tag) const { return 0; } + template - T Reduce (T d, const MPI_Op & op, int root = 0) { return d; } + T Reduce (T d, const MPI_Op & op, int root = 0) const { return d; } template T AllReduce (T d, const MPI_Op & op) const { return d; } template void Bcast (T & s, int root = 0) const { ; } + + NgMPI_Comm SubCommunicator (FlatArray procs) const + { return *this; } }; - + + NETGEN_INLINE void MyMPI_WaitAll (FlatArray requests) { ; } + #endif // PARALLEL