mpi_wrapper

This commit is contained in:
Joachim Schöberl 2019-02-11 21:37:00 +01:00
parent 1074593664
commit 9ced2f561f
4 changed files with 67 additions and 61 deletions

View File

@ -23,8 +23,11 @@ namespace ngcore
template <> struct MPI_typetrait<char> { template <> struct MPI_typetrait<char> {
static MPI_Datatype MPIType () { return MPI_CHAR; } }; static MPI_Datatype MPIType () { return MPI_CHAR; } };
template <> struct MPI_typetrait<unsigned char> {
static MPI_Datatype MPIType () { return MPI_CHAR; } };
template <> struct MPI_typetrait<size_t> { template <> struct MPI_typetrait<size_t> {
static MPI_Datatype MPIType () { return MPI_UNIT64_T; } }; static MPI_Datatype MPIType () { return MPI_UINT64_T; } };
template <> struct MPI_typetrait<double> { template <> struct MPI_typetrait<double> {
static MPI_Datatype MPIType () { return MPI_DOUBLE; } }; static MPI_Datatype MPIType () { return MPI_DOUBLE; } };
@ -43,6 +46,7 @@ namespace ngcore
{ {
MPI_Comm comm; MPI_Comm comm;
int * refcount; int * refcount;
int rank, size;
public: public:
NgMPI_Comm (MPI_Comm _comm, bool owns = false) NgMPI_Comm (MPI_Comm _comm, bool owns = false)
: comm(_comm) : comm(_comm)
@ -51,16 +55,19 @@ namespace ngcore
refcount = nullptr; refcount = nullptr;
else else
refcount = new int{1}; refcount = new int{1};
MPI_Comm_rank(comm, &rank);
MPI_Comm_size(comm, &size);
} }
NgMPI_Comm (const NgMPI_Comm & c) NgMPI_Comm (const NgMPI_Comm & c)
: comm(c.comm), refcount(c.refcount) : comm(c.comm), refcount(c.refcount), rank(c.rank), size(c.siez)
{ {
if (refcount) (*refcount)++; if (refcount) (*refcount)++;
} }
NgMPI_Comm (NgMPI_Comm && c) NgMPI_Comm (NgMPI_Comm && c)
: comm(c.comm), refcount(c.refcount) : comm(c.comm), refcount(c.refcount), rank(c.rank), size(c.size)
{ {
c.refcount = nullptr; c.refcount = nullptr;
} }
@ -74,8 +81,21 @@ namespace ngcore
operator MPI_Comm() const { return comm; } operator MPI_Comm() const { return comm; }
auto Rank() const { int r; MPI_Comm_rank(comm, &r); return r; } int Rank() const { return rank; } // int r; MPI_Comm_rank(comm, &r); return r; }
auto Size() const { int s; MPI_Comm_size(comm, &s); return s; } int Size() const { return size; } // int s; MPI_Comm_size(comm, &s); return s; }
template<typename T, typename T2 = decltype(GetMPIType<T>())>
void Send( T & val, int dest, int tag) {
MPI_Send (&val, 1, GetMPIType<T>(), dest, tag, comm);
}
template<typename T, typename T2 = decltype(GetMPIType<T>())>
void MyMPI_Recv (T & val, int src, int tag) {
MPI_Recv (&val, 1, GetMPIType<T>(), src, tag, comm, MPI_STATUS_IGNORE);
}
}; };
@ -90,6 +110,14 @@ namespace ngcore
size_t Rank() const { return 0; } size_t Rank() const { return 0; }
size_t Size() const { return 1; } size_t Size() const { return 1; }
template<typename T>
void Send( T & val, int dest, int tag) { ; }
template<typename T>
void MyMPI_Recv (T & val, int src, int tag) { ; }
}; };
#endif #endif

View File

