modernize ParallelTopology

This commit is contained in:
Joachim Schöberl 2020-08-29 09:36:46 +02:00
parent fcee13be59
commit f8dd4be8d6
4 changed files with 164 additions and 75 deletions

View File

@ -54,7 +54,7 @@ namespace ngcore
NGCORE_API size_t * TablePrefixSum64 (FlatArray<size_t> entrysize) NGCORE_API size_t * TablePrefixSum64 (FlatArray<size_t> entrysize)
{ return TablePrefixSum2 (entrysize); } { return TablePrefixSum2 (entrysize); }
/*
BaseDynamicTable :: BaseDynamicTable (int size) BaseDynamicTable :: BaseDynamicTable (int size)
: data(size) : data(size)
{ {
@ -88,7 +88,6 @@ namespace ngcore
} }
} }
BaseDynamicTable :: ~BaseDynamicTable () BaseDynamicTable :: ~BaseDynamicTable ()
{ {
if (oneblock) if (oneblock)
@ -112,7 +111,7 @@ namespace ngcore
} }
} }
void BaseDynamicTable :: IncSize (int i, int elsize) void BaseDynamicTable :: IncSize (IndexType i, int elsize)
{ {
if (i < 0 || i >= data.Size()) if (i < 0 || i >= data.Size())
{ {
@ -135,7 +134,7 @@ namespace ngcore
line.size++; line.size++;
} }
void BaseDynamicTable :: DecSize (int i) void BaseDynamicTable :: DecSize (IndexType i)
{ {
if (i < 0 || i >= data.Size()) if (i < 0 || i >= data.Size())
{ {
@ -153,6 +152,7 @@ namespace ngcore
line.size--; line.size--;
} }
*/
void FilteredTableCreator::Add (size_t blocknr, int data) void FilteredTableCreator::Add (size_t blocknr, int data)
{ {

View File

@ -349,11 +349,13 @@ public:
}; };
/// Base class to generic DynamicTable. /// Base class to generic DynamicTable.
template <class IndexType = size_t>
class BaseDynamicTable class BaseDynamicTable
{ {
protected: protected:
static constexpr IndexType BASE = IndexBASE<IndexType>();
/// ///
struct linestruct struct linestruct
{ {
@ -366,24 +368,106 @@ public:
}; };
/// ///
Array<linestruct> data; Array<linestruct, IndexType> data;
/// ///
char * oneblock; char * oneblock;
public: public:
/// ///
NGCORE_API BaseDynamicTable (int size); BaseDynamicTable (int size)
: data(size)
{
for (auto & d : data)
{
d.maxsize = 0;
d.size = 0;
d.col = nullptr;
}
oneblock = nullptr;
}
/// ///
NGCORE_API BaseDynamicTable (const Array<int> & entrysizes, int elemsize); BaseDynamicTable (const Array<int, IndexType> & entrysizes, int elemsize)
: data(entrysizes.Size())
{
int cnt = 0;
int n = entrysizes.Size();
for (auto es : entrysizes)
cnt += es;
oneblock = new char[elemsize * cnt];
cnt = 0;
for (auto i : Range(data))
{
data[i].maxsize = entrysizes[i];
data[i].size = 0;
data[i].col = &oneblock[elemsize * cnt];
cnt += entrysizes[i];
}
}
/// ///
NGCORE_API ~BaseDynamicTable (); ~BaseDynamicTable ()
{
if (oneblock)
delete [] oneblock;
else
for (auto & d : data)
delete [] static_cast<char*> (d.col);
}
/// Changes Size of table to size, deletes data /// Changes Size of table to size, deletes data
NGCORE_API void SetSize (int size); void SetSize (int size)
{
for (auto & d : data)
delete [] static_cast<char*> (d.col);
data.SetSize(size);
for (auto & d : data)
{
d.maxsize = 0;
d.size = 0;
d.col = NULL;
}
}
/// ///
NGCORE_API void IncSize (int i, int elsize); void IncSize (IndexType i, int elsize)
{
NETGEN_CHECK_RANGE(i,BASE,data.Size()+BASE);
linestruct & line = data[i];
if (line.size == line.maxsize)
{
void * p = new char [(2*line.maxsize+5) * elsize];
memcpy (p, line.col, line.maxsize * elsize);
delete [] static_cast<char*> (line.col);
line.col = p;
line.maxsize = 2*line.maxsize+5;
}
line.size++;
}
NGCORE_API void DecSize (int i); void DecSize (IndexType i)
{
NETGEN_CHECK_RANGE(i,BASE,data.Size()+BASE);
/*
if (i < 0 || i >= data.Size())
{
std::cerr << "BaseDynamicTable::Dec: Out of range" << std::endl;
return;
}
*/
linestruct & line = data[i];
if (line.size == 0)
throw Exception ("BaseDynamicTable::Dec: EntrySize < 0");
line.size--;
}
}; };
@ -394,17 +478,19 @@ public:
A DynamicTable contains entries of variable size. Entry sizes can A DynamicTable contains entries of variable size. Entry sizes can
be increased dynamically. be increased dynamically.
*/ */
template <class T> template <class T, class IndexType = size_t>
class DynamicTable : public BaseDynamicTable class DynamicTable : public BaseDynamicTable<IndexType>
{ {
using BaseDynamicTable<IndexType>::data;
using BaseDynamicTable<IndexType>::oneblock;
public: public:
/// Creates table of size size /// Creates table of size size
DynamicTable (int size = 0) DynamicTable (int size = 0)
: BaseDynamicTable (size) { ; } : BaseDynamicTable<IndexType> (size) { }
/// Creates table with a priori fixed entry sizes. /// Creates table with a priori fixed entry sizes.
DynamicTable (const Array<int> & entrysizes) DynamicTable (const Array<int, IndexType> & entrysizes)
: BaseDynamicTable (entrysizes, sizeof(T)) { ; } : BaseDynamicTable<IndexType> (entrysizes, sizeof(T)) { }
DynamicTable & operator= (DynamicTable && tab2) DynamicTable & operator= (DynamicTable && tab2)
{ {
@ -412,19 +498,19 @@ public:
Swap (oneblock, tab2.oneblock); Swap (oneblock, tab2.oneblock);
return *this; return *this;
} }
/// Inserts element acont into row i. Does not test if already used. /// Inserts element acont into row i. Does not test if already used.
void Add (int i, const T & acont) void Add (IndexType i, const T & acont)
{ {
if (data[i].size == data[i].maxsize) if (data[i].size == data[i].maxsize)
IncSize (i, sizeof (T)); this->IncSize (i, sizeof (T));
else else
data[i].size++; data[i].size++;
static_cast<T*> (data[i].col) [data[i].size-1] = acont; static_cast<T*> (data[i].col) [data[i].size-1] = acont;
} }
/// Inserts element acont into row i, iff not yet exists. /// Inserts element acont into row i, iff not yet exists.
void AddUnique (int i, const T & cont) void AddUnique (IndexType i, const T & cont)
{ {
int es = EntrySize (i); int es = EntrySize (i);
int * line = const_cast<int*> (GetLine (i)); int * line = const_cast<int*> (GetLine (i));
@ -436,25 +522,25 @@ public:
/// Inserts element acont into row i. Does not test if already used. /// Inserts element acont into row i. Does not test if already used.
void AddEmpty (int i) void AddEmpty (IndexType i)
{ {
IncSize (i, sizeof (T)); IncSize (i, sizeof (T));
} }
/** Set the nr-th element in the i-th row to acont. /** Set the nr-th element in the i-th row to acont.
Does not check for overflow. */ Does not check for overflow. */
void Set (int i, int nr, const T & acont) void Set (IndexType i, int nr, const T & acont)
{ static_cast<T*> (data[i].col)[nr] = acont; } { static_cast<T*> (data[i].col)[nr] = acont; }
/** Returns the nr-th element in the i-th row. /** Returns the nr-th element in the i-th row.
Does not check for overflow. */ Does not check for overflow. */
const T & Get (int i, int nr) const const T & Get (IndexType i, int nr) const
{ return static_cast<T*> (data[i].col)[nr]; } { return static_cast<T*> (data[i].col)[nr]; }
/** Returns pointer to the first element in row i. */ /** Returns pointer to the first element in row i. */
const T * GetLine (int i) const const T * GetLine (IndexType i) const
{ return static_cast<T*> (data[i].col); } { return static_cast<T*> (data[i].col); }
@ -463,15 +549,15 @@ public:
{ return data.Size(); } { return data.Size(); }
/// Returns size of the i-th row. /// Returns size of the i-th row.
int EntrySize (int i) const int EntrySize (IndexType i) const
{ return data[i].size; } { return data[i].size; }
/// ///
void DecEntrySize (int i) void DecEntrySize (IndexType i)
{ DecSize(i); } { DecSize(i); }
/// Access entry i /// Access entry i
FlatArray<T> operator[] (int i) FlatArray<T> operator[] (IndexType i)
{ return FlatArray<T> (data[i].size, static_cast<T*> (data[i].col)); } { return FlatArray<T> (data[i].size, static_cast<T*> (data[i].col)); }
/* /*
@ -480,7 +566,7 @@ public:
ConstFlatArray operator[] (int i) const ConstFlatArray operator[] (int i) const
{ return FlatArray<T> (data[i].size, static_cast<T*> (data[i].col)); } { return FlatArray<T> (data[i].size, static_cast<T*> (data[i].col)); }
*/ */
FlatArray<T> operator[] (int i) const FlatArray<T> operator[] (IndexType i) const
{ return FlatArray<T> (data[i].size, static_cast<T*> (data[i].col)); } { return FlatArray<T> (data[i].size, static_cast<T*> (data[i].col)); }
}; };

View File

@ -51,8 +51,8 @@ namespace netgen
if ( mesh.GetCommunicator().Size() == 1 ) return; if ( mesh.GetCommunicator().Size() == 1 ) return;
int ned = mesh.GetTopology().GetNEdges(); size_t ned = mesh.GetTopology().GetNEdges();
int nfa = mesh.GetTopology().GetNFaces(); size_t nfa = mesh.GetTopology().GetNFaces();
if (glob_edge.Size() != ned) if (glob_edge.Size() != ned)
{ {
@ -89,29 +89,30 @@ namespace netgen
*testout << "enumerate globally, loc2distvert.size = " << loc2distvert.Size() *testout << "enumerate globally, loc2distvert.size = " << loc2distvert.Size()
<< ", glob_vert.size = " << glob_vert.Size() << endl; << ", glob_vert.size = " << glob_vert.Size() << endl;
// *testout << "old glob_vert = " << endl << glob_vert << endl;
if (rank == 0) if (rank == 0)
nv = 0; nv = 0;
IntRange newvr(oldnv, nv); // new vertex range // IntRange newvr(oldnv, nv); // new vertex range
auto new_pir = Range(PointIndex(oldnv+PointIndex::BASE),
PointIndex(nv+PointIndex::BASE));
glob_vert.SetSize (nv); glob_vert.SetSize (nv);
glob_vert.Range(newvr) = -1; glob_vert.Range(oldnv, nv) = -1;
int num_master_points = 0; int num_master_points = 0;
for (auto i : newvr) for (auto pi : new_pir)
{ {
auto dps = GetDistantPNums(i); auto dps = GetDistantProcs(pi);
// check sorted: // check sorted:
for (int j = 0; j+1 < dps.Size(); j++) for (int j = 0; j+1 < dps.Size(); j++)
if (dps[j+1] < dps[j]) cout << "wrong sort" << endl; if (dps[j+1] < dps[j]) cout << "wrong sort" << endl;
if (dps.Size() == 0 || dps[0] > comm.Rank()) if (dps.Size() == 0 || dps[0] > comm.Rank())
glob_vert[i] = num_master_points++; L2G(pi) = num_master_points++;
} }
*testout << "nummaster = " << num_master_points << endl; *testout << "nummaster = " << num_master_points << endl;
Array<int> first_master_point(comm.Size()); Array<int> first_master_point(comm.Size());
@ -120,7 +121,7 @@ namespace netgen
if (comm.AllReduce (oldnv, MPI_SUM) == 0) if (comm.AllReduce (oldnv, MPI_SUM) == 0)
max_oldv = PointIndex::BASE-1; max_oldv = PointIndex::BASE-1;
size_t num_glob_points = max_oldv+1; // PointIndex::BASE; size_t num_glob_points = max_oldv+1;
for (int i = 0; i < comm.Size(); i++) for (int i = 0; i < comm.Size(); i++)
{ {
int cur = first_master_point[i]; int cur = first_master_point[i];
@ -128,9 +129,9 @@ namespace netgen
num_glob_points += cur; num_glob_points += cur;
} }
for (auto i : newvr) for (auto pi : new_pir)
if (glob_vert[i] != -1) if (L2G(pi) != -1)
glob_vert[i] += first_master_point[comm.Rank()]; L2G(pi) += first_master_point[comm.Rank()];
// ScatterDofData (global_nums); // ScatterDofData (global_nums);
@ -139,12 +140,12 @@ namespace netgen
nrecv = 0; nrecv = 0;
/** Count send/recv size **/ /** Count send/recv size **/
for (auto i : newvr) for (auto pi : new_pir)
{ {
auto dps = GetDistantPNums(i); auto dps = GetDistantProcs(pi);
if (!dps.Size()) continue; if (!dps.Size()) continue;
if (rank < dps[0]) if (rank < dps[0])
for(auto p:dps) for (auto p : dps)
nsend[p]++; nsend[p]++;
else else
nrecv[dps[0]]++; nrecv[dps[0]]++;
@ -155,14 +156,12 @@ namespace netgen
/** Fill send_data **/ /** Fill send_data **/
nsend = 0; nsend = 0;
for (auto i : newvr) for (auto pi : new_pir)
{ if (auto dps = GetDistantProcs(pi); dps.Size())
auto dps = GetDistantPNums(i); if (rank < dps[0])
if (dps.Size() && rank < dps[0]) for (auto p : dps)
for(auto p : dps) send_data[p][nsend[p]++] = L2G(pi);
send_data[p][nsend[p]++] = glob_vert[i];
}
Array<MPI_Request> requests; Array<MPI_Request> requests;
for (int i = 0; i < comm.Size(); i++) for (int i = 0; i < comm.Size(); i++)
{ {
@ -176,21 +175,23 @@ namespace netgen
Array<int> cnt(comm.Size()); Array<int> cnt(comm.Size());
cnt = 0; cnt = 0;
for (auto i : newvr) /*
for (auto pi : new_pir)
{ {
auto dps = GetDistantPNums(i); auto dps = GetDistantProcs(pi);
if (dps.Size() > 0 && dps[0] < comm.Rank()) if (dps.Size() > 0 && dps[0] < comm.Rank())
{ {
int master = comm.Size(); int master = dps[0];
for (int j = 0; j < dps.Size(); j++) L2G(pi) = recv_data[master][cnt[master]++];
master = min (master, dps[j]);
if (master != dps[0])
cout << "master not the first one !" << endl;
glob_vert[i] = recv_data[master][cnt[master]++];
} }
} }
*/
for (auto pi : new_pir)
if (auto dps = GetDistantProcs(pi); dps.Size())
if (int master = dps[0]; master < comm.Rank())
L2G(pi) = recv_data[master][cnt[master]++];
/* /*
if (PointIndex::BASE==1) if (PointIndex::BASE==1)
for (auto & i : glob_vert) for (auto & i : glob_vert)
@ -208,13 +209,11 @@ namespace netgen
Array<int> index0(glob_vert.Size()); Array<int> index0(glob_vert.Size());
for (int pi : Range(index0)) for (int pi : Range(index0))
index0[pi] = pi; index0[pi] = pi;
QuickSortI (FlatArray<int> (glob_vert), index0); QuickSortI (glob_vert, index0);
comm.Barrier(); for (size_t i = 0; i+1 < glob_vert.Size(); i++)
for (int i = 0; i+1 < glob_vert.Size(); i++)
if (glob_vert[index0[i]] > glob_vert[index0[i+1]]) if (glob_vert[index0[i]] > glob_vert[index0[i+1]])
cout << "wrong ordering" << endl; cout << "wrong ordering" << endl;
comm.Barrier();
if (rank != 0) if (rank != 0)
{ {
@ -272,7 +271,7 @@ namespace netgen
// *testout << "l " << i << " globi "<< glob_vert[i] << " dist = " << loc2distvert[i] << endl; // *testout << "l " << i << " globi "<< glob_vert[i] << " dist = " << loc2distvert[i] << endl;
} }
for (int i = 0; i+1 < glob_vert.Size(); i++) for (size_t i = 0; i+1 < glob_vert.Size(); i++)
if (glob_vert[i] > glob_vert[i+1]) if (glob_vert[i] > glob_vert[i+1])
cout << "wrong ordering of globvert" << endl; cout << "wrong ordering of globvert" << endl;
@ -536,11 +535,11 @@ namespace netgen
// build exchange vertices // build exchange vertices
cnt_send = 0; cnt_send = 0;
for (PointIndex pi : mesh.Points().Range()) for (PointIndex pi : mesh.Points().Range())
for (int dist : GetDistantPNums(pi-PointIndex::BASE)) for (int dist : GetDistantProcs(pi))
cnt_send[dist-1]++; cnt_send[dist-1]++;
TABLE<int> dest2vert(cnt_send); TABLE<int> dest2vert(cnt_send);
for (PointIndex pi : mesh.Points().Range()) for (PointIndex pi : mesh.Points().Range())
for (int dist : GetDistantPNums(pi-PointIndex::BASE)) for (int dist : GetDistantProcs(pi))
dest2vert.Add (dist-1, pi); dest2vert.Add (dist-1, pi);
for (PointIndex pi = PointIndex::BASE; pi < newnv+PointIndex::BASE; pi++) for (PointIndex pi = PointIndex::BASE; pi < newnv+PointIndex::BASE; pi++)
@ -687,11 +686,11 @@ namespace netgen
// build exchange vertices // build exchange vertices
cnt_send = 0; cnt_send = 0;
for (PointIndex pi : mesh.Points().Range()) for (PointIndex pi : mesh.Points().Range())
for (int dist : GetDistantPNums(pi-PointIndex::BASE)) for (int dist : GetDistantProcs(pi))
cnt_send[dist-1]++; cnt_send[dist-1]++;
TABLE<int> dest2vert(cnt_send); TABLE<int> dest2vert(cnt_send);
for (PointIndex pi : mesh.Points().Range()) for (PointIndex pi : mesh.Points().Range())
for (int dist : GetDistantPNums(pi-PointIndex::BASE)) for (int dist : GetDistantProcs(pi))
dest2vert.Add (dist-1, pi); dest2vert.Add (dist-1, pi);
MPI_Group_free(&MPI_LocalGroup); MPI_Group_free(&MPI_LocalGroup);
@ -916,11 +915,11 @@ namespace netgen
// build exchange vertices // build exchange vertices
cnt_send = 0; cnt_send = 0;
for (PointIndex pi : mesh.Points().Range()) for (PointIndex pi : mesh.Points().Range())
for (int dist : GetDistantPNums(pi-PointIndex::BASE)) for (int dist : GetDistantProcs(pi))
cnt_send[dist-1]++; cnt_send[dist-1]++;
TABLE<int> dest2vert(cnt_send); TABLE<int> dest2vert(cnt_send);
for (PointIndex pi : mesh.Points().Range()) for (PointIndex pi : mesh.Points().Range())
for (int dist : GetDistantPNums(pi-PointIndex::BASE)) for (int dist : GetDistantProcs(pi))
dest2vert.Add (dist-1, pi); dest2vert.Add (dist-1, pi);
// exchange edges // exchange edges

View File

@ -123,6 +123,10 @@ namespace netgen
FlatArray<int> GetDistantEdgeNums (int locnum) const { return loc2distedge[locnum]; } FlatArray<int> GetDistantEdgeNums (int locnum) const { return loc2distedge[locnum]; }
FlatArray<int> GetDistantProcs (PointIndex pi) const { return loc2distvert[pi-PointIndex::BASE]; } FlatArray<int> GetDistantProcs (PointIndex pi) const { return loc2distvert[pi-PointIndex::BASE]; }
auto & L2G (PointIndex pi) { return glob_vert[pi-PointIndex::BASE]; }
auto L2G (PointIndex pi) const { return glob_vert[pi-PointIndex::BASE]; }
[[deprecated("Use GetDistantProcs(..).Contains instead!")]] [[deprecated("Use GetDistantProcs(..).Contains instead!")]]
bool IsExchangeVert (int dest, int vnum) const bool IsExchangeVert (int dest, int vnum) const
{ {