diff --git a/libsrc/core/ng_mpi_wrapper.cpp b/libsrc/core/ng_mpi_wrapper.cpp index 328e5ac7..208fb69d 100644 --- a/libsrc/core/ng_mpi_wrapper.cpp +++ b/libsrc/core/ng_mpi_wrapper.cpp @@ -33,7 +33,6 @@ void InitMPI(std::optional mpi_lib_path) { cout << IM(3) << "InitMPI" << endl; std::string vendor = ""; - std::string ng_lib_name = ""; std::string mpi4py_lib_file = ""; if (mpi_lib_path) { @@ -97,6 +96,7 @@ void InitMPI(std::optional mpi_lib_path) { #endif // WIN32 } + std::string ng_lib_name = ""; if (vendor == "Open MPI") ng_lib_name = "ng_openmpi"; else if (vendor == "MPICH") @@ -106,6 +106,8 @@ void InitMPI(std::optional mpi_lib_path) { else throw std::runtime_error("Unknown MPI vendor: " + vendor); + ng_lib_name += NETGEN_SHARED_LIBRARY_SUFFIX; + // Load the ng_mpi wrapper and call ng_init_mpi to set all function pointers typedef void (*ng_init_handle)(); ng_mpi_lib = std::make_unique(ng_lib_name); @@ -119,7 +121,26 @@ static std::runtime_error no_mpi() { #if defined(NG_PYTHON) && defined(NG_MPI4PY) decltype(NG_MPI_CommFromMPI4Py) NG_MPI_CommFromMPI4Py = - [](py::handle, NG_MPI_Comm &) -> bool { throw no_mpi(); }; + [](py::handle py_obj, NG_MPI_Comm &ng_comm) -> bool { + // If this gets called, it means that we want to convert an mpi4py + // communicator to a Netgen MPI communicator, but the Netgen MPI wrapper + // runtime was not yet initialized. + + // store the current address of this function + auto old_converter_address = NG_MPI_CommFromMPI4Py; + + // initialize the MPI wrapper runtime, this sets all the function pointers + InitMPI(); + + // if the initialization was successful, the function pointer should have + // changed + // -> call the actual conversion function + if (NG_MPI_CommFromMPI4Py != old_converter_address) + return NG_MPI_CommFromMPI4Py(py_obj, ng_comm); + + // otherwise, something strange happened + throw no_mpi(); +}; decltype(NG_MPI_CommToMPI4Py) NG_MPI_CommToMPI4Py = [](NG_MPI_Comm) -> py::handle { throw no_mpi(); }; #endif