diff --git a/libsrc/general/mpi_interface.hpp b/libsrc/general/mpi_interface.hpp index 880abc4b..593d290a 100644 --- a/libsrc/general/mpi_interface.hpp +++ b/libsrc/general/mpi_interface.hpp @@ -18,56 +18,98 @@ namespace netgen using ngcore::ntasks; #ifndef PARALLEL + /** without MPI, we need a dummy typedef **/ typedef int MPI_Comm; #endif /** This is the "standard" communicator that will be used for netgen-objects. **/ extern MPI_Comm ng_comm; -#ifndef PARALLEL - enum { MPI_COMM_WORLD = 12345, MPI_COMM_NULL = 0}; - inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm) { return 1; } - inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { return 0; } -#endif - - #ifdef PARALLEL - inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm) { int ntasks; MPI_Comm_size(comm, &ntasks); return ntasks; } - inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { int id; MPI_Comm_rank(comm, &id); return id; } +#else + enum { MPI_COMM_WORLD = 12345, MPI_COMM_NULL = 0}; + inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm) { return 1; } + inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { return 0; } +#endif +#ifdef PARALLEL + // For python wrapping of communicators + struct PyMPI_Comm { + MPI_Comm comm; + bool owns_comm; + PyMPI_Comm (MPI_Comm _comm, bool _owns_comm = false) : comm(_comm), owns_comm(_owns_comm) { } + PyMPI_Comm (const PyMPI_Comm & c) = delete; + ~PyMPI_Comm () { + if (owns_comm) + MPI_Comm_free(&comm); + } + inline int Rank() const { return MyMPI_GetId(comm); } + inline int Size() const { return MyMPI_GetNTasks(comm); } + }; +#else + // dummy without MPI + struct PyMPI_Comm { + MPI_Comm comm = 0; + PyMPI_Comm (MPI_Comm _comm, bool _owns_comm = false) { } + ~PyMPI_Comm () { } + inline int Rank() const { return 0; } + inline int Size() const { return 1; } + }; +#endif + +#ifdef PARALLEL + template + inline MPI_Datatype MyGetMPIType ( ) + { cerr << "ERROR in GetMPIType() -- no type found" << endl;return 0; } + template <> + inline MPI_Datatype MyGetMPIType ( ) + { return MPI_INT; } + template <> + inline MPI_Datatype MyGetMPIType ( ) + { return MPI_DOUBLE; } + template <> + inline MPI_Datatype MyGetMPIType ( ) + { return MPI_CHAR; } + template<> + inline MPI_Datatype MyGetMPIType ( ) + { return MPI_UINT64_T; } +#else + typedef int MPI_Datatype; + template inline MPI_Datatype MyGetMPIType ( ) { return 0; } +#endif + +#ifdef PARALLEL + inline MPI_Comm MyMPI_SubCommunicator(MPI_Comm comm, Array & procs) + { + MPI_Comm subcomm; + MPI_Group gcomm, gsubcomm; + MPI_Comm_group(comm, &gcomm); + MPI_Group_incl(gcomm, procs.Size(), &(procs[0]), &gsubcomm); + MPI_Comm_create_group(comm, gsubcomm, 6969, &subcomm); + return subcomm; + } +#else + inline MPI_Comm MyMPI_SubCommunicator(MPI_Comm comm, Array & procs) + { return comm; } +#endif + +#ifdef PARALLEL enum { MPI_TAG_CMD = 110 }; enum { MPI_TAG_MESH = 210 }; enum { MPI_TAG_VIS = 310 }; - template - MPI_Datatype MyGetMPIType ( ) - { cerr << "ERROR in GetMPIType() -- no type found" << endl;return 0; } - - template <> - inline MPI_Datatype MyGetMPIType ( ) - { return MPI_INT; } - - template <> - inline MPI_Datatype MyGetMPIType ( ) - { return MPI_DOUBLE; } - - template <> - inline MPI_Datatype MyGetMPIType ( ) - { return MPI_CHAR; } - - inline void MyMPI_Send (int i, int dest, int tag, MPI_Comm comm = ng_comm) { int hi = i; diff --git a/libsrc/meshing/python_mesh.cpp b/libsrc/meshing/python_mesh.cpp index 11ba4a7c..dbc498cb 100644 --- a/libsrc/meshing/python_mesh.cpp +++ b/libsrc/meshing/python_mesh.cpp @@ -17,6 +17,22 @@ namespace netgen { extern bool netgen_executable_started; extern shared_ptr ng_geometry; +#ifdef PARALLEL + /** we need allreduce in python-wrapped communicators **/ + template + inline T MyMPI_AllReduceNG (T d, const MPI_Op & op = MPI_SUM, MPI_Comm comm = ng_comm) + { + T global_d; + MPI_Allreduce ( &d, &global_d, 1, MyGetMPIType(), op, comm); + return global_d; + } +#else + enum { MPI_SUM = 0, MPI_MIN = 1, MPI_MAX = 2 }; + typedef int MPI_Op; + template + inline T MyMPI_AllReduceNG (T d, const MPI_Op & op = MPI_SUM, MPI_Comm comm = ng_comm) + { return d; } +#endif } @@ -503,19 +519,21 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m) py::class_>(m, "Mesh") // .def(py::init<>("create empty mesh")) - .def(py::init( [] (int dim) + .def(py::init( [] (int dim, shared_ptr pycomm) { auto mesh = make_shared(); - mesh->SetCommunicator(netgen::ng_comm); + mesh->SetCommunicator(pycomm!=nullptr ? pycomm->comm : netgen::ng_comm); mesh -> SetDimension(dim); SetGlobalMesh(mesh); // for visualization mesh -> SetGeometry (nullptr); return mesh; } ), - py::arg("dim")=3 + py::arg("dim")=3, py::arg("comm")=nullptr ) .def(NGSPickle()) - + .def_property_readonly("comm", [](const Mesh & amesh) + { return make_shared(amesh.GetCommunicator()); }, + "MPI-communicator the Mesh lives in") /* .def("__init__", [](Mesh *instance, int dim) @@ -528,15 +546,25 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m) */ .def_property_readonly("_timestamp", &Mesh::GetTimeStamp) + .def("Distribute", [](Mesh & self, shared_ptr pycomm) { + MPI_Comm comm = pycomm!=nullptr ? pycomm->comm : self.GetCommunicator(); + self.SetCommunicator(comm); + if(MyMPI_GetNTasks(comm)==1) return; + if(MyMPI_GetNTasks(comm)==2) throw NgException("Sorry, cannot handle communicators with NP=2!"); + cout << " rank " << MyMPI_GetId(comm) << " of " << MyMPI_GetNTasks(comm) << " called Distribute " << endl; + if(MyMPI_GetId(comm)==0) self.Distribute(); + else self.SendRecvMesh(); + }, py::arg("comm")=nullptr) .def("Load", FunctionPointer ([](Mesh & self, const string & filename) { istream * infile; + MPI_Comm comm = self.GetCommunicator(); + id = MyMPI_GetId(comm); + ntasks = MyMPI_GetNTasks(comm); + #ifdef PARALLEL - MPI_Comm_rank(netgen::ng_comm, &id); - MPI_Comm_size(netgen::ng_comm, &ntasks); - char* buf = nullptr; int strs = 0; if(id==0) { @@ -564,10 +592,10 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m) } /** Scatter the geometry-string **/ - MPI_Bcast(&strs, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(&strs, 1, MPI_INT, 0, comm); if(id!=0) buf = new char[strs]; - MPI_Bcast(buf, strs, MPI_CHAR, 0, MPI_COMM_WORLD); + MPI_Bcast(buf, strs, MPI_CHAR, 0, comm); if(id==0) delete infile; infile = new istringstream(string((const char*)buf, (size_t)strs)); @@ -922,6 +950,51 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m) printmessage_importance = importance; return old; })); + + py::class_> (m, "MPI_Comm") + .def_property_readonly ("rank", &PyMPI_Comm::Rank) + .def_property_readonly ("size", &PyMPI_Comm::Size) + // .def_property_readonly ("rank", [](PyMPI_Comm & c) { cout << "rank for " << c.comm << endl; return c.Rank(); }) + // .def_property_readonly ("size", [](PyMPI_Comm & c) { cout << "size for " << c.comm << endl; return c.Size(); }) +#ifdef PARALLEL + .def("Barrier", [](PyMPI_Comm & c) { MPI_Barrier(c.comm); }) + .def("WTime", [](PyMPI_Comm & c) { return MPI_Wtime(); }) +#else + .def("Barrier", [](PyMPI_Comm & c) { }) + .def("WTime", [](PyMPI_Comm & c) { return -1.0; }) +#endif + .def("Sum", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); }) + .def("Min", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); }) + .def("Max", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); }) + .def("Sum", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); }) + .def("Min", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); }) + .def("Max", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); }) + .def("Sum", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); }) + .def("Min", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); }) + .def("Max", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); }) + .def("SubComm", [](PyMPI_Comm & c, py::list proc_list) -> shared_ptr { + Array procs; + if (py::extract (proc_list).check()) { + py::list pylist = py::extract (proc_list)(); + procs.SetSize(py::len(pylist)); + for (int i = 0; i < py::len(pylist); i++) + procs[i] = py::extract(pylist[i])(); + } + else { + throw Exception("SubComm needs a list!"); + } + if(!procs.Size()) { + cout << "warning, tried to construct empty communicator, returning MPI_COMM_NULL" << endl; + return make_shared(MPI_COMM_NULL); + } + else if(procs.Size()==2) { + throw Exception("Sorry, NGSolve cannot handle NP=2."); + } + MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs); + return make_shared(subcomm, true); + }, py::arg("procs")); + ; + } PYBIND11_MODULE(libmesh, m) {