#ifndef NGCORE_MPIWRAPPER_HPP #define NGCORE_MPIWRAPPER_HPP #ifdef PARALLEL #define OMPI_SKIP_MPICXX #include #endif namespace ngcore { #ifdef PARALLEL template struct MPI_typetrait { }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_INT; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_SHORT; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_CHAR; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_CHAR; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_UINT64_T; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_DOUBLE; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_C_BOOL; } }; template ::MPIType())> inline MPI_Datatype GetMPIType () { return MPI_typetrait::MPIType(); } class NgMPI_Comm { MPI_Comm comm; int * refcount; int rank, size; public: NgMPI_Comm () : refcount(nullptr), rank(0), size(1) { ; } NgMPI_Comm (MPI_Comm _comm, bool owns = false) : comm(_comm) { if (!owns) refcount = nullptr; else refcount = new int{1}; MPI_Comm_rank(comm, &rank); MPI_Comm_size(comm, &size); } NgMPI_Comm (const NgMPI_Comm & c) : comm(c.comm), refcount(c.refcount), rank(c.rank), size(c.size) { if (refcount) (*refcount)++; } NgMPI_Comm (NgMPI_Comm && c) : comm(c.comm), refcount(c.refcount), rank(c.rank), size(c.size) { c.refcount = nullptr; } ~NgMPI_Comm() { if (refcount) if (--(*refcount) == 0) MPI_Comm_free(&comm); } NgMPI_Comm & operator= (const NgMPI_Comm & c) { if (refcount) if (--(*refcount) == 0) MPI_Comm_free(&comm); refcount = c.refcount; if (refcount) (*refcount)++; comm = c.comm; size = c.size; rank = c.rank; return *this; } operator MPI_Comm() const { return comm; } int Rank() const { return rank; } int Size() const { return size; } void Barrier() const { if (size > 1) MPI_Barrier (comm); } template())> void Send (T & val, int dest, int tag) const { MPI_Send (&val, 1, GetMPIType(), dest, tag, comm); } template())> void MyMPI_Recv (T & val, int src, int tag) const { MPI_Recv (&val, 1, GetMPIType(), src, tag, comm, MPI_STATUS_IGNORE); } /** --- collectives --- **/ template ())> T AllReduce (T d, const MPI_Op & op) const { if (size == 1) return d; T global_d; MPI_Allreduce ( &d, &global_d, 1, GetMPIType(), op, comm); return global_d; } template ())> void Bcast (T & s, int root = 0) const { if (size == 1) return ; MPI_Bcast (&s, 1, GetMPIType(), root, comm); } void Bcast (std::string & s, int root = 0) const { if (size == 1) return; int len = s.length(); Bcast (len, root); if (rank != 0) s.resize (len); MPI_Bcast (&s[0], len, MPI_CHAR, root, comm); } }; #else class MPI_Comm { int nr; public: MPI_Comm (int _nr = 0) : nr(_nr) { ; } operator int() const { return nr; } bool operator== (MPI_Comm c2) const { return nr == c2.nr; } }; static MPI_Comm MPI_COMM_WORLD = 12345, MPI_COMM_NULL = 10000; class NgMPI_Comm { public: NgMPI_Comm () { ; } NgMPI_Comm (MPI_Comm _comm, bool owns = false) { ; } size_t Rank() const { return 0; } size_t Size() const { return 1; } void Barrier() const { ; } operator MPI_Comm() const { return MPI_Comm(); } template void Send( T & val, int dest, int tag) const { ; } template void MyMPI_Recv (T & val, int src, int tag) const { ; } template T AllReduce (T d, const MPI_Op & op) const { return d; } template INLINE void Bcast (T & s, int root = 0) const { ; } }; #endif } #endif