@ -23,7 +23,7 @@ namespace netgen
#endif #endif
/** This is the "standard" communicator that will be used for netgen-objects. **/ /** This is the "standard" communicator that will be used for netgen-objects. **/
extern DLL_HEADER MPI_Comm ng_comm; extern DLL_HEADER NgMPI_Comm ng_comm;
#ifdef PARALLEL #ifdef PARALLEL
inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm) inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm)
@ -44,6 +44,7 @@ namespace netgen
inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { return 0; } inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { return 0; }
#endif #endif
/*
#ifdef PARALLEL #ifdef PARALLEL
// For python wrapping of communicators // For python wrapping of communicators
struct PyMPI_Comm { struct PyMPI_Comm {
@ -68,6 +69,7 @@ namespace netgen
inline int Size() const { return 1; } inline int Size() const { return 1; }
}; };
#endif #endif
*/
#ifdef PARALLEL #ifdef PARALLEL
template <class T> template <class T>

View File

@ -32,7 +32,7 @@ namespace netgen
// TraceGlobal glob2("global2"); // TraceGlobal glob2("global2");
// global communicator for netgen // global communicator for netgen
DLL_HEADER MPI_Comm ng_comm = MPI_COMM_WORLD; DLL_HEADER MyMPI_Comm ng_comm = MPI_COMM_WORLD;
weak_ptr<Mesh> global_mesh; weak_ptr<Mesh> global_mesh;
void SetGlobalMesh (shared_ptr<Mesh> m) void SetGlobalMesh (shared_ptr<Mesh> m)

View File

