#ifndef NGCORE_MPIWRAPPER_HPP #define NGCORE_MPIWRAPPER_HPP #ifdef PARALLEL #define OMPI_SKIP_MPICXX #include #endif namespace ngcore { #ifdef PARALLEL template struct MPI_typetrait { }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_INT; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_SHORT; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_CHAR; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_CHAR; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_UINT64_T; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_DOUBLE; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_C_BOOL; } }; template ::MPIType())> inline MPI_Datatype GetMPIType () { return MPI_typetrait::MPIType(); } class NgMPI_Comm { protected: MPI_Comm comm; bool valid_comm; int * refcount; int rank, size; public: NgMPI_Comm () : valid_comm(false), refcount(nullptr), rank(0), size(1) { ; } NgMPI_Comm (MPI_Comm _comm, bool owns = false) : comm(_comm), valid_comm(true) { if (!owns) refcount = nullptr; else refcount = new int{1}; MPI_Comm_rank(comm, &rank); MPI_Comm_size(comm, &size); } NgMPI_Comm (const NgMPI_Comm & c) : comm(c.comm), valid_comm(c.valid_comm), refcount(c.refcount), rank(c.rank), size(c.size) { if (refcount) (*refcount)++; } NgMPI_Comm (NgMPI_Comm && c) : comm(c.comm), valid_comm(c.valid_comm), refcount(c.refcount), rank(c.rank), size(c.size) { c.refcount = nullptr; } ~NgMPI_Comm() { if (refcount) if (--(*refcount) == 0) MPI_Comm_free(&comm); } NgMPI_Comm & operator= (const NgMPI_Comm & c) { if (refcount) if (--(*refcount) == 0) MPI_Comm_free(&comm); refcount = c.refcount; if (refcount) (*refcount)++; comm = c.comm; valid_comm = c.valid_comm; size = c.size; rank = c.rank; return *this; } class InvalidCommException : public Exception { public: InvalidCommException() : Exception("Do not have a valid communicator") { ; } }; operator MPI_Comm() const { if (!valid_comm) throw InvalidCommException(); return comm; } int Rank() const { return rank; } int Size() const { return size; } void Barrier() const { if (size > 1) MPI_Barrier (comm); } /** --- blocking P2P --- **/ template())> void Send (T & val, int dest, int tag) const { MPI_Send (&val, 1, GetMPIType(), dest, tag, comm); } template())> void Recv (T & val, int src, int tag) const { MPI_Recv (&val, 1, GetMPIType(), src, tag, comm, MPI_STATUS_IGNORE); } /** --- non-blocking P2P --- **/ template())> MPI_Request ISend (T & val, int dest, int tag) const { MPI_Request request; MPI_Isend (&val, 1, GetMPIType(), dest, tag, comm, &request); return request; } template())> MPI_Request IRecv (T & val, int dest, int tag) const { MPI_Request request; MPI_Irecv (&val, 1, GetMPIType(), dest, tag, comm, &request); return request; } /** --- collectives --- **/ template ())> T Reduce (T d, const MPI_Op & op, int root = 0) { if (size == 1) return d; T global_d; MPI_Reduce (&d, &global_d, 1, GetMPIType(), op, root, comm); return global_d; } template ())> T AllReduce (T d, const MPI_Op & op) const { if (size == 1) return d; T global_d; MPI_Allreduce ( &d, &global_d, 1, GetMPIType(), op, comm); return global_d; } template ())> void Bcast (T & s, int root = 0) const { if (size == 1) return ; MPI_Bcast (&s, 1, GetMPIType(), root, comm); } void Bcast (std::string & s, int root = 0) const { if (size == 1) return; int len = s.length(); Bcast (len, root); if (rank != 0) s.resize (len); MPI_Bcast (&s[0], len, MPI_CHAR, root, comm); } }; #else class MPI_Comm { int nr; public: MPI_Comm (int _nr = 0) : nr(_nr) { ; } operator int() const { return nr; } bool operator== (MPI_Comm c2) const { return nr == c2.nr; } }; static MPI_Comm MPI_COMM_WORLD = 12345, MPI_COMM_NULL = 10000; typedef int MPI_Op; typedef int MPI_Request; enum { MPI_SUM = 0, MPI_MIN = 1, MPI_MAX = 2 }; class NgMPI_Comm { public: NgMPI_Comm () { ; } NgMPI_Comm (MPI_Comm _comm, bool owns = false) { ; } size_t Rank() const { return 0; } size_t Size() const { return 1; } void Barrier() const { ; } operator MPI_Comm() const { return MPI_Comm(); } template void Send( T & val, int dest, int tag) const { ; } template void MyMPI_Recv (T & val, int src, int tag) const { ; } template MPI_Request ISend (T & val, int dest, int tag) const { return 0; } template MPI_Request IRecv (T & val, int dest, int tag) const { return 0; } template T Reduce (T d, const MPI_Op & op, int root = 0) { return d; } template T AllReduce (T d, const MPI_Op & op) const { return d; } template void Bcast (T & s, int root = 0) const { ; } }; #endif } #endif