mirror of
https://github.com/NGSolve/netgen.git
synced 2024-12-26 22:00:33 +05:00
Wrap MPI-communicator on netgen side
This commit is contained in:
parent
da5c9723d1
commit
214b5c452d
@ -18,55 +18,97 @@ namespace netgen
|
|||||||
using ngcore::ntasks;
|
using ngcore::ntasks;
|
||||||
|
|
||||||
#ifndef PARALLEL
|
#ifndef PARALLEL
|
||||||
|
/** without MPI, we need a dummy typedef **/
|
||||||
typedef int MPI_Comm;
|
typedef int MPI_Comm;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/** This is the "standard" communicator that will be used for netgen-objects. **/
|
/** This is the "standard" communicator that will be used for netgen-objects. **/
|
||||||
extern MPI_Comm ng_comm;
|
extern MPI_Comm ng_comm;
|
||||||
|
|
||||||
#ifndef PARALLEL
|
|
||||||
enum { MPI_COMM_WORLD = 12345, MPI_COMM_NULL = 0};
|
|
||||||
inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm) { return 1; }
|
|
||||||
inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { return 0; }
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
|
||||||
#ifdef PARALLEL
|
#ifdef PARALLEL
|
||||||
|
|
||||||
inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm)
|
inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm)
|
||||||
{
|
{
|
||||||
int ntasks;
|
int ntasks;
|
||||||
MPI_Comm_size(comm, &ntasks);
|
MPI_Comm_size(comm, &ntasks);
|
||||||
return ntasks;
|
return ntasks;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline int MyMPI_GetId (MPI_Comm comm = ng_comm)
|
inline int MyMPI_GetId (MPI_Comm comm = ng_comm)
|
||||||
{
|
{
|
||||||
int id;
|
int id;
|
||||||
MPI_Comm_rank(comm, &id);
|
MPI_Comm_rank(comm, &id);
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
enum { MPI_COMM_WORLD = 12345, MPI_COMM_NULL = 0};
|
||||||
|
inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm) { return 1; }
|
||||||
|
inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { return 0; }
|
||||||
|
#endif
|
||||||
|
|
||||||
enum { MPI_TAG_CMD = 110 };
|
#ifdef PARALLEL
|
||||||
enum { MPI_TAG_MESH = 210 };
|
// For python wrapping of communicators
|
||||||
enum { MPI_TAG_VIS = 310 };
|
struct PyMPI_Comm {
|
||||||
|
MPI_Comm comm;
|
||||||
|
bool owns_comm;
|
||||||
|
PyMPI_Comm (MPI_Comm _comm, bool _owns_comm = false) : comm(_comm), owns_comm(_owns_comm) { }
|
||||||
|
PyMPI_Comm (const PyMPI_Comm & c) = delete;
|
||||||
|
~PyMPI_Comm () {
|
||||||
|
if (owns_comm)
|
||||||
|
MPI_Comm_free(&comm);
|
||||||
|
}
|
||||||
|
inline int Rank() const { return MyMPI_GetId(comm); }
|
||||||
|
inline int Size() const { return MyMPI_GetNTasks(comm); }
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
// dummy without MPI
|
||||||
|
struct PyMPI_Comm {
|
||||||
|
MPI_Comm comm = 0;
|
||||||
|
PyMPI_Comm (MPI_Comm _comm, bool _owns_comm = false) { }
|
||||||
|
~PyMPI_Comm () { }
|
||||||
|
inline int Rank() const { return 0; }
|
||||||
|
inline int Size() const { return 1; }
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef PARALLEL
|
||||||
template <class T>
|
template <class T>
|
||||||
MPI_Datatype MyGetMPIType ( )
|
inline MPI_Datatype MyGetMPIType ( )
|
||||||
{ cerr << "ERROR in GetMPIType() -- no type found" << endl;return 0; }
|
{ cerr << "ERROR in GetMPIType() -- no type found" << endl;return 0; }
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline MPI_Datatype MyGetMPIType<int> ( )
|
inline MPI_Datatype MyGetMPIType<int> ( )
|
||||||
{ return MPI_INT; }
|
{ return MPI_INT; }
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline MPI_Datatype MyGetMPIType<double> ( )
|
inline MPI_Datatype MyGetMPIType<double> ( )
|
||||||
{ return MPI_DOUBLE; }
|
{ return MPI_DOUBLE; }
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline MPI_Datatype MyGetMPIType<char> ( )
|
inline MPI_Datatype MyGetMPIType<char> ( )
|
||||||
{ return MPI_CHAR; }
|
{ return MPI_CHAR; }
|
||||||
|
template<>
|
||||||
|
inline MPI_Datatype MyGetMPIType<size_t> ( )
|
||||||
|
{ return MPI_UINT64_T; }
|
||||||
|
#else
|
||||||
|
typedef int MPI_Datatype;
|
||||||
|
template <class T> inline MPI_Datatype MyGetMPIType ( ) { return 0; }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef PARALLEL
|
||||||
|
inline MPI_Comm MyMPI_SubCommunicator(MPI_Comm comm, Array<int> & procs)
|
||||||
|
{
|
||||||
|
MPI_Comm subcomm;
|
||||||
|
MPI_Group gcomm, gsubcomm;
|
||||||
|
MPI_Comm_group(comm, &gcomm);
|
||||||
|
MPI_Group_incl(gcomm, procs.Size(), &(procs[0]), &gsubcomm);
|
||||||
|
MPI_Comm_create_group(comm, gsubcomm, 6969, &subcomm);
|
||||||
|
return subcomm;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
inline MPI_Comm MyMPI_SubCommunicator(MPI_Comm comm, Array<int> & procs)
|
||||||
|
{ return comm; }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef PARALLEL
|
||||||
|
enum { MPI_TAG_CMD = 110 };
|
||||||
|
enum { MPI_TAG_MESH = 210 };
|
||||||
|
enum { MPI_TAG_VIS = 310 };
|
||||||
|
|
||||||
inline void MyMPI_Send (int i, int dest, int tag, MPI_Comm comm = ng_comm)
|
inline void MyMPI_Send (int i, int dest, int tag, MPI_Comm comm = ng_comm)
|
||||||
{
|
{
|
||||||
|
@ -17,6 +17,22 @@ namespace netgen
|
|||||||
{
|
{
|
||||||
extern bool netgen_executable_started;
|
extern bool netgen_executable_started;
|
||||||
extern shared_ptr<NetgenGeometry> ng_geometry;
|
extern shared_ptr<NetgenGeometry> ng_geometry;
|
||||||
|
#ifdef PARALLEL
|
||||||
|
/** we need allreduce in python-wrapped communicators **/
|
||||||
|
template <typename T>
|
||||||
|
inline T MyMPI_AllReduceNG (T d, const MPI_Op & op = MPI_SUM, MPI_Comm comm = ng_comm)
|
||||||
|
{
|
||||||
|
T global_d;
|
||||||
|
MPI_Allreduce ( &d, &global_d, 1, MyGetMPIType<T>(), op, comm);
|
||||||
|
return global_d;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
enum { MPI_SUM = 0, MPI_MIN = 1, MPI_MAX = 2 };
|
||||||
|
typedef int MPI_Op;
|
||||||
|
template <typename T>
|
||||||
|
inline T MyMPI_AllReduceNG (T d, const MPI_Op & op = MPI_SUM, MPI_Comm comm = ng_comm)
|
||||||
|
{ return d; }
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -503,19 +519,21 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
|
|||||||
py::class_<Mesh,shared_ptr<Mesh>>(m, "Mesh")
|
py::class_<Mesh,shared_ptr<Mesh>>(m, "Mesh")
|
||||||
// .def(py::init<>("create empty mesh"))
|
// .def(py::init<>("create empty mesh"))
|
||||||
|
|
||||||
.def(py::init( [] (int dim)
|
.def(py::init( [] (int dim, shared_ptr<PyMPI_Comm> pycomm)
|
||||||
{
|
{
|
||||||
auto mesh = make_shared<Mesh>();
|
auto mesh = make_shared<Mesh>();
|
||||||
mesh->SetCommunicator(netgen::ng_comm);
|
mesh->SetCommunicator(pycomm!=nullptr ? pycomm->comm : netgen::ng_comm);
|
||||||
mesh -> SetDimension(dim);
|
mesh -> SetDimension(dim);
|
||||||
SetGlobalMesh(mesh); // for visualization
|
SetGlobalMesh(mesh); // for visualization
|
||||||
mesh -> SetGeometry (nullptr);
|
mesh -> SetGeometry (nullptr);
|
||||||
return mesh;
|
return mesh;
|
||||||
} ),
|
} ),
|
||||||
py::arg("dim")=3
|
py::arg("dim")=3, py::arg("comm")=nullptr
|
||||||
)
|
)
|
||||||
.def(NGSPickle<Mesh>())
|
.def(NGSPickle<Mesh>())
|
||||||
|
.def_property_readonly("comm", [](const Mesh & amesh)
|
||||||
|
{ return make_shared<PyMPI_Comm>(amesh.GetCommunicator()); },
|
||||||
|
"MPI-communicator the Mesh lives in")
|
||||||
/*
|
/*
|
||||||
.def("__init__",
|
.def("__init__",
|
||||||
[](Mesh *instance, int dim)
|
[](Mesh *instance, int dim)
|
||||||
@ -528,15 +546,25 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
.def_property_readonly("_timestamp", &Mesh::GetTimeStamp)
|
.def_property_readonly("_timestamp", &Mesh::GetTimeStamp)
|
||||||
|
.def("Distribute", [](Mesh & self, shared_ptr<PyMPI_Comm> pycomm) {
|
||||||
|
MPI_Comm comm = pycomm!=nullptr ? pycomm->comm : self.GetCommunicator();
|
||||||
|
self.SetCommunicator(comm);
|
||||||
|
if(MyMPI_GetNTasks(comm)==1) return;
|
||||||
|
if(MyMPI_GetNTasks(comm)==2) throw NgException("Sorry, cannot handle communicators with NP=2!");
|
||||||
|
cout << " rank " << MyMPI_GetId(comm) << " of " << MyMPI_GetNTasks(comm) << " called Distribute " << endl;
|
||||||
|
if(MyMPI_GetId(comm)==0) self.Distribute();
|
||||||
|
else self.SendRecvMesh();
|
||||||
|
}, py::arg("comm")=nullptr)
|
||||||
.def("Load", FunctionPointer
|
.def("Load", FunctionPointer
|
||||||
([](Mesh & self, const string & filename)
|
([](Mesh & self, const string & filename)
|
||||||
{
|
{
|
||||||
istream * infile;
|
istream * infile;
|
||||||
|
|
||||||
#ifdef PARALLEL
|
MPI_Comm comm = self.GetCommunicator();
|
||||||
MPI_Comm_rank(netgen::ng_comm, &id);
|
id = MyMPI_GetId(comm);
|
||||||
MPI_Comm_size(netgen::ng_comm, &ntasks);
|
ntasks = MyMPI_GetNTasks(comm);
|
||||||
|
|
||||||
|
#ifdef PARALLEL
|
||||||
char* buf = nullptr;
|
char* buf = nullptr;
|
||||||
int strs = 0;
|
int strs = 0;
|
||||||
if(id==0) {
|
if(id==0) {
|
||||||
@ -564,10 +592,10 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Scatter the geometry-string **/
|
/** Scatter the geometry-string **/
|
||||||
MPI_Bcast(&strs, 1, MPI_INT, 0, MPI_COMM_WORLD);
|
MPI_Bcast(&strs, 1, MPI_INT, 0, comm);
|
||||||
if(id!=0)
|
if(id!=0)
|
||||||
buf = new char[strs];
|
buf = new char[strs];
|
||||||
MPI_Bcast(buf, strs, MPI_CHAR, 0, MPI_COMM_WORLD);
|
MPI_Bcast(buf, strs, MPI_CHAR, 0, comm);
|
||||||
if(id==0)
|
if(id==0)
|
||||||
delete infile;
|
delete infile;
|
||||||
infile = new istringstream(string((const char*)buf, (size_t)strs));
|
infile = new istringstream(string((const char*)buf, (size_t)strs));
|
||||||
@ -922,6 +950,51 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
|
|||||||
printmessage_importance = importance;
|
printmessage_importance = importance;
|
||||||
return old;
|
return old;
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
py::class_<PyMPI_Comm, shared_ptr<PyMPI_Comm>> (m, "MPI_Comm")
|
||||||
|
.def_property_readonly ("rank", &PyMPI_Comm::Rank)
|
||||||
|
.def_property_readonly ("size", &PyMPI_Comm::Size)
|
||||||
|
// .def_property_readonly ("rank", [](PyMPI_Comm & c) { cout << "rank for " << c.comm << endl; return c.Rank(); })
|
||||||
|
// .def_property_readonly ("size", [](PyMPI_Comm & c) { cout << "size for " << c.comm << endl; return c.Size(); })
|
||||||
|
#ifdef PARALLEL
|
||||||
|
.def("Barrier", [](PyMPI_Comm & c) { MPI_Barrier(c.comm); })
|
||||||
|
.def("WTime", [](PyMPI_Comm & c) { return MPI_Wtime(); })
|
||||||
|
#else
|
||||||
|
.def("Barrier", [](PyMPI_Comm & c) { })
|
||||||
|
.def("WTime", [](PyMPI_Comm & c) { return -1.0; })
|
||||||
|
#endif
|
||||||
|
.def("Sum", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
|
||||||
|
.def("Min", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
|
||||||
|
.def("Max", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
|
||||||
|
.def("Sum", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
|
||||||
|
.def("Min", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
|
||||||
|
.def("Max", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
|
||||||
|
.def("Sum", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
|
||||||
|
.def("Min", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
|
||||||
|
.def("Max", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
|
||||||
|
.def("SubComm", [](PyMPI_Comm & c, py::list proc_list) -> shared_ptr<PyMPI_Comm> {
|
||||||
|
Array<int> procs;
|
||||||
|
if (py::extract<py::list> (proc_list).check()) {
|
||||||
|
py::list pylist = py::extract<py::list> (proc_list)();
|
||||||
|
procs.SetSize(py::len(pylist));
|
||||||
|
for (int i = 0; i < py::len(pylist); i++)
|
||||||
|
procs[i] = py::extract<int>(pylist[i])();
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
throw Exception("SubComm needs a list!");
|
||||||
|
}
|
||||||
|
if(!procs.Size()) {
|
||||||
|
cout << "warning, tried to construct empty communicator, returning MPI_COMM_NULL" << endl;
|
||||||
|
return make_shared<PyMPI_Comm>(MPI_COMM_NULL);
|
||||||
|
}
|
||||||
|
else if(procs.Size()==2) {
|
||||||
|
throw Exception("Sorry, NGSolve cannot handle NP=2.");
|
||||||
|
}
|
||||||
|
MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs);
|
||||||
|
return make_shared<PyMPI_Comm>(subcomm, true);
|
||||||
|
}, py::arg("procs"));
|
||||||
|
;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PYBIND11_MODULE(libmesh, m) {
|
PYBIND11_MODULE(libmesh, m) {
|
||||||
|
Loading…
Reference in New Issue
Block a user