@ -494,20 +494,20 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
py::class_<Mesh,shared_ptr<Mesh>>(m, "Mesh") py::class_<Mesh,shared_ptr<Mesh>>(m, "Mesh")
// .def(py::init<>("create empty mesh")) // .def(py::init<>("create empty mesh"))
.def(py::init( [] (int dim, shared_ptr<PyMPI_Comm> pycomm) .def(py::init( [] (int dim, NgMPI_Comm comm)
{ {
auto mesh = make_shared<Mesh>(); auto mesh = make_shared<Mesh>();
mesh->SetCommunicator(pycomm!=nullptr ? pycomm->comm : netgen::ng_comm); mesh->SetCommunicator(comm);
mesh -> SetDimension(dim); mesh -> SetDimension(dim);
SetGlobalMesh(mesh); // for visualization SetGlobalMesh(mesh); // for visualization
mesh -> SetGeometry (nullptr); mesh -> SetGeometry (nullptr);
return mesh; return mesh;
} ), } ),
py::arg("dim")=3, py::arg("comm")=nullptr py::arg("dim")=3, py::arg("comm")=NgMPI_Comm(ng_comm)
) )
.def(NGSPickle<Mesh>()) .def(NGSPickle<Mesh>())
.def_property_readonly("comm", [](const Mesh & amesh) .def_property_readonly("comm", [](const Mesh & amesh) -> NgMPI_Comm
{ return make_shared<PyMPI_Comm>(amesh.GetCommunicator()); }, { return amesh.GetCommunicator(); },
"MPI-communicator the Mesh lives in") "MPI-communicator the Mesh lives in")
/* /*
.def("__init__", .def("__init__",
@ -521,8 +521,7 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
*/ */
.def_property_readonly("_timestamp", &Mesh::GetTimeStamp) .def_property_readonly("_timestamp", &Mesh::GetTimeStamp)
.def("Distribute", [](shared_ptr<Mesh> self, shared_ptr<PyMPI_Comm> pycomm) { .def("Distribute", [](shared_ptr<Mesh> self, NgMPI_Comm comm) {
MPI_Comm comm = pycomm!=nullptr ? pycomm->comm : self->GetCommunicator();
self->SetCommunicator(comm); self->SetCommunicator(comm);
if(MyMPI_GetNTasks(comm)==1) return self; if(MyMPI_GetNTasks(comm)==1) return self;
// if(MyMPI_GetNTasks(comm)==2) throw NgException("Sorry, cannot handle communicators with NP=2!"); // if(MyMPI_GetNTasks(comm)==2) throw NgException("Sorry, cannot handle communicators with NP=2!");
@ -530,10 +529,10 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
if(MyMPI_GetId(comm)==0) self->Distribute(); if(MyMPI_GetId(comm)==0) self->Distribute();
else self->SendRecvMesh(); else self->SendRecvMesh();
return self; return self;
}, py::arg("comm")=nullptr) }, py::arg("comm")=NgMPI_Comm(ng_comm))
.def("Receive", [](shared_ptr<PyMPI_Comm> pycomm) { .def("Receive", [](NgMPI_Comm comm) {
auto mesh = make_shared<Mesh>(); auto mesh = make_shared<Mesh>();
mesh->SetCommunicator(pycomm->comm); mesh->SetCommunicator(comm);
mesh->SendRecvMesh(); mesh->SendRecvMesh();
return mesh; return mesh;
}) })
@ -933,57 +932,34 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
return old; return old;
})); }));
py::class_<PyMPI_Comm, shared_ptr<PyMPI_Comm>> (m, "MPI_Comm") py::class_<NgMPI_Comm> (m, "MPI_Comm")
.def_property_readonly ("rank", &PyMPI_Comm::Rank) .def_property_readonly ("rank", &NgMPI_Comm::Rank)
.def_property_readonly ("size", &PyMPI_Comm::Size) .def_property_readonly ("size", &NgMPI_Comm::Size)
// .def_property_readonly ("rank", [](PyMPI_Comm & c) { cout << "rank for " << c.comm << endl; return c.Rank(); })
// .def_property_readonly ("size", [](PyMPI_Comm & c) { cout << "size for " << c.comm << endl; return c.Size(); })
#ifdef PARALLEL #ifdef PARALLEL
.def("Barrier", [](PyMPI_Comm & c) { MPI_Barrier(c.comm); }) .def("Barrier", [](NgMPI_Comm & c) { MPI_Barrier(c); })
.def("WTime", [](PyMPI_Comm & c) { return MPI_Wtime(); }) .def("WTime", [](NgMPI_Comm & c) { return MPI_Wtime(); })
#else #else
.def("Barrier", [](PyMPI_Comm & c) { }) .def("Barrier", [](NgMPI_Comm & c) { })
.def("WTime", [](PyMPI_Comm & c) { return -1.0; }) .def("WTime", [](NgMPI_Comm & c) { return -1.0; })
#endif #endif
.def("Sum", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); }) .def("Sum", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); })
.def("Min", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); }) .def("Min", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); })
.def("Max", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); }) .def("Max", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); })
.def("Sum", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); }) .def("Sum", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); })
.def("Min", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); }) .def("Min", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); })
.def("Max", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); }) .def("Max", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); })
.def("Sum", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); }) .def("Sum", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); })
.def("Min", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); }) .def("Min", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); })
.def("Max", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); }) .def("Max", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); })
.def("SubComm", [](PyMPI_Comm & c, std::vector<int> proc_list) { .def("SubComm", [](NgMPI_Comm & c, std::vector<int> proc_list) {
Array<int> procs(proc_list.size()); Array<int> procs(proc_list.size());
for (int i = 0; i < procs.Size(); i++) for (int i = 0; i < procs.Size(); i++)
procs[i] = proc_list[i]; procs[i] = proc_list[i];
if (!procs.Contains(c.Rank())) if (!procs.Contains(c.Rank()))
throw Exception("rank "+ToString(c.Rank())+" not in subcomm"); throw Exception("rank "+ToString(c.Rank())+" not in subcomm");
MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs); MPI_Comm subcomm = MyMPI_SubCommunicator(c, procs);
return make_shared<PyMPI_Comm>(subcomm, true); return NgMPI_Comm(subcomm, true);
/*
Array<int> procs;
if (py::extract<py::list> (proc_list).check()) {
py::list pylist = py::extract<py::list> (proc_list)();
procs.SetSize(py::len(pyplist));
for (int i = 0; i < py::len(pylist); i++)
procs[i] = py::extract<int>(pylist[i])();
}
else {
throw Exception("SubComm needs a list!");
}
if(!procs.Size()) {
cout << "warning, tried to construct empty communicator, returning MPI_COMM_NULL" << endl;
return make_shared<PyMPI_Comm>(MPI_COMM_NULL);
}
else if(procs.Size()==2) {
throw Exception("Sorry, NGSolve cannot handle NP=2.");
}
MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs);
return make_shared<PyMPI_Comm>(subcomm, true);
*/
}, py::arg("procs")); }, py::arg("procs"));
; ;