Fix library name suffix, initialize MPI runtime wrapper as soon as an mpi4py comm gets converted into an NG_MPI_Comm

This commit is contained in:
Matthias Hochsteger 2024-05-07 11:27:51 +02:00
parent 9b5fc16397
commit e1be58011f

View File

@ -33,7 +33,6 @@ void InitMPI(std::optional<std::filesystem::path> mpi_lib_path) {
cout << IM(3) << "InitMPI" << endl;
std::string vendor = "";
std::string ng_lib_name = "";
std::string mpi4py_lib_file = "";
if (mpi_lib_path) {
@ -97,6 +96,7 @@ void InitMPI(std::optional<std::filesystem::path> mpi_lib_path) {
#endif // WIN32
}
std::string ng_lib_name = "";
if (vendor == "Open MPI")
ng_lib_name = "ng_openmpi";
else if (vendor == "MPICH")
@ -106,6 +106,8 @@ void InitMPI(std::optional<std::filesystem::path> mpi_lib_path) {
else
throw std::runtime_error("Unknown MPI vendor: " + vendor);
ng_lib_name += NETGEN_SHARED_LIBRARY_SUFFIX;
// Load the ng_mpi wrapper and call ng_init_mpi to set all function pointers
typedef void (*ng_init_handle)();
ng_mpi_lib = std::make_unique<SharedLibrary>(ng_lib_name);
@ -119,7 +121,26 @@ static std::runtime_error no_mpi() {
#if defined(NG_PYTHON) && defined(NG_MPI4PY)
decltype(NG_MPI_CommFromMPI4Py) NG_MPI_CommFromMPI4Py =
[](py::handle, NG_MPI_Comm &) -> bool { throw no_mpi(); };
[](py::handle py_obj, NG_MPI_Comm &ng_comm) -> bool {
// If this gets called, it means that we want to convert an mpi4py
// communicator to a Netgen MPI communicator, but the Netgen MPI wrapper
// runtime was not yet initialized.
// store the current address of this function
auto old_converter_address = NG_MPI_CommFromMPI4Py;
// initialize the MPI wrapper runtime, this sets all the function pointers
InitMPI();
// if the initialization was successful, the function pointer should have
// changed
// -> call the actual conversion function
if (NG_MPI_CommFromMPI4Py != old_converter_address)
return NG_MPI_CommFromMPI4Py(py_obj, ng_comm);
// otherwise, something strange happened
throw no_mpi();
};
decltype(NG_MPI_CommToMPI4Py) NG_MPI_CommToMPI4Py =
[](NG_MPI_Comm) -> py::handle { throw no_mpi(); };
#endif