mirror of
https://github.com/NGSolve/netgen.git
synced 2024-12-25 21:40:33 +05:00
Fix library name suffix, initialize MPI runtime wrapper as soon as an mpi4py comm gets converted into an NG_MPI_Comm
This commit is contained in:
parent
9b5fc16397
commit
e1be58011f
@ -33,7 +33,6 @@ void InitMPI(std::optional<std::filesystem::path> mpi_lib_path) {
|
||||
cout << IM(3) << "InitMPI" << endl;
|
||||
|
||||
std::string vendor = "";
|
||||
std::string ng_lib_name = "";
|
||||
std::string mpi4py_lib_file = "";
|
||||
|
||||
if (mpi_lib_path) {
|
||||
@ -97,6 +96,7 @@ void InitMPI(std::optional<std::filesystem::path> mpi_lib_path) {
|
||||
#endif // WIN32
|
||||
}
|
||||
|
||||
std::string ng_lib_name = "";
|
||||
if (vendor == "Open MPI")
|
||||
ng_lib_name = "ng_openmpi";
|
||||
else if (vendor == "MPICH")
|
||||
@ -106,6 +106,8 @@ void InitMPI(std::optional<std::filesystem::path> mpi_lib_path) {
|
||||
else
|
||||
throw std::runtime_error("Unknown MPI vendor: " + vendor);
|
||||
|
||||
ng_lib_name += NETGEN_SHARED_LIBRARY_SUFFIX;
|
||||
|
||||
// Load the ng_mpi wrapper and call ng_init_mpi to set all function pointers
|
||||
typedef void (*ng_init_handle)();
|
||||
ng_mpi_lib = std::make_unique<SharedLibrary>(ng_lib_name);
|
||||
@ -119,7 +121,26 @@ static std::runtime_error no_mpi() {
|
||||
|
||||
#if defined(NG_PYTHON) && defined(NG_MPI4PY)
|
||||
decltype(NG_MPI_CommFromMPI4Py) NG_MPI_CommFromMPI4Py =
|
||||
[](py::handle, NG_MPI_Comm &) -> bool { throw no_mpi(); };
|
||||
[](py::handle py_obj, NG_MPI_Comm &ng_comm) -> bool {
|
||||
// If this gets called, it means that we want to convert an mpi4py
|
||||
// communicator to a Netgen MPI communicator, but the Netgen MPI wrapper
|
||||
// runtime was not yet initialized.
|
||||
|
||||
// store the current address of this function
|
||||
auto old_converter_address = NG_MPI_CommFromMPI4Py;
|
||||
|
||||
// initialize the MPI wrapper runtime, this sets all the function pointers
|
||||
InitMPI();
|
||||
|
||||
// if the initialization was successful, the function pointer should have
|
||||
// changed
|
||||
// -> call the actual conversion function
|
||||
if (NG_MPI_CommFromMPI4Py != old_converter_address)
|
||||
return NG_MPI_CommFromMPI4Py(py_obj, ng_comm);
|
||||
|
||||
// otherwise, something strange happened
|
||||
throw no_mpi();
|
||||
};
|
||||
decltype(NG_MPI_CommToMPI4Py) NG_MPI_CommToMPI4Py =
|
||||
[](NG_MPI_Comm) -> py::handle { throw no_mpi(); };
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user