#ifndef NGCORE_MPIWRAPPER_HPP #define NGCORE_MPIWRAPPER_HPP #include <array> #include <complex> #include "array.hpp" #include "table.hpp" #include "exception.hpp" #include "profiler.hpp" #include "ngstream.hpp" #include "ng_mpi.hpp" namespace ngcore { #ifdef PARALLEL template <class T> struct MPI_typetrait { }; template <> struct MPI_typetrait<int> { static NG_MPI_Datatype MPIType () { return NG_MPI_INT; } }; template <> struct MPI_typetrait<short> { static NG_MPI_Datatype MPIType () { return NG_MPI_SHORT; } }; template <> struct MPI_typetrait<char> { static NG_MPI_Datatype MPIType () { return NG_MPI_CHAR; } }; template <> struct MPI_typetrait<signed char> { static NG_MPI_Datatype MPIType () { return NG_MPI_CHAR; } }; template <> struct MPI_typetrait<unsigned char> { static NG_MPI_Datatype MPIType () { return NG_MPI_CHAR; } }; template <> struct MPI_typetrait<size_t> { static NG_MPI_Datatype MPIType () { return NG_MPI_UINT64_T; } }; template <> struct MPI_typetrait<double> { static NG_MPI_Datatype MPIType () { return NG_MPI_DOUBLE; } }; template <> struct MPI_typetrait<std::complex<double>> { static NG_MPI_Datatype MPIType () { return NG_MPI_CXX_DOUBLE_COMPLEX; } }; template <> struct MPI_typetrait<bool> { static NG_MPI_Datatype MPIType () { return NG_MPI_C_BOOL; } }; template<typename T, size_t S> struct MPI_typetrait<std::array<T,S>> { static NG_MPI_Datatype MPIType () { static NG_MPI_Datatype NG_MPI_T = 0; if (!NG_MPI_T) { NG_MPI_Type_contiguous ( S, MPI_typetrait<T>::MPIType(), &NG_MPI_T); NG_MPI_Type_commit ( &NG_MPI_T ); } return NG_MPI_T; } }; template <class T, class T2 = decltype(MPI_typetrait<T>::MPIType())> inline NG_MPI_Datatype GetMPIType () { return MPI_typetrait<T>::MPIType(); } template <class T> inline NG_MPI_Datatype GetMPIType (T &) { return GetMPIType<T>(); } inline void MyMPI_WaitAll (FlatArray<NG_MPI_Request> requests) { static Timer t("MPI - WaitAll"); RegionTimer reg(t); if (!requests.Size()) return; NG_MPI_Waitall (requests.Size(), requests.Data(), NG_MPI_STATUSES_IGNORE); } inline int MyMPI_WaitAny (FlatArray<NG_MPI_Request> requests) { int nr; NG_MPI_Waitany (requests.Size(), requests.Data(), &nr, NG_MPI_STATUS_IGNORE); return nr; } class NgMPI_Comm { protected: NG_MPI_Comm comm; bool valid_comm; int * refcount; int rank, size; public: NgMPI_Comm () : valid_comm(false), refcount(nullptr), rank(0), size(1) { ; } NgMPI_Comm (NG_MPI_Comm _comm, bool owns = false) : comm(_comm), valid_comm(true) { int flag; NG_MPI_Initialized (&flag); if (!flag) { valid_comm = false; refcount = nullptr; rank = 0; size = 1; return; } if (!owns) refcount = nullptr; else refcount = new int{1}; NG_MPI_Comm_rank(comm, &rank); NG_MPI_Comm_size(comm, &size); } NgMPI_Comm (const NgMPI_Comm & c) : comm(c.comm), valid_comm(c.valid_comm), refcount(c.refcount), rank(c.rank), size(c.size) { if (refcount) (*refcount)++; } NgMPI_Comm (NgMPI_Comm && c) : comm(c.comm), valid_comm(c.valid_comm), refcount(c.refcount), rank(c.rank), size(c.size) { c.refcount = nullptr; } ~NgMPI_Comm() { if (refcount) if (--(*refcount) == 0) NG_MPI_Comm_free(&comm); } bool ValidCommunicator() const { return valid_comm; } NgMPI_Comm & operator= (const NgMPI_Comm & c) { if (refcount) if (--(*refcount) == 0) NG_MPI_Comm_free(&comm); refcount = c.refcount; if (refcount) (*refcount)++; comm = c.comm; valid_comm = c.valid_comm; size = c.size; rank = c.rank; return *this; } class InvalidCommException : public Exception { public: InvalidCommException() : Exception("Do not have a valid communicator") { ; } }; operator NG_MPI_Comm() const { if (!valid_comm) throw InvalidCommException(); return comm; } int Rank() const { return rank; } int Size() const { return size; } void Barrier() const { static Timer t("MPI - Barrier"); RegionTimer reg(t); if (size > 1) NG_MPI_Barrier (comm); } /** --- blocking P2P --- **/ template<typename T, typename T2 = decltype(GetMPIType<T>())> void Send (T & val, int dest, int tag) const { NG_MPI_Send (&val, 1, GetMPIType<T>(), dest, tag, comm); } void Send (const std::string & s, int dest, int tag) const { NG_MPI_Send( const_cast<char*> (&s[0]), s.length(), NG_MPI_CHAR, dest, tag, comm); } template<typename T, typename TI, typename T2 = decltype(GetMPIType<T>())> void Send(FlatArray<T,TI> s, int dest, int tag) const { NG_MPI_Send (s.Data(), s.Size(), GetMPIType<T>(), dest, tag, comm); } template<typename T, typename T2 = decltype(GetMPIType<T>())> void Recv (T & val, int src, int tag) const { NG_MPI_Recv (&val, 1, GetMPIType<T>(), src, tag, comm, NG_MPI_STATUS_IGNORE); } void Recv (std::string & s, int src, int tag) const { NG_MPI_Status status; int len; NG_MPI_Probe (src, tag, comm, &status); NG_MPI_Get_count (&status, NG_MPI_CHAR, &len); // s.assign (len, ' '); s.resize (len); NG_MPI_Recv( &s[0], len, NG_MPI_CHAR, src, tag, comm, NG_MPI_STATUS_IGNORE); } template <typename T, typename TI, typename T2 = decltype(GetMPIType<T>())> void Recv (FlatArray <T,TI> s, int src, int tag) const { NG_MPI_Recv (s.Data(), s.Size(), GetMPIType<T> (), src, tag, comm, NG_MPI_STATUS_IGNORE); } template <typename T, typename TI, typename T2 = decltype(GetMPIType<T>())> void Recv (Array <T,TI> & s, int src, int tag) const { NG_MPI_Status status; int len; const NG_MPI_Datatype NG_MPI_T = GetMPIType<T> (); NG_MPI_Probe (src, tag, comm, &status); NG_MPI_Get_count (&status, NG_MPI_T, &len); s.SetSize (len); NG_MPI_Recv (s.Data(), len, NG_MPI_T, src, tag, comm, NG_MPI_STATUS_IGNORE); } /** --- non-blocking P2P --- **/ template<typename T, typename T2 = decltype(GetMPIType<T>())> NG_MPI_Request ISend (T & val, int dest, int tag) const { NG_MPI_Request request; NG_MPI_Isend (&val, 1, GetMPIType<T>(), dest, tag, comm, &request); return request; } template<typename T, typename T2 = decltype(GetMPIType<T>())> NG_MPI_Request ISend (FlatArray<T> s, int dest, int tag) const { NG_MPI_Request request; NG_MPI_Isend (s.Data(), s.Size(), GetMPIType<T>(), dest, tag, comm, &request); return request; } template<typename T, typename T2 = decltype(GetMPIType<T>())> NG_MPI_Request IRecv (T & val, int dest, int tag) const { NG_MPI_Request request; NG_MPI_Irecv (&val, 1, GetMPIType<T>(), dest, tag, comm, &request); return request; } template<typename T, typename T2 = decltype(GetMPIType<T>())> NG_MPI_Request IRecv (FlatArray<T> s, int src, int tag) const { NG_MPI_Request request; NG_MPI_Irecv (s.Data(), s.Size(), GetMPIType<T>(), src, tag, comm, &request); return request; } /** --- collectives --- **/ template <typename T, typename T2 = decltype(GetMPIType<T>())> T Reduce (T d, const NG_MPI_Op & op, int root = 0) const { static Timer t("MPI - Reduce"); RegionTimer reg(t); if (size == 1) return d; T global_d; NG_MPI_Reduce (&d, &global_d, 1, GetMPIType<T>(), op, root, comm); return global_d; } template <typename T, typename T2 = decltype(GetMPIType<T>())> T AllReduce (T d, const NG_MPI_Op & op) const { static Timer t("MPI - AllReduce"); RegionTimer reg(t); if (size == 1) return d; T global_d; NG_MPI_Allreduce ( &d, &global_d, 1, GetMPIType<T>(), op, comm); return global_d; } template <typename T, typename T2 = decltype(GetMPIType<T>())> void AllReduce (FlatArray<T> d, const NG_MPI_Op & op) const { static Timer t("MPI - AllReduce Array"); RegionTimer reg(t); if (size == 1) return; NG_MPI_Allreduce (NG_MPI_IN_PLACE, d.Data(), d.Size(), GetMPIType<T>(), op, comm); } template <typename T, typename T2 = decltype(GetMPIType<T>())> void Bcast (T & s, int root = 0) const { if (size == 1) return; static Timer t("MPI - Bcast"); RegionTimer reg(t); NG_MPI_Bcast (&s, 1, GetMPIType<T>(), root, comm); } template <class T> void Bcast (Array<T> & d, int root = 0) { if (size == 1) return; int ds = d.Size(); Bcast (ds, root); if (Rank() != root) d.SetSize (ds); if (ds != 0) NG_MPI_Bcast (d.Data(), ds, GetMPIType<T>(), root, comm); } void Bcast (std::string & s, int root = 0) const { if (size == 1) return; int len = s.length(); Bcast (len, root); if (rank != 0) s.resize (len); NG_MPI_Bcast (&s[0], len, NG_MPI_CHAR, root, comm); } template <typename T> void AllToAll (FlatArray<T> send, FlatArray<T> recv) const { NG_MPI_Alltoall (send.Data(), 1, GetMPIType<T>(), recv.Data(), 1, GetMPIType<T>(), comm); } template <typename T> void ScatterRoot (FlatArray<T> send) const { if (size == 1) return; NG_MPI_Scatter (send.Data(), 1, GetMPIType<T>(), NG_MPI_IN_PLACE, -1, GetMPIType<T>(), 0, comm); } template <typename T> void Scatter (T & recv) const { if (size == 1) return; NG_MPI_Scatter (NULL, 0, GetMPIType<T>(), &recv, 1, GetMPIType<T>(), 0, comm); } template <typename T> void GatherRoot (FlatArray<T> recv) const { recv[0] = T(0); if (size == 1) return; NG_MPI_Gather (NG_MPI_IN_PLACE, 1, GetMPIType<T>(), recv.Data(), 1, GetMPIType<T>(), 0, comm); } template <typename T> void Gather (T send) const { if (size == 1) return; NG_MPI_Gather (&send, 1, GetMPIType<T>(), NULL, 1, GetMPIType<T>(), 0, comm); } template <typename T> void AllGather (T val, FlatArray<T> recv) const { if (size == 1) { recv[0] = val; return; } NG_MPI_Allgather (&val, 1, GetMPIType<T>(), recv.Data(), 1, GetMPIType<T>(), comm); } template <typename T> void ExchangeTable (DynamicTable<T> & send_data, DynamicTable<T> & recv_data, int tag) { Array<int> send_sizes(size); Array<int> recv_sizes(size); for (int i = 0; i < size; i++) send_sizes[i] = send_data[i].Size(); AllToAll (send_sizes, recv_sizes); recv_data = DynamicTable<T> (recv_sizes, true); Array<NG_MPI_Request> requests; for (int dest = 0; dest < size; dest++) if (dest != rank && send_data[dest].Size()) requests.Append (ISend (FlatArray<T>(send_data[dest]), dest, tag)); for (int dest = 0; dest < size; dest++) if (dest != rank && recv_data[dest].Size()) requests.Append (IRecv (FlatArray<T>(recv_data[dest]), dest, tag)); MyMPI_WaitAll (requests); } NgMPI_Comm SubCommunicator (FlatArray<int> procs) const { NG_MPI_Comm subcomm; NG_MPI_Group gcomm, gsubcomm; NG_MPI_Comm_group(comm, &gcomm); NG_MPI_Group_incl(gcomm, procs.Size(), procs.Data(), &gsubcomm); NG_MPI_Comm_create_group(comm, gsubcomm, 4242, &subcomm); return NgMPI_Comm(subcomm, true); } }; // class NgMPI_Comm #else // PARALLEL class NG_MPI_Comm { int nr; public: NG_MPI_Comm (int _nr = 0) : nr(_nr) { ; } operator int() const { return nr; } bool operator== (NG_MPI_Comm c2) const { return nr == c2.nr; } }; static NG_MPI_Comm NG_MPI_COMM_WORLD = 12345, NG_MPI_COMM_NULL = 10000; typedef int NG_MPI_Op; typedef int NG_MPI_Datatype; typedef int NG_MPI_Request; enum { NG_MPI_SUM = 0, NG_MPI_MIN = 1, NG_MPI_MAX = 2, NG_MPI_LOR = 4711 }; inline void NG_MPI_Type_contiguous ( int, NG_MPI_Datatype, NG_MPI_Datatype*) { ; } inline void NG_MPI_Type_commit ( NG_MPI_Datatype * ) { ; } template <class T> struct MPI_typetrait { static NG_MPI_Datatype MPIType () { return -1; } }; template <class T, class T2=void> inline NG_MPI_Datatype GetMPIType () { return -1; } class NgMPI_Comm { public: NgMPI_Comm () { ; } NgMPI_Comm (NG_MPI_Comm _comm, bool owns = false) { ; } size_t Rank() const { return 0; } size_t Size() const { return 1; } bool ValidCommunicator() const { return false; } void Barrier() const { ; } operator NG_MPI_Comm() const { return NG_MPI_Comm(); } template<typename T> void Send( T & val, int dest, int tag) const { ; } template<typename T> void Send(FlatArray<T> s, int dest, int tag) const { ; } template<typename T> void Recv (T & val, int src, int tag) const { ; } template <typename T> void Recv (FlatArray <T> s, int src, int tag) const { ; } template <typename T> void Recv (Array <T> & s, int src, int tag) const { ; } template<typename T> NG_MPI_Request ISend (T & val, int dest, int tag) const { return 0; } template<typename T> NG_MPI_Request ISend (FlatArray<T> s, int dest, int tag) const { return 0; } template<typename T> NG_MPI_Request IRecv (T & val, int dest, int tag) const { return 0; } template<typename T> NG_MPI_Request IRecv (FlatArray<T> s, int src, int tag) const { return 0; } template <typename T> T Reduce (T d, const NG_MPI_Op & op, int root = 0) const { return d; } template <typename T> T AllReduce (T d, const NG_MPI_Op & op) const { return d; } template <typename T> void AllReduce (FlatArray<T> d, const NG_MPI_Op & op) const { ; } template <typename T> void Bcast (T & s, int root = 0) const { ; } template <class T> void Bcast (Array<T> & d, int root = 0) { ; } template <typename T> void AllGather (T val, FlatArray<T> recv) const { recv[0] = val; } template <typename T> void ExchangeTable (DynamicTable<T> & send_data, DynamicTable<T> & recv_data, int tag) { ; } NgMPI_Comm SubCommunicator (FlatArray<int> procs) const { return *this; } }; inline void MyMPI_WaitAll (FlatArray<NG_MPI_Request> requests) { ; } inline int MyMPI_WaitAny (FlatArray<NG_MPI_Request> requests) { return 0; } #endif // PARALLEL } // namespace ngcore #endif // NGCORE_MPIWRAPPER_HPP