From 14c39f828375db15e664640ccbde178dd6c97acb Mon Sep 17 00:00:00 2001 From: Joachim Schoeberl Date: Wed, 27 Nov 2024 21:16:36 +0100 Subject: [PATCH] introduce NgMPI_Request(s) --- libsrc/core/mpi_wrapper.hpp | 65 ++++++++++++++++++++++++++++++++- libsrc/meshing/parallelmesh.cpp | 29 +++++++++++---- libsrc/meshing/paralleltop.cpp | 10 +++-- 3 files changed, 90 insertions(+), 14 deletions(-) diff --git a/libsrc/core/mpi_wrapper.hpp b/libsrc/core/mpi_wrapper.hpp index 0f5e46af..a110912c 100644 --- a/libsrc/core/mpi_wrapper.hpp +++ b/libsrc/core/mpi_wrapper.hpp @@ -72,6 +72,57 @@ namespace ngcore return GetMPIType(); } + class NgMPI_Request + { + NG_MPI_Request request; + public: + NgMPI_Request (NG_MPI_Request requ) : request{requ} { } + NgMPI_Request (const NgMPI_Request&) = delete; + NgMPI_Request (NgMPI_Request&&) = default; + ~NgMPI_Request () { NG_MPI_Wait (&request, NG_MPI_STATUS_IGNORE); } + void Wait() { NG_MPI_Wait (&request, NG_MPI_STATUS_IGNORE); } + operator NG_MPI_Request() && + { + auto tmp = request; + request = NG_MPI_REQUEST_NULL; + return tmp; + } + }; + + class NgMPI_Requests + { + Array requests; + public: + NgMPI_Requests() = default; + ~NgMPI_Requests() { WaitAll(); } + + NgMPI_Requests & operator+= (NgMPI_Request && r) + { + requests += NG_MPI_Request(std::move(r)); + return *this; + } + + NgMPI_Requests & operator+= (NG_MPI_Request r) + { + requests += r; + return *this; + } + + void WaitAll() + { + static Timer t("NgMPI - WaitAll"); RegionTimer reg(t); + if (!requests.Size()) return; + NG_MPI_Waitall (requests.Size(), requests.Data(), NG_MPI_STATUSES_IGNORE); + } + + int WaitAny () + { + int nr; + NG_MPI_Waitany (requests.Size(), requests.Data(), &nr, NG_MPI_STATUS_IGNORE); + return nr; + } + }; + inline void MyMPI_WaitAll (FlatArray requests) { @@ -341,7 +392,7 @@ namespace ngcore template - NG_MPI_Request IBcast (std::array & d, int root = 0) const + NgMPI_Request IBcast (std::array & d, int root = 0) const { NG_MPI_Request request; NG_MPI_Ibcast (&d[0], S, GetMPIType(), root, comm, &request); @@ -349,7 +400,7 @@ namespace ngcore } template - NG_MPI_Request IBcast (FlatArray d, int root = 0) const + NgMPI_Request IBcast (FlatArray d, int root = 0) const { NG_MPI_Request request; int ds = d.Size(); @@ -481,6 +532,16 @@ namespace ngcore }; template inline NG_MPI_Datatype GetMPIType () { return -1; } + + class NgMPI_Request { }; + class NgMPI_Requests + { + public: + NgMPI_Requests operator+= (NgMPI_Request &&) { ; } + NgMPI_Requests operator+= (NG_MPI_Request r) { ; } + void WaitAll() { ; } + int WaitAny() { return 0; } + }; class NgMPI_Comm { diff --git a/libsrc/meshing/parallelmesh.cpp b/libsrc/meshing/parallelmesh.cpp index cca87834..a1e388b7 100644 --- a/libsrc/meshing/parallelmesh.cpp +++ b/libsrc/meshing/parallelmesh.cpp @@ -898,7 +898,7 @@ namespace netgen for( int k = 1; k < ntasks; k++) sendrequests[k] = comm.ISend(nnames, k, NG_MPI_TAG_MESH+7); #endif - sendrequests.SetSize(3); + // sendrequests.SetSize(3); /** Send bc/mat/cd*-names **/ // nr of names std::array nnames{0,0,0,0}; @@ -907,7 +907,10 @@ namespace netgen nnames[2] = GetNCD2Names(); nnames[3] = GetNCD3Names(); int tot_nn = nnames[0] + nnames[1] + nnames[2] + nnames[3]; - sendrequests[0] = comm.IBcast (nnames); + // sendrequests[0] = comm.IBcast (nnames); + + NgMPI_Requests requ; + requ += comm.IBcast (nnames); // (void) NG_MPI_Isend(nnames, 4, NG_MPI_INT, k, NG_MPI_TAG_MESH+6, comm, &sendrequests[k]); auto iterate_names = [&](auto func) { @@ -924,7 +927,8 @@ namespace netgen for( int k = 1; k < ntasks; k++) (void) NG_MPI_Isend(&name_sizes[0], tot_nn, NG_MPI_INT, k, NG_MPI_TAG_MESH+7, comm, &sendrequests[k]); */ - sendrequests[1] = comm.IBcast (name_sizes); + // sendrequests[1] = comm.IBcast (name_sizes); + requ += comm.IBcast (name_sizes); // names int strs = 0; iterate_names([&](auto ptr) { strs += (ptr==NULL) ? 0 : ptr->size(); }); @@ -941,10 +945,12 @@ namespace netgen (void) NG_MPI_Isend(&(compiled_names[0]), strs, NG_MPI_CHAR, k, NG_MPI_TAG_MESH+7, comm, &sendrequests[ntasks+k]); */ - sendrequests[2] = comm.IBcast (compiled_names); + // sendrequests[2] = comm.IBcast (compiled_names); + requ += comm.IBcast (compiled_names); PrintMessage ( 3, "wait for names"); - MyMPI_WaitAll (sendrequests); + // MyMPI_WaitAll (sendrequests); + requ.WaitAll(); comm.Barrier(); @@ -1208,10 +1214,13 @@ namespace netgen comm.Recv(nnames, 0, NG_MPI_TAG_MESH+7); */ - Array recvrequests(1); + // Array recvrequests(1); std::array nnames; + /* recvrequests[0] = comm.IBcast (nnames); MyMPI_WaitAll (recvrequests); + */ + comm.IBcast (nnames); // cout << "nnames = " << FlatArray(nnames) << endl; materials.SetSize(nnames[0]); @@ -1222,8 +1231,11 @@ namespace netgen int tot_nn = nnames[0] + nnames[1] + nnames[2] + nnames[3]; Array name_sizes(tot_nn); // NG_MPI_Recv(&name_sizes[0], tot_nn, NG_MPI_INT, 0, NG_MPI_TAG_MESH+7, comm, NG_MPI_STATUS_IGNORE); + /* recvrequests[0] = comm.IBcast (name_sizes); MyMPI_WaitAll (recvrequests); + */ + comm.IBcast (name_sizes); int tot_size = 0; for (int k = 0; k < tot_nn; k++) tot_size += name_sizes[k]; @@ -1231,8 +1243,9 @@ namespace netgen // NgArray compiled_names(tot_size); // NG_MPI_Recv(&(compiled_names[0]), tot_size, NG_MPI_CHAR, 0, NG_MPI_TAG_MESH+7, comm, NG_MPI_STATUS_IGNORE); Array compiled_names(tot_size); - recvrequests[0] = comm.IBcast (compiled_names); - MyMPI_WaitAll (recvrequests); + // recvrequests[0] = comm.IBcast (compiled_names); + // MyMPI_WaitAll (recvrequests); + comm.IBcast (compiled_names); tot_nn = tot_size = 0; auto write_names = [&] (auto & array) { diff --git a/libsrc/meshing/paralleltop.cpp b/libsrc/meshing/paralleltop.cpp index 191227ed..a62d3952 100644 --- a/libsrc/meshing/paralleltop.cpp +++ b/libsrc/meshing/paralleltop.cpp @@ -138,16 +138,18 @@ namespace netgen for (auto p : dps) send_data[p][nsend[p]++] = L2G(pi); - Array requests; + // Array requests; + NgMPI_Requests requests; for (int i = 0; i < comm.Size(); i++) { if (nsend[i]) - requests.Append (comm.ISend (send_data[i], i, 200)); + requests += comm.ISend (send_data[i], i, 200); if (nrecv[i]) - requests.Append (comm.IRecv (recv_data[i], i, 200)); + requests += comm.IRecv (recv_data[i], i, 200); } - MyMPI_WaitAll (requests); + // MyMPI_WaitAll (requests); + requests.WaitAll(); Array cnt(comm.Size()); cnt = 0;