mirror of
https://github.com/NGSolve/netgen.git
synced 2025-01-12 22:20:35 +05:00
mpi_wrapper
This commit is contained in:
parent
1074593664
commit
9ced2f561f
@ -23,8 +23,11 @@ namespace ngcore
|
||||
template <> struct MPI_typetrait<char> {
|
||||
static MPI_Datatype MPIType () { return MPI_CHAR; } };
|
||||
|
||||
template <> struct MPI_typetrait<unsigned char> {
|
||||
static MPI_Datatype MPIType () { return MPI_CHAR; } };
|
||||
|
||||
template <> struct MPI_typetrait<size_t> {
|
||||
static MPI_Datatype MPIType () { return MPI_UNIT64_T; } };
|
||||
static MPI_Datatype MPIType () { return MPI_UINT64_T; } };
|
||||
|
||||
template <> struct MPI_typetrait<double> {
|
||||
static MPI_Datatype MPIType () { return MPI_DOUBLE; } };
|
||||
@ -43,6 +46,7 @@ namespace ngcore
|
||||
{
|
||||
MPI_Comm comm;
|
||||
int * refcount;
|
||||
int rank, size;
|
||||
public:
|
||||
NgMPI_Comm (MPI_Comm _comm, bool owns = false)
|
||||
: comm(_comm)
|
||||
@ -51,16 +55,19 @@ namespace ngcore
|
||||
refcount = nullptr;
|
||||
else
|
||||
refcount = new int{1};
|
||||
|
||||
MPI_Comm_rank(comm, &rank);
|
||||
MPI_Comm_size(comm, &size);
|
||||
}
|
||||
|
||||
NgMPI_Comm (const NgMPI_Comm & c)
|
||||
: comm(c.comm), refcount(c.refcount)
|
||||
: comm(c.comm), refcount(c.refcount), rank(c.rank), size(c.siez)
|
||||
{
|
||||
if (refcount) (*refcount)++;
|
||||
}
|
||||
|
||||
NgMPI_Comm (NgMPI_Comm && c)
|
||||
: comm(c.comm), refcount(c.refcount)
|
||||
: comm(c.comm), refcount(c.refcount), rank(c.rank), size(c.size)
|
||||
{
|
||||
c.refcount = nullptr;
|
||||
}
|
||||
@ -74,8 +81,21 @@ namespace ngcore
|
||||
|
||||
operator MPI_Comm() const { return comm; }
|
||||
|
||||
auto Rank() const { int r; MPI_Comm_rank(comm, &r); return r; }
|
||||
auto Size() const { int s; MPI_Comm_size(comm, &s); return s; }
|
||||
int Rank() const { return rank; } // int r; MPI_Comm_rank(comm, &r); return r; }
|
||||
int Size() const { return size; } // int s; MPI_Comm_size(comm, &s); return s; }
|
||||
|
||||
|
||||
template<typename T, typename T2 = decltype(GetMPIType<T>())>
|
||||
void Send( T & val, int dest, int tag) {
|
||||
MPI_Send (&val, 1, GetMPIType<T>(), dest, tag, comm);
|
||||
}
|
||||
|
||||
template<typename T, typename T2 = decltype(GetMPIType<T>())>
|
||||
void MyMPI_Recv (T & val, int src, int tag) {
|
||||
MPI_Recv (&val, 1, GetMPIType<T>(), src, tag, comm, MPI_STATUS_IGNORE);
|
||||
}
|
||||
|
||||
|
||||
};
|
||||
|
||||
|
||||
@ -90,6 +110,14 @@ namespace ngcore
|
||||
|
||||
size_t Rank() const { return 0; }
|
||||
size_t Size() const { return 1; }
|
||||
|
||||
|
||||
|
||||
template<typename T>
|
||||
void Send( T & val, int dest, int tag) { ; }
|
||||
|
||||
template<typename T>
|
||||
void MyMPI_Recv (T & val, int src, int tag) { ; }
|
||||
};
|
||||
|
||||
#endif
|
||||
|
@ -23,7 +23,7 @@ namespace netgen
|
||||
#endif
|
||||
|
||||
/** This is the "standard" communicator that will be used for netgen-objects. **/
|
||||
extern DLL_HEADER MPI_Comm ng_comm;
|
||||
extern DLL_HEADER NgMPI_Comm ng_comm;
|
||||
|
||||
#ifdef PARALLEL
|
||||
inline int MyMPI_GetNTasks (MPI_Comm comm = ng_comm)
|
||||
@ -44,6 +44,7 @@ namespace netgen
|
||||
inline int MyMPI_GetId (MPI_Comm comm = ng_comm) { return 0; }
|
||||
#endif
|
||||
|
||||
/*
|
||||
#ifdef PARALLEL
|
||||
// For python wrapping of communicators
|
||||
struct PyMPI_Comm {
|
||||
@ -68,7 +69,8 @@ namespace netgen
|
||||
inline int Size() const { return 1; }
|
||||
};
|
||||
#endif
|
||||
|
||||
*/
|
||||
|
||||
#ifdef PARALLEL
|
||||
template <class T>
|
||||
inline MPI_Datatype MyGetMPIType ( )
|
||||
|
@ -32,7 +32,7 @@ namespace netgen
|
||||
// TraceGlobal glob2("global2");
|
||||
|
||||
// global communicator for netgen
|
||||
DLL_HEADER MPI_Comm ng_comm = MPI_COMM_WORLD;
|
||||
DLL_HEADER MyMPI_Comm ng_comm = MPI_COMM_WORLD;
|
||||
|
||||
weak_ptr<Mesh> global_mesh;
|
||||
void SetGlobalMesh (shared_ptr<Mesh> m)
|
||||
|
@ -494,20 +494,20 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
|
||||
py::class_<Mesh,shared_ptr<Mesh>>(m, "Mesh")
|
||||
// .def(py::init<>("create empty mesh"))
|
||||
|
||||
.def(py::init( [] (int dim, shared_ptr<PyMPI_Comm> pycomm)
|
||||
.def(py::init( [] (int dim, NgMPI_Comm comm)
|
||||
{
|
||||
auto mesh = make_shared<Mesh>();
|
||||
mesh->SetCommunicator(pycomm!=nullptr ? pycomm->comm : netgen::ng_comm);
|
||||
mesh->SetCommunicator(comm);
|
||||
mesh -> SetDimension(dim);
|
||||
SetGlobalMesh(mesh); // for visualization
|
||||
mesh -> SetGeometry (nullptr);
|
||||
return mesh;
|
||||
} ),
|
||||
py::arg("dim")=3, py::arg("comm")=nullptr
|
||||
py::arg("dim")=3, py::arg("comm")=NgMPI_Comm(ng_comm)
|
||||
)
|
||||
.def(NGSPickle<Mesh>())
|
||||
.def_property_readonly("comm", [](const Mesh & amesh)
|
||||
{ return make_shared<PyMPI_Comm>(amesh.GetCommunicator()); },
|
||||
.def_property_readonly("comm", [](const Mesh & amesh) -> NgMPI_Comm
|
||||
{ return amesh.GetCommunicator(); },
|
||||
"MPI-communicator the Mesh lives in")
|
||||
/*
|
||||
.def("__init__",
|
||||
@ -521,8 +521,7 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
|
||||
*/
|
||||
|
||||
.def_property_readonly("_timestamp", &Mesh::GetTimeStamp)
|
||||
.def("Distribute", [](shared_ptr<Mesh> self, shared_ptr<PyMPI_Comm> pycomm) {
|
||||
MPI_Comm comm = pycomm!=nullptr ? pycomm->comm : self->GetCommunicator();
|
||||
.def("Distribute", [](shared_ptr<Mesh> self, NgMPI_Comm comm) {
|
||||
self->SetCommunicator(comm);
|
||||
if(MyMPI_GetNTasks(comm)==1) return self;
|
||||
// if(MyMPI_GetNTasks(comm)==2) throw NgException("Sorry, cannot handle communicators with NP=2!");
|
||||
@ -530,10 +529,10 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
|
||||
if(MyMPI_GetId(comm)==0) self->Distribute();
|
||||
else self->SendRecvMesh();
|
||||
return self;
|
||||
}, py::arg("comm")=nullptr)
|
||||
.def("Receive", [](shared_ptr<PyMPI_Comm> pycomm) {
|
||||
}, py::arg("comm")=NgMPI_Comm(ng_comm))
|
||||
.def("Receive", [](NgMPI_Comm comm) {
|
||||
auto mesh = make_shared<Mesh>();
|
||||
mesh->SetCommunicator(pycomm->comm);
|
||||
mesh->SetCommunicator(comm);
|
||||
mesh->SendRecvMesh();
|
||||
return mesh;
|
||||
})
|
||||
@ -933,57 +932,34 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
|
||||
return old;
|
||||
}));
|
||||
|
||||
py::class_<PyMPI_Comm, shared_ptr<PyMPI_Comm>> (m, "MPI_Comm")
|
||||
.def_property_readonly ("rank", &PyMPI_Comm::Rank)
|
||||
.def_property_readonly ("size", &PyMPI_Comm::Size)
|
||||
// .def_property_readonly ("rank", [](PyMPI_Comm & c) { cout << "rank for " << c.comm << endl; return c.Rank(); })
|
||||
// .def_property_readonly ("size", [](PyMPI_Comm & c) { cout << "size for " << c.comm << endl; return c.Size(); })
|
||||
py::class_<NgMPI_Comm> (m, "MPI_Comm")
|
||||
.def_property_readonly ("rank", &NgMPI_Comm::Rank)
|
||||
.def_property_readonly ("size", &NgMPI_Comm::Size)
|
||||
|
||||
#ifdef PARALLEL
|
||||
.def("Barrier", [](PyMPI_Comm & c) { MPI_Barrier(c.comm); })
|
||||
.def("WTime", [](PyMPI_Comm & c) { return MPI_Wtime(); })
|
||||
.def("Barrier", [](NgMPI_Comm & c) { MPI_Barrier(c); })
|
||||
.def("WTime", [](NgMPI_Comm & c) { return MPI_Wtime(); })
|
||||
#else
|
||||
.def("Barrier", [](PyMPI_Comm & c) { })
|
||||
.def("WTime", [](PyMPI_Comm & c) { return -1.0; })
|
||||
.def("Barrier", [](NgMPI_Comm & c) { })
|
||||
.def("WTime", [](NgMPI_Comm & c) { return -1.0; })
|
||||
#endif
|
||||
.def("Sum", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
|
||||
.def("Min", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
|
||||
.def("Max", [](PyMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
|
||||
.def("Sum", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
|
||||
.def("Min", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
|
||||
.def("Max", [](PyMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
|
||||
.def("Sum", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c.comm); })
|
||||
.def("Min", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c.comm); })
|
||||
.def("Max", [](PyMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c.comm); })
|
||||
.def("SubComm", [](PyMPI_Comm & c, std::vector<int> proc_list) {
|
||||
.def("Sum", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); })
|
||||
.def("Min", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); })
|
||||
.def("Max", [](NgMPI_Comm & c, double x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); })
|
||||
.def("Sum", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); })
|
||||
.def("Min", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); })
|
||||
.def("Max", [](NgMPI_Comm & c, int x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); })
|
||||
.def("Sum", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_SUM, c); })
|
||||
.def("Min", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MIN, c); })
|
||||
.def("Max", [](NgMPI_Comm & c, size_t x) { return MyMPI_AllReduceNG(x, MPI_MAX, c); })
|
||||
.def("SubComm", [](NgMPI_Comm & c, std::vector<int> proc_list) {
|
||||
Array<int> procs(proc_list.size());
|
||||
for (int i = 0; i < procs.Size(); i++)
|
||||
procs[i] = proc_list[i];
|
||||
if (!procs.Contains(c.Rank()))
|
||||
throw Exception("rank "+ToString(c.Rank())+" not in subcomm");
|
||||
MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs);
|
||||
return make_shared<PyMPI_Comm>(subcomm, true);
|
||||
|
||||
/*
|
||||
Array<int> procs;
|
||||
if (py::extract<py::list> (proc_list).check()) {
|
||||
py::list pylist = py::extract<py::list> (proc_list)();
|
||||
procs.SetSize(py::len(pyplist));
|
||||
for (int i = 0; i < py::len(pylist); i++)
|
||||
procs[i] = py::extract<int>(pylist[i])();
|
||||
}
|
||||
else {
|
||||
throw Exception("SubComm needs a list!");
|
||||
}
|
||||
if(!procs.Size()) {
|
||||
cout << "warning, tried to construct empty communicator, returning MPI_COMM_NULL" << endl;
|
||||
return make_shared<PyMPI_Comm>(MPI_COMM_NULL);
|
||||
}
|
||||
else if(procs.Size()==2) {
|
||||
throw Exception("Sorry, NGSolve cannot handle NP=2.");
|
||||
}
|
||||
MPI_Comm subcomm = MyMPI_SubCommunicator(c.comm, procs);
|
||||
return make_shared<PyMPI_Comm>(subcomm, true);
|
||||
*/
|
||||
MPI_Comm subcomm = MyMPI_SubCommunicator(c, procs);
|
||||
return NgMPI_Comm(subcomm, true);
|
||||
}, py::arg("procs"));
|
||||
;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user