diff --git a/libsrc/core/mpi_wrapper.hpp b/libsrc/core/mpi_wrapper.hpp index b42c551e..07b15867 100644 --- a/libsrc/core/mpi_wrapper.hpp +++ b/libsrc/core/mpi_wrapper.hpp @@ -188,8 +188,15 @@ namespace ngcore MPI_Bcast (&s[0], len, MPI_CHAR, root, comm); } - - }; + NgMPI_Comm SubCommunicator (FlatArray procs) const + { + MPI_Comm subcomm; + MPI_Group gcomm, gsubcomm; + MPI_Comm_group(comm, &gcomm); + MPI_Group_incl(gcomm, procs.Size(), procs.Data(), &gsubcomm); + MPI_Comm_create_group(comm, gsubcomm, 4242, &subcomm); + return NgMPI_Comm(subcomm, true); + } #else // PARALLEL @@ -239,6 +246,9 @@ namespace ngcore template void Bcast (T & s, int root = 0) const { ; } + + NgMPI_Comm SubCommunicator (FlatArray procs) const + { return *this; } }; #endif // PARALLEL diff --git a/libsrc/general/mpi_interface.hpp b/libsrc/general/mpi_interface.hpp index 08a15662..b7de6d64 100644 --- a/libsrc/general/mpi_interface.hpp +++ b/libsrc/general/mpi_interface.hpp @@ -94,21 +94,6 @@ namespace netgen template inline MPI_Datatype MyGetMPIType ( ) { return 0; } #endif -#ifdef PARALLEL - inline MPI_Comm MyMPI_SubCommunicator(MPI_Comm comm, NgArray & procs) - { - MPI_Comm subcomm; - MPI_Group gcomm, gsubcomm; - MPI_Comm_group(comm, &gcomm); - MPI_Group_incl(gcomm, procs.Size(), &(procs[0]), &gsubcomm); - MPI_Comm_create_group(comm, gsubcomm, 6969, &subcomm); - return subcomm; - } -#else - inline MPI_Comm MyMPI_SubCommunicator(MPI_Comm comm, NgArray & procs) - { return comm; } -#endif - #ifdef PARALLEL enum { MPI_TAG_CMD = 110 }; enum { MPI_TAG_MESH = 210 }; diff --git a/libsrc/meshing/python_mesh.cpp b/libsrc/meshing/python_mesh.cpp index 3e55ff75..dc28bab1 100644 --- a/libsrc/meshing/python_mesh.cpp +++ b/libsrc/meshing/python_mesh.cpp @@ -88,13 +88,12 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m) .def("Min", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); }) .def("Max", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); }) .def("SubComm", [](NgMPI_Comm & c, std::vector proc_list) { - NgArray procs(proc_list.size()); + Array procs(proc_list.size()); for (int i = 0; i < procs.Size(); i++) - procs[i] = proc_list[i]; + { procs[i] = proc_list[i]; } if (!procs.Contains(c.Rank())) - throw Exception("rank "+ToString(c.Rank())+" not in subcomm"); - MPI_Comm subcomm = MyMPI_SubCommunicator(c, procs); - return NgMPI_Comm(subcomm, true); + { throw Exception("rank "+ToString(c.Rank())+" not in subcomm"); } + return c.SubCommunicator(procs); }, py::arg("procs")); ;