diff --git a/libsrc/core/mpi_wrapper.hpp b/libsrc/core/mpi_wrapper.hpp index 29f05cf1..4b49266a 100644 --- a/libsrc/core/mpi_wrapper.hpp +++ b/libsrc/core/mpi_wrapper.hpp @@ -44,6 +44,7 @@ namespace ngcore class NgMPI_Comm { + protected: MPI_Comm comm; int * refcount; int rank, size; @@ -108,6 +109,7 @@ namespace ngcore } + /** --- blocking P2P --- **/ template())> void Send (T & val, int dest, int tag) const { @@ -115,13 +117,41 @@ namespace ngcore } template())> - void MyMPI_Recv (T & val, int src, int tag) const { + void Recv (T & val, int src, int tag) const { MPI_Recv (&val, 1, GetMPIType(), src, tag, comm, MPI_STATUS_IGNORE); } + /** --- non-blocking P2P --- **/ + + template())> + MPI_Request ISend (T & val, int dest, int tag) const + { + MPI_Request request; + MPI_Isend (&val, 1, GetMPIType(), dest, tag, comm, &request); + return request; + } + + template())> + MPI_Request IRecv (T & val, int dest, int tag) const + { + MPI_Request request; + MPI_Irecv (&val, 1, GetMPIType(), dest, tag, comm, &request); + return request; + } + /** --- collectives --- **/ + template ())> + T Reduce (T d, const MPI_Op & op, int root = 0) + { + if (size == 1) return d; + + T global_d; + MPI_Reduce (&d, &global_d, 1, GetMPIType(), op, root, comm); + return global_d; + } + template ())> T AllReduce (T d, const MPI_Op & op) const { @@ -162,6 +192,8 @@ namespace ngcore static MPI_Comm MPI_COMM_WORLD = 12345, MPI_COMM_NULL = 10000; typedef int MPI_Op; + typedef int MPI_Request; + enum { MPI_SUM = 0, MPI_MIN = 1, MPI_MAX = 2 }; class NgMPI_Comm @@ -182,6 +214,15 @@ namespace ngcore template void MyMPI_Recv (T & val, int src, int tag) const { ; } + template + MPI_Request ISend (T & val, int dest, int tag) const { return 0; } + + template + MPI_Request IRecv (T & val, int dest, int tag) const { return 0; } + + template + T Reduce (T d, const MPI_Op & op, int root = 0) { return d; } + template T AllReduce (T d, const MPI_Op & op) const { return d; }