Wrap MPI-communicator on netgen side

This commit is contained in:
Lukas 2019-01-30 20:55:45 +01:00
parent da5c9723d1
commit 214b5c452d
2 changed files with 150 additions and 35 deletions

View File

@ -18,56 +18,98 @@ namespace netgen
using ngcore::ntasks;
#ifndef PARALLEL
/** without MPI, we need a dummy typedef **/
typedef int MPI_Comm;
#endif
/** This is the "standard" communicator that will be used for netgen-objects. **/
extern MPI_Comm ng_comm;
#ifndef PARALLEL
enum { MPI_COMM_WORLD = 12345, MPI_COMM_NULL = 0};
inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm) { return 1; }
inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { return 0; }
#endif
#ifdef PARALLEL
inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm)
{
int ntasks;
MPI_Comm_size(comm, &ntasks);
return ntasks;
}
inline int MyMPI_GetId (MPI_Comm comm = ng_comm)
{
int id;
MPI_Comm_rank(comm, &id);
return id;
}
#else
enum { MPI_COMM_WORLD = 12345, MPI_COMM_NULL = 0};
inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm) { return 1; }
inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { return 0; }
#endif
#ifdef PARALLEL
// For python wrapping of communicators
struct PyMPI_Comm {
MPI_Comm comm;
bool owns_comm;
PyMPI_Comm (MPI_Comm _comm, bool _owns_comm = false) : comm(_comm), owns_comm(_owns_comm) { }
PyMPI_Comm (const PyMPI_Comm & c) = delete;
~PyMPI_Comm () {
if (owns_comm)
MPI_Comm_free(&comm);
}
inline int Rank() const { return MyMPI_GetId(comm); }
inline int Size() const { return MyMPI_GetNTasks(comm); }
};
#else
// dummy without MPI
struct PyMPI_Comm {
MPI_Comm comm = 0;
PyMPI_Comm (MPI_Comm _comm, bool _owns_comm = false) { }
~PyMPI_Comm () { }
inline int Rank() const { return 0; }
inline int Size() const { return 1; }
};
#endif
#ifdef PARALLEL
template <class T>
inline MPI_Datatype MyGetMPIType ( )
{ cerr << "ERROR in GetMPIType() -- no type found" << endl;return 0; }
template <>
inline MPI_Datatype MyGetMPIType<int> ( )
{ return MPI_INT; }
template <>
inline MPI_Datatype MyGetMPIType<double> ( )
{ return MPI_DOUBLE; }
template <>
inline MPI_Datatype MyGetMPIType<char> ( )
{ return MPI_CHAR; }
template<>
inline MPI_Datatype MyGetMPIType<size_t> ( )
{ return MPI_UINT64_T; }
#else
typedef int MPI_Datatype;
template <class T> inline MPI_Datatype MyGetMPIType ( ) { return 0; }
#endif
#ifdef PARALLEL
inline MPI_Comm MyMPI_SubCommunicator(MPI_Comm comm, Array<int> & procs)
{
MPI_Comm subcomm;
MPI_Group gcomm, gsubcomm;
MPI_Comm_group(comm, &gcomm);
MPI_Group_incl(gcomm, procs.Size(), &(procs[0]), &gsubcomm);
MPI_Comm_create_group(comm, gsubcomm, 6969, &subcomm);
return subcomm;
}
#else
inline MPI_Comm MyMPI_SubCommunicator(MPI_Comm comm, Array<int> & procs)
{ return comm; }
#endif
#ifdef PARALLEL
enum { MPI_TAG_CMD = 110 };
enum { MPI_TAG_MESH = 210 };
enum { MPI_TAG_VIS = 310 };
template <class T>
MPI_Datatype MyGetMPIType ( )
{ cerr << "ERROR in GetMPIType() -- no type found" << endl;return 0; }
template <>
inline MPI_Datatype MyGetMPIType<int> ( )
{ return MPI_INT; }
template <>
inline MPI_Datatype MyGetMPIType<double> ( )
{ return MPI_DOUBLE; }
template <>
inline MPI_Datatype MyGetMPIType<char> ( )
{ return MPI_CHAR; }
inline void MyMPI_Send (int i, int dest, int tag, MPI_Comm comm = ng_comm)
{
int hi = i;

View File

@ -17,6 +17,22 @@ namespace netgen
{
extern bool netgen_executable_started;
extern shared_ptr<NetgenGeometry> ng_geometry;
#ifdef PARALLEL
/** we need allreduce in python-wrapped communicators **/
template <typename T>
inline T MyMPI_AllReduceNG (T d, const MPI_Op & op = MPI_SUM, MPI_Comm comm = ng_comm)
{
T global_d;
MPI_Allreduce ( &d, &global_d, 1, MyGetMPIType<T>(), op, comm);
return global_d;
}
#else
enum { MPI_SUM = 0, MPI_MIN = 1, MPI_MAX = 2 };
typedef int MPI_Op;
template <typename T>
inline T MyMPI_AllReduceNG (T d, const MPI_Op & op = MPI_SUM, MPI_Comm comm = ng_comm)
{ return d; }
#endif
}
@ -503,19 +519,21 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
py::class_<Mesh,shared_ptr<Mesh>>(m, "Mesh")
// .def(py::init<>("create empty mesh"))
.def(py::init( [] (int dim)
.def(py::init( [] (int dim, shared_ptr<PyMPI_Comm> pycomm)
{
auto mesh = make_shared<Mesh>();
mesh->SetCommunicator(netgen::ng_comm);
mesh->SetCommunicator(pycomm!=nullptr ? pycomm->comm : netgen::ng_comm);
mesh -> SetDimension(dim);
SetGlobalMesh(mesh); // for visualization
mesh -> SetGeometry (nullptr);
return mesh;
} ),
py::arg("dim")=3
py::arg("dim")=3, py::arg("comm")=nullptr
)
.def(NGSPickle<Mesh>())
.def_property_readonly("comm", [](const Mesh & amesh)
{ return make_shared<PyMPI_Comm>(amesh.GetCommunicator()); },
"MPI-communicator the Mesh lives in")
/*
.def("__init__",
[](Mesh *instance, int dim)
@ -528,15 +546,25 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
*/
.def_property_readonly("_timestamp", &Mesh::GetTimeStamp)
.def("Distribute", [](Mesh & self, shared_ptr<PyMPI_Comm> pycomm) {
MPI_Comm comm = pycomm!=nullptr ? pycomm->comm : self.GetCommunicator();
self.SetCommunicator(comm);
if(MyMPI_GetNTasks(comm)==1) return;
if(MyMPI_GetNTasks(comm)==2) throw NgException("Sorry, cannot handle communicators with NP=2!");
cout << " rank " << MyMPI_GetId(comm) << " of " << MyMPI_GetNTasks(comm) << " called Distribute " << endl;
if(MyMPI_GetId(comm)==0) self.Distribute();
else self.SendRecvMesh();
}, py::arg("comm")=nullptr)
.def("Load", FunctionPointer
([](Mesh & self, const string & filename)
{
istream * infile;
MPI_Comm comm = self.GetCommunicator();
id = MyMPI_GetId(comm);
ntasks = MyMPI_GetNTasks(comm);
#ifdef PARALLEL
MPI_Comm_rank(netgen::ng_comm, &id);
MPI_Comm_size(netgen::ng_comm, &ntasks);
char* buf = nullptr;
int strs = 0;
if(id==0) {
@ -564,10 +592,10 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
}
/** Scatter the geometry-string **/
MPI_Bcast(&strs, 1, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Bcast(&strs, 1, MPI_INT, 0, comm);
if(id!=0)
buf = new char[strs];
MPI_Bcast(buf, strs, MPI_CHAR, 0, MPI_COMM_WORLD);
MPI_Bcast(buf, strs, MPI_CHAR, 0, comm);
if(id==0)
delete infile;
infile = new istringstream(string((const char*)buf, (size_t)strs));
@ -922,6 +950,51 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
printmessage_importance = importance;
return old;
}));
py::class_<PyMPI_Comm, shared_ptr<PyMPI_Comm>> (m, "MPI_Comm")
.def_property_readonly ("rank", &PyMPI_Comm::Rank)
.def_property_readonly ("size", &PyMPI_Comm::Size)
// .def_property_readonly ("rank", [](PyMPI_Comm & c) { cout << "rank for " << c.comm << endl; return c.Rank(); })
// .def_property_readonly ("size", [](PyMPI_Comm & c) { cout << "size for " << c.comm << endl; return c.Size(); })
#ifdef PARALLEL
.def("Barrier", [](PyMPI_Comm & c) { MPI_Barrier(c.comm); })
.def("WTime", [](PyMPI_Comm & c) { return MPI_Wtime(); })
#else
.def("Barrier", [](PyMPI_Comm & c) { })
.def("WTime", [](PyMPI_Comm & c) { return -1.0; })
#endif
.def("Sum", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
.def("Min", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
.def("Max", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
.def("Sum", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
.def("Min", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
.def("Max", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
.def("Sum", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
.def("Min", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
.def("Max", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
.def("SubComm", [](PyMPI_Comm & c, py::list proc_list) -> shared_ptr<PyMPI_Comm> {
Array<int> procs;
if (py::extract<py::list> (proc_list).check()) {
py::list pylist = py::extract<py::list> (proc_list)();
procs.SetSize(py::len(pylist));
for (int i = 0; i < py::len(pylist); i++)
procs[i] = py::extract<int>(pylist[i])();
}
else {
throw Exception("SubComm needs a list!");
}
if(!procs.Size()) {
cout << "warning, tried to construct empty communicator, returning MPI_COMM_NULL" << endl;
return make_shared<PyMPI_Comm>(MPI_COMM_NULL);
}
else if(procs.Size()==2) {
throw Exception("Sorry, NGSolve cannot handle NP=2.");
}
MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs);
return make_shared<PyMPI_Comm>(subcomm, true);
}, py::arg("procs"));
;
}
PYBIND11_MODULE(libmesh, m) {