diff --git a/libsrc/core/mpi_wrapper.hpp b/libsrc/core/mpi_wrapper.hpp index da5df796..c1fc47dd 100644 --- a/libsrc/core/mpi_wrapper.hpp +++ b/libsrc/core/mpi_wrapper.hpp @@ -174,10 +174,10 @@ namespace ngcore } template())> - MPI_Request IRecv (T & val, int src, int tag) const + MPI_Request IRecv (T & val, int dest, int tag) const { MPI_Request request; - MPI_Irecv (&val, 1, GetMPIType(), src, tag, comm, &request); + MPI_Irecv (&val, 1, GetMPIType(), dest, tag, comm, &request); return request; } @@ -227,6 +227,16 @@ namespace ngcore MPI_Bcast (&s[0], len, MPI_CHAR, root, comm); } + NgMPI_Comm SubCommunicator (FlatArray procs) const + { + MPI_Comm subcomm; + MPI_Group gcomm, gsubcomm; + MPI_Comm_group(comm, &gcomm); + MPI_Group_incl(gcomm, procs.Size(), procs.Data(), &gsubcomm); + MPI_Comm_create_group(comm, gsubcomm, 4242, &subcomm); + return NgMPI_Comm(subcomm, true); + } + }; // class NgMPI_Comm NETGEN_INLINE void MyMPI_WaitAll (FlatArray requests) @@ -313,15 +323,7 @@ namespace ngcore #endif // PARALLEL - - - - - - - - -} +} // namespace ngcore #endif // NGCORE_MPIWRAPPER_HPP diff --git a/libsrc/general/mpi_interface.hpp b/libsrc/general/mpi_interface.hpp index 08a15662..b7de6d64 100644 --- a/libsrc/general/mpi_interface.hpp +++ b/libsrc/general/mpi_interface.hpp @@ -94,21 +94,6 @@ namespace netgen template inline MPI_Datatype MyGetMPIType ( ) { return 0; } #endif -#ifdef PARALLEL - inline MPI_Comm MyMPI_SubCommunicator(MPI_Comm comm, NgArray & procs) - { - MPI_Comm subcomm; - MPI_Group gcomm, gsubcomm; - MPI_Comm_group(comm, &gcomm); - MPI_Group_incl(gcomm, procs.Size(), &(procs[0]), &gsubcomm); - MPI_Comm_create_group(comm, gsubcomm, 6969, &subcomm); - return subcomm; - } -#else - inline MPI_Comm MyMPI_SubCommunicator(MPI_Comm comm, NgArray & procs) - { return comm; } -#endif - #ifdef PARALLEL enum { MPI_TAG_CMD = 110 }; enum { MPI_TAG_MESH = 210 }; diff --git a/libsrc/meshing/python_mesh.cpp b/libsrc/meshing/python_mesh.cpp index 3e55ff75..dc28bab1 100644 --- a/libsrc/meshing/python_mesh.cpp +++ b/libsrc/meshing/python_mesh.cpp @@ -88,13 +88,12 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m) .def("Min", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); }) .def("Max", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); }) .def("SubComm", [](NgMPI_Comm & c, std::vector proc_list) { - NgArray procs(proc_list.size()); + Array procs(proc_list.size()); for (int i = 0; i < procs.Size(); i++) - procs[i] = proc_list[i]; + { procs[i] = proc_list[i]; } if (!procs.Contains(c.Rank())) - throw Exception("rank "+ToString(c.Rank())+" not in subcomm"); - MPI_Comm subcomm = MyMPI_SubCommunicator(c, procs); - return NgMPI_Comm(subcomm, true); + { throw Exception("rank "+ToString(c.Rank())+" not in subcomm"); } + return c.SubCommunicator(procs); }, py::arg("procs")); ;