diff --git a/libsrc/core/mpi_wrapper.hpp b/libsrc/core/mpi_wrapper.hpp index 40b60b5c..6dca3f29 100644 --- a/libsrc/core/mpi_wrapper.hpp +++ b/libsrc/core/mpi_wrapper.hpp @@ -23,8 +23,11 @@ namespace ngcore template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_CHAR; } }; + template <> struct MPI_typetrait { + static MPI_Datatype MPIType () { return MPI_CHAR; } }; + template <> struct MPI_typetrait { - static MPI_Datatype MPIType () { return MPI_UNIT64_T; } }; + static MPI_Datatype MPIType () { return MPI_UINT64_T; } }; template <> struct MPI_typetrait { static MPI_Datatype MPIType () { return MPI_DOUBLE; } }; @@ -43,6 +46,7 @@ namespace ngcore { MPI_Comm comm; int * refcount; + int rank, size; public: NgMPI_Comm (MPI_Comm _comm, bool owns = false) : comm(_comm) @@ -51,16 +55,19 @@ namespace ngcore refcount = nullptr; else refcount = new int{1}; + + MPI_Comm_rank(comm, &rank); + MPI_Comm_size(comm, &size); } NgMPI_Comm (const NgMPI_Comm & c) - : comm(c.comm), refcount(c.refcount) + : comm(c.comm), refcount(c.refcount), rank(c.rank), size(c.siez) { if (refcount) (*refcount)++; } NgMPI_Comm (NgMPI_Comm && c) - : comm(c.comm), refcount(c.refcount) + : comm(c.comm), refcount(c.refcount), rank(c.rank), size(c.size) { c.refcount = nullptr; } @@ -74,8 +81,21 @@ namespace ngcore operator MPI_Comm() const { return comm; } - auto Rank() const { int r; MPI_Comm_rank(comm, &r); return r; } - auto Size() const { int s; MPI_Comm_size(comm, &s); return s; } + int Rank() const { return rank; } // int r; MPI_Comm_rank(comm, &r); return r; } + int Size() const { return size; } // int s; MPI_Comm_size(comm, &s); return s; } + + + template())> + void Send( T & val, int dest, int tag) { + MPI_Send (&val, 1, GetMPIType(), dest, tag, comm); + } + + template())> + void MyMPI_Recv (T & val, int src, int tag) { + MPI_Recv (&val, 1, GetMPIType(), src, tag, comm, MPI_STATUS_IGNORE); + } + + }; @@ -90,6 +110,14 @@ namespace ngcore size_t Rank() const { return 0; } size_t Size() const { return 1; } + + + + template + void Send( T & val, int dest, int tag) { ; } + + template + void MyMPI_Recv (T & val, int src, int tag) { ; } }; #endif diff --git a/libsrc/general/mpi_interface.hpp b/libsrc/general/mpi_interface.hpp index eb79865e..5fee6ef9 100644 --- a/libsrc/general/mpi_interface.hpp +++ b/libsrc/general/mpi_interface.hpp @@ -23,7 +23,7 @@ namespace netgen #endif /** This is the "standard" communicator that will be used for netgen-objects. **/ - extern DLL_HEADER MPI_Comm ng_comm; + extern DLL_HEADER NgMPI_Comm ng_comm; #ifdef PARALLEL inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm) @@ -44,6 +44,7 @@ namespace netgen inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { return 0; } #endif + /* #ifdef PARALLEL // For python wrapping of communicators struct PyMPI_Comm { @@ -68,7 +69,8 @@ namespace netgen inline int Size() const { return 1; } }; #endif - + */ + #ifdef PARALLEL template inline MPI_Datatype MyGetMPIType ( ) diff --git a/libsrc/meshing/global.cpp b/libsrc/meshing/global.cpp index 327727be..88cb7a38 100644 --- a/libsrc/meshing/global.cpp +++ b/libsrc/meshing/global.cpp @@ -32,7 +32,7 @@ namespace netgen // TraceGlobal glob2("global2"); // global communicator for netgen - DLL_HEADER MPI_Comm ng_comm = MPI_COMM_WORLD; + DLL_HEADER MyMPI_Comm ng_comm = MPI_COMM_WORLD; weak_ptr global_mesh; void SetGlobalMesh (shared_ptr m) diff --git a/libsrc/meshing/python_mesh.cpp b/libsrc/meshing/python_mesh.cpp index 34b00a0f..d4f8744f 100644 --- a/libsrc/meshing/python_mesh.cpp +++ b/libsrc/meshing/python_mesh.cpp @@ -494,20 +494,20 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m) py::class_>(m, "Mesh") // .def(py::init<>("create empty mesh")) - .def(py::init( [] (int dim, shared_ptr pycomm) + .def(py::init( [] (int dim, NgMPI_Comm comm) { auto mesh = make_shared(); - mesh->SetCommunicator(pycomm!=nullptr ? pycomm->comm : netgen::ng_comm); + mesh->SetCommunicator(comm); mesh -> SetDimension(dim); SetGlobalMesh(mesh); // for visualization mesh -> SetGeometry (nullptr); return mesh; } ), - py::arg("dim")=3, py::arg("comm")=nullptr + py::arg("dim")=3, py::arg("comm")=NgMPI_Comm(ng_comm) ) .def(NGSPickle()) - .def_property_readonly("comm", [](const Mesh & amesh) - { return make_shared(amesh.GetCommunicator()); }, + .def_property_readonly("comm", [](const Mesh & amesh) -> NgMPI_Comm + { return amesh.GetCommunicator(); }, "MPI-communicator the Mesh lives in") /* .def("__init__", @@ -521,8 +521,7 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m) */ .def_property_readonly("_timestamp", &Mesh::GetTimeStamp) - .def("Distribute", [](shared_ptr self, shared_ptr pycomm) { - MPI_Comm comm = pycomm!=nullptr ? pycomm->comm : self->GetCommunicator(); + .def("Distribute", [](shared_ptr self, NgMPI_Comm comm) { self->SetCommunicator(comm); if(MyMPI_GetNTasks(comm)==1) return self; // if(MyMPI_GetNTasks(comm)==2) throw NgException("Sorry, cannot handle communicators with NP=2!"); @@ -530,10 +529,10 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m) if(MyMPI_GetId(comm)==0) self->Distribute(); else self->SendRecvMesh(); return self; - }, py::arg("comm")=nullptr) - .def("Receive", [](shared_ptr pycomm) { + }, py::arg("comm")=NgMPI_Comm(ng_comm)) + .def("Receive", [](NgMPI_Comm comm) { auto mesh = make_shared(); - mesh->SetCommunicator(pycomm->comm); + mesh->SetCommunicator(comm); mesh->SendRecvMesh(); return mesh; }) @@ -933,57 +932,34 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m) return old; })); - py::class_> (m, "MPI_Comm") - .def_property_readonly ("rank", &PyMPI_Comm::Rank) - .def_property_readonly ("size", &PyMPI_Comm::Size) - // .def_property_readonly ("rank", [](PyMPI_Comm & c) { cout << "rank for " << c.comm << endl; return c.Rank(); }) - // .def_property_readonly ("size", [](PyMPI_Comm & c) { cout << "size for " << c.comm << endl; return c.Size(); }) + py::class_ (m, "MPI_Comm") + .def_property_readonly ("rank", &NgMPI_Comm::Rank) + .def_property_readonly ("size", &NgMPI_Comm::Size) + #ifdef PARALLEL - .def("Barrier", [](PyMPI_Comm & c) { MPI_Barrier(c.comm); }) - .def("WTime", [](PyMPI_Comm & c) { return MPI_Wtime(); }) + .def("Barrier", [](NgMPI_Comm & c) { MPI_Barrier(c); }) + .def("WTime", [](NgMPI_Comm & c) { return MPI_Wtime(); }) #else - .def("Barrier", [](PyMPI_Comm & c) { }) - .def("WTime", [](PyMPI_Comm & c) { return -1.0; }) + .def("Barrier", [](NgMPI_Comm & c) { }) + .def("WTime", [](NgMPI_Comm & c) { return -1.0; }) #endif - .def("Sum", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); }) - .def("Min", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); }) - .def("Max", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); }) - .def("Sum", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); }) - .def("Min", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); }) - .def("Max", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); }) - .def("Sum", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); }) - .def("Min", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); }) - .def("Max", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); }) - .def("SubComm", [](PyMPI_Comm & c, std::vector proc_list) { + .def("Sum", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); }) + .def("Min", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); }) + .def("Max", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); }) + .def("Sum", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); }) + .def("Min", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); }) + .def("Max", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); }) + .def("Sum", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); }) + .def("Min", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); }) + .def("Max", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); }) + .def("SubComm", [](NgMPI_Comm & c, std::vector proc_list) { Array procs(proc_list.size()); for (int i = 0; i < procs.Size(); i++) procs[i] = proc_list[i]; if (!procs.Contains(c.Rank())) throw Exception("rank "+ToString(c.Rank())+" not in subcomm"); - MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs); - return make_shared(subcomm, true); - - /* - Array procs; - if (py::extract (proc_list).check()) { - py::list pylist = py::extract (proc_list)(); - procs.SetSize(py::len(pyplist)); - for (int i = 0; i < py::len(pylist); i++) - procs[i] = py::extract(pylist[i])(); - } - else { - throw Exception("SubComm needs a list!"); - } - if(!procs.Size()) { - cout << "warning, tried to construct empty communicator, returning MPI_COMM_NULL" << endl; - return make_shared(MPI_COMM_NULL); - } - else if(procs.Size()==2) { - throw Exception("Sorry, NGSolve cannot handle NP=2."); - } - MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs); - return make_shared(subcomm, true); - */ + MPI_Comm subcomm = MyMPI_SubCommunicator(c, procs); + return NgMPI_Comm(subcomm, true); }, py::arg("procs")); ;