mpi_wrapper

This commit is contained in:
Joachim Schöberl 2019-02-11 21:37:00 +01:00
parent 1074593664
commit 9ced2f561f
4 changed files with 67 additions and 61 deletions

View File

@ -23,8 +23,11 @@ namespace ngcore
template <> struct MPI_typetrait<char> {
static MPI_Datatype MPIType () { return MPI_CHAR; } };
template <> struct MPI_typetrait<unsigned char> {
static MPI_Datatype MPIType () { return MPI_CHAR; } };
template <> struct MPI_typetrait<size_t> {
static MPI_Datatype MPIType () { return MPI_UNIT64_T; } };
static MPI_Datatype MPIType () { return MPI_UINT64_T; } };
template <> struct MPI_typetrait<double> {
static MPI_Datatype MPIType () { return MPI_DOUBLE; } };
@ -43,6 +46,7 @@ namespace ngcore
{
MPI_Comm comm;
int * refcount;
int rank, size;
public:
NgMPI_Comm (MPI_Comm _comm, bool owns = false)
: comm(_comm)
@ -51,16 +55,19 @@ namespace ngcore
refcount = nullptr;
else
refcount = new int{1};
MPI_Comm_rank(comm, &rank);
MPI_Comm_size(comm, &size);
}
NgMPI_Comm (const NgMPI_Comm & c)
: comm(c.comm), refcount(c.refcount)
: comm(c.comm), refcount(c.refcount), rank(c.rank), size(c.siez)
{
if (refcount) (*refcount)++;
}
NgMPI_Comm (NgMPI_Comm && c)
: comm(c.comm), refcount(c.refcount)
: comm(c.comm), refcount(c.refcount), rank(c.rank), size(c.size)
{
c.refcount = nullptr;
}
@ -74,8 +81,21 @@ namespace ngcore
operator MPI_Comm() const { return comm; }
auto Rank() const { int r; MPI_Comm_rank(comm, &r); return r; }
auto Size() const { int s; MPI_Comm_size(comm, &s); return s; }
int Rank() const { return rank; } // int r; MPI_Comm_rank(comm, &r); return r; }
int Size() const { return size; } // int s; MPI_Comm_size(comm, &s); return s; }
template<typename T, typename T2 = decltype(GetMPIType<T>())>
void Send( T & val, int dest, int tag) {
MPI_Send (&val, 1, GetMPIType<T>(), dest, tag, comm);
}
template<typename T, typename T2 = decltype(GetMPIType<T>())>
void MyMPI_Recv (T & val, int src, int tag) {
MPI_Recv (&val, 1, GetMPIType<T>(), src, tag, comm, MPI_STATUS_IGNORE);
}
};
@ -90,6 +110,14 @@ namespace ngcore
size_t Rank() const { return 0; }
size_t Size() const { return 1; }
template<typename T>
void Send( T & val, int dest, int tag) { ; }
template<typename T>
void MyMPI_Recv (T & val, int src, int tag) { ; }
};
#endif

View File

@ -23,7 +23,7 @@ namespace netgen
#endif
/** This is the "standard" communicator that will be used for netgen-objects. **/
extern DLL_HEADER MPI_Comm ng_comm;
extern DLL_HEADER NgMPI_Comm ng_comm;
#ifdef PARALLEL
inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm)
@ -44,6 +44,7 @@ namespace netgen
inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { return 0; }
#endif
/*
#ifdef PARALLEL
// For python wrapping of communicators
struct PyMPI_Comm {
@ -68,7 +69,8 @@ namespace netgen
inline int Size() const { return 1; }
};
#endif
*/
#ifdef PARALLEL
template <class T>
inline MPI_Datatype MyGetMPIType ( )

View File

@ -32,7 +32,7 @@ namespace netgen
// TraceGlobal glob2("global2");
// global communicator for netgen
DLL_HEADER MPI_Comm ng_comm = MPI_COMM_WORLD;
DLL_HEADER MyMPI_Comm ng_comm = MPI_COMM_WORLD;
weak_ptr<Mesh> global_mesh;
void SetGlobalMesh (shared_ptr<Mesh> m)

View File

@ -494,20 +494,20 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
py::class_<Mesh,shared_ptr<Mesh>>(m, "Mesh")
// .def(py::init<>("create empty mesh"))
.def(py::init( [] (int dim, shared_ptr<PyMPI_Comm> pycomm)
.def(py::init( [] (int dim, NgMPI_Comm comm)
{
auto mesh = make_shared<Mesh>();
mesh->SetCommunicator(pycomm!=nullptr ? pycomm->comm : netgen::ng_comm);
mesh->SetCommunicator(comm);
mesh -> SetDimension(dim);
SetGlobalMesh(mesh); // for visualization
mesh -> SetGeometry (nullptr);
return mesh;
} ),
py::arg("dim")=3, py::arg("comm")=nullptr
py::arg("dim")=3, py::arg("comm")=NgMPI_Comm(ng_comm)
)
.def(NGSPickle<Mesh>())
.def_property_readonly("comm", [](const Mesh & amesh)
{ return make_shared<PyMPI_Comm>(amesh.GetCommunicator()); },
.def_property_readonly("comm", [](const Mesh & amesh) -> NgMPI_Comm
{ return amesh.GetCommunicator(); },
"MPI-communicator the Mesh lives in")
/*
.def("__init__",
@ -521,8 +521,7 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
*/
.def_property_readonly("_timestamp", &Mesh::GetTimeStamp)
.def("Distribute", [](shared_ptr<Mesh> self, shared_ptr<PyMPI_Comm> pycomm) {
MPI_Comm comm = pycomm!=nullptr ? pycomm->comm : self->GetCommunicator();
.def("Distribute", [](shared_ptr<Mesh> self, NgMPI_Comm comm) {
self->SetCommunicator(comm);
if(MyMPI_GetNTasks(comm)==1) return self;
// if(MyMPI_GetNTasks(comm)==2) throw NgException("Sorry, cannot handle communicators with NP=2!");
@ -530,10 +529,10 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
if(MyMPI_GetId(comm)==0) self->Distribute();
else self->SendRecvMesh();
return self;
}, py::arg("comm")=nullptr)
.def("Receive", [](shared_ptr<PyMPI_Comm> pycomm) {
}, py::arg("comm")=NgMPI_Comm(ng_comm))
.def("Receive", [](NgMPI_Comm comm) {
auto mesh = make_shared<Mesh>();
mesh->SetCommunicator(pycomm->comm);
mesh->SetCommunicator(comm);
mesh->SendRecvMesh();
return mesh;
})
@ -933,57 +932,34 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
return old;
}));
py::class_<PyMPI_Comm, shared_ptr<PyMPI_Comm>> (m, "MPI_Comm")
.def_property_readonly ("rank", &PyMPI_Comm::Rank)
.def_property_readonly ("size", &PyMPI_Comm::Size)
// .def_property_readonly ("rank", [](PyMPI_Comm & c) { cout << "rank for " << c.comm << endl; return c.Rank(); })
// .def_property_readonly ("size", [](PyMPI_Comm & c) { cout << "size for " << c.comm << endl; return c.Size(); })
py::class_<NgMPI_Comm> (m, "MPI_Comm")
.def_property_readonly ("rank", &NgMPI_Comm::Rank)
.def_property_readonly ("size", &NgMPI_Comm::Size)
#ifdef PARALLEL
.def("Barrier", [](PyMPI_Comm & c) { MPI_Barrier(c.comm); })
.def("WTime", [](PyMPI_Comm & c) { return MPI_Wtime(); })
.def("Barrier", [](NgMPI_Comm & c) { MPI_Barrier(c); })
.def("WTime", [](NgMPI_Comm & c) { return MPI_Wtime(); })
#else
.def("Barrier", [](PyMPI_Comm & c) { })
.def("WTime", [](PyMPI_Comm & c) { return -1.0; })
.def("Barrier", [](NgMPI_Comm & c) { })
.def("WTime", [](NgMPI_Comm & c) { return -1.0; })
#endif
.def("Sum", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
.def("Min", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
.def("Max", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
.def("Sum", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
.def("Min", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
.def("Max", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
.def("Sum", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
.def("Min", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
.def("Max", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
.def("SubComm", [](PyMPI_Comm & c, std::vector<int> proc_list) {
.def("Sum", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); })
.def("Min", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); })
.def("Max", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); })
.def("Sum", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); })
.def("Min", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); })
.def("Max", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); })
.def("Sum", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); })
.def("Min", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); })
.def("Max", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); })
.def("SubComm", [](NgMPI_Comm & c, std::vector<int> proc_list) {
Array<int> procs(proc_list.size());
for (int i = 0; i < procs.Size(); i++)
procs[i] = proc_list[i];
if (!procs.Contains(c.Rank()))
throw Exception("rank "+ToString(c.Rank())+" not in subcomm");
MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs);
return make_shared<PyMPI_Comm>(subcomm, true);
/*
Array<int> procs;
if (py::extract<py::list> (proc_list).check()) {
py::list pylist = py::extract<py::list> (proc_list)();
procs.SetSize(py::len(pyplist));
for (int i = 0; i < py::len(pylist); i++)
procs[i] = py::extract<int>(pylist[i])();
}
else {
throw Exception("SubComm needs a list!");
}
if(!procs.Size()) {
cout << "warning, tried to construct empty communicator, returning MPI_COMM_NULL" << endl;
return make_shared<PyMPI_Comm>(MPI_COMM_NULL);
}
else if(procs.Size()==2) {
throw Exception("Sorry, NGSolve cannot handle NP=2.");
}
MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs);
return make_shared<PyMPI_Comm>(subcomm, true);
*/
MPI_Comm subcomm = MyMPI_SubCommunicator(c, procs);
return NgMPI_Comm(subcomm, true);
}, py::arg("procs"));
;