Fix MPICH on Linux by handling array arguments correctly

This commit is contained in:
Matthias Hochsteger 2024-05-06 17:35:44 +02:00
parent ae719d58c4
commit 692e4afe3e
3 changed files with 93 additions and 39 deletions

View File

@ -32,13 +32,13 @@ functions = [
("int", "MPI_Type_commit", "MPI_Datatype*"),
("int", "MPI_Type_contiguous", "int", "MPI_Datatype", "MPI_Datatype*"),
("int", "MPI_Type_create_resized", "MPI_Datatype", "MPI_Aint", "MPI_Aint", "MPI_Datatype*"),
("int", "MPI_Type_create_struct", "int", "int*", "MPI_Aint*", "MPI_Datatype*", "MPI_Datatype*"),
("int", "MPI_Type_create_struct", "int", "int*:0", "MPI_Aint*:0", "MPI_Datatype*:0", "MPI_Datatype*"),
("int", "MPI_Type_free", "MPI_Datatype*"),
("int", "MPI_Type_get_extent", "MPI_Datatype", "MPI_Aint*", "MPI_Aint*"),
("int", "MPI_Type_indexed", "int", "int*", "int*", "MPI_Datatype", "MPI_Datatype*"),
("int", "MPI_Type_indexed", "int", "int*:0", "int*:0", "MPI_Datatype", "MPI_Datatype*"),
("int", "MPI_Wait", "MPI_Request*", "MPI_Status*"),
("int", "MPI_Waitall", "int", "MPI_Request*", "MPI_Status*"),
("int", "MPI_Waitany", "int", "MPI_Request*", "int*", "MPI_Status*"),
("int", "MPI_Waitall", "int", "MPI_Request*:0", "MPI_Status*"),
("int", "MPI_Waitany", "int", "MPI_Request*:0", "int*", "MPI_Status*"),
]
constants = [
@ -69,8 +69,23 @@ constants = [
("void*", "MPI_IN_PLACE"),
]
def get_args(f):
return ["NG_"+a if a.startswith("MPI_") else a for a in f[2:]]
def get_args(f, counts=False):
args = []
for arg in f[2:]:
has_count = ':' in arg
if has_count:
s, count = arg.split(':')
count = int(count)
else:
s = arg
count = None
if s.startswith("MPI_"):
s = "NG_" + s
if counts:
args.append((s, count))
else:
args.append(s)
return args
def generate_declarations():
code = ""
@ -109,15 +124,24 @@ def generate_init():
for f in functions:
ret = f[0]
name = f[1]
args = get_args(f)
args = get_args(f, counts=True)
in_args =''
call_args = ''
for i, a in enumerate(args):
for i in range(len(args)):
arg, count = args[i]
if i > 0:
in_args += ', '
call_args += ', '
in_args += a + f" arg{i}"
call_args += f" ng2mpi(arg{i})"
in_args += arg + f" arg{i}"
if not arg.startswith("NG_"):
# plain type (like int, int *, etc.), just pass the argument along
call_args += f" arg{i}"
elif count is None:
# MPI type (by value or pointer), but just one object, no arrays
call_args += f" ng2mpi(arg{i})"
else:
# arrays of MPI types, we need to copy them due to incompatible size
call_args += f" ng2mpi(arg{i}, arg{count})"
code += f"NG_{name} = []({in_args})->{ret} {{ return {name}({call_args}); }};\n"
for _, name in constants:

View File

@ -33,6 +33,20 @@ NG_MPI_Status* mpi2ng(MPI_Status* status) {
NG_MPI_Comm mpi2ng(MPI_Comm comm) { return reinterpret_cast<uintptr_t>(comm); }
#endif
template <size_t size, size_t stride>
void gather_strided_array(size_t count, char* data) {
static_assert(size <= stride, "Size must be less than or equal to stride");
if constexpr (size < stride) {
char* dst = data;
char* src = data;
for (auto i : Range(count)) {
memcpy(dst, src, size);
dst += size;
src += stride;
}
}
}
template <typename T>
T cast_ng2mpi(uintptr_t obj) {
if constexpr (std::is_pointer_v<T>)
@ -49,6 +63,13 @@ T cast_ng2mpi(uintptr_t* ptr) {
return static_cast<T>(ptr);
}
template <typename T, typename TSrc>
T* cast_ng2mpi(TSrc* ptr, int count) {
gather_strided_array<sizeof(T), sizeof(TSrc)>(count,
reinterpret_cast<char*>(ptr));
return reinterpret_cast<T*>(ptr);
}
MPI_Comm ng2mpi(NG_MPI_Comm comm) {
static_assert(sizeof(MPI_Comm) <= sizeof(comm.value), "Size mismatch");
static_assert(alignof(MPI_Comm) <= alignof(NG_MPI_Comm), "Size mismatch");
@ -70,15 +91,24 @@ MPI_Group* ng2mpi(NG_MPI_Group* group) {
MPI_Datatype* ng2mpi(NG_MPI_Datatype* type) {
return cast_ng2mpi<MPI_Datatype*>(&type->value);
}
MPI_Datatype* ng2mpi(NG_MPI_Datatype* type, int count) {
return cast_ng2mpi<MPI_Datatype>(&type->value, count);
}
MPI_Request* ng2mpi(NG_MPI_Request* request) {
return cast_ng2mpi<MPI_Request*>(&request->value);
}
MPI_Request* ng2mpi(NG_MPI_Request* request, int count) {
return cast_ng2mpi<MPI_Request>(&request->value, count);
}
MPI_Status* ng2mpi(NG_MPI_Status* status) {
return reinterpret_cast<MPI_Status*>(status);
}
MPI_Aint* ng2mpi(NG_MPI_Aint* aint) {
return reinterpret_cast<MPI_Aint*>(aint);
}
MPI_Aint* ng2mpi(NG_MPI_Aint* aint, int count) {
return cast_ng2mpi<MPI_Aint>(aint, count);
}
MPI_Datatype ng2mpi(NG_MPI_Datatype type) {
static_assert(sizeof(MPI_Datatype) <= sizeof(type.value), "Size mismatch");

View File

@ -1,43 +1,43 @@
NG_MPI_Wtime = []()->double { return MPI_Wtime(); };
NG_MPI_Allgather = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, NG_MPI_Comm arg6)->int { return MPI_Allgather( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6)); };
NG_MPI_Allreduce = [](void* arg0, void* arg1, int arg2, NG_MPI_Datatype arg3, NG_MPI_Op arg4, NG_MPI_Comm arg5)->int { return MPI_Allreduce( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5)); };
NG_MPI_Alltoall = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, NG_MPI_Comm arg6)->int { return MPI_Alltoall( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6)); };
NG_MPI_Allgather = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, NG_MPI_Comm arg6)->int { return MPI_Allgather( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), ng2mpi(arg6)); };
NG_MPI_Allreduce = [](void* arg0, void* arg1, int arg2, NG_MPI_Datatype arg3, NG_MPI_Op arg4, NG_MPI_Comm arg5)->int { return MPI_Allreduce( arg0, arg1, arg2, ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5)); };
NG_MPI_Alltoall = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, NG_MPI_Comm arg6)->int { return MPI_Alltoall( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), ng2mpi(arg6)); };
NG_MPI_Barrier = [](NG_MPI_Comm arg0)->int { return MPI_Barrier( ng2mpi(arg0)); };
NG_MPI_Bcast = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, NG_MPI_Comm arg4)->int { return MPI_Bcast( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4)); };
NG_MPI_Comm_create_group = [](NG_MPI_Comm arg0, NG_MPI_Group arg1, int arg2, NG_MPI_Comm* arg3)->int { return MPI_Comm_create_group( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3)); };
NG_MPI_Bcast = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, NG_MPI_Comm arg4)->int { return MPI_Bcast( arg0, arg1, ng2mpi(arg2), arg3, ng2mpi(arg4)); };
NG_MPI_Comm_create_group = [](NG_MPI_Comm arg0, NG_MPI_Group arg1, int arg2, NG_MPI_Comm* arg3)->int { return MPI_Comm_create_group( ng2mpi(arg0), ng2mpi(arg1), arg2, ng2mpi(arg3)); };
NG_MPI_Comm_free = [](NG_MPI_Comm* arg0)->int { return MPI_Comm_free( ng2mpi(arg0)); };
NG_MPI_Comm_group = [](NG_MPI_Comm arg0, NG_MPI_Group* arg1)->int { return MPI_Comm_group( ng2mpi(arg0), ng2mpi(arg1)); };
NG_MPI_Comm_rank = [](NG_MPI_Comm arg0, int* arg1)->int { return MPI_Comm_rank( ng2mpi(arg0), ng2mpi(arg1)); };
NG_MPI_Comm_size = [](NG_MPI_Comm arg0, int* arg1)->int { return MPI_Comm_size( ng2mpi(arg0), ng2mpi(arg1)); };
NG_MPI_Comm_rank = [](NG_MPI_Comm arg0, int* arg1)->int { return MPI_Comm_rank( ng2mpi(arg0), arg1); };
NG_MPI_Comm_size = [](NG_MPI_Comm arg0, int* arg1)->int { return MPI_Comm_size( ng2mpi(arg0), arg1); };
NG_MPI_Finalize = []()->int { return MPI_Finalize(); };
NG_MPI_Gather = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, int arg6, NG_MPI_Comm arg7)->int { return MPI_Gather( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6), ng2mpi(arg7)); };
NG_MPI_Get_count = [](NG_MPI_Status* arg0, NG_MPI_Datatype arg1, int* arg2)->int { return MPI_Get_count( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2)); };
NG_MPI_Get_processor_name = [](char* arg0, int* arg1)->int { return MPI_Get_processor_name( ng2mpi(arg0), ng2mpi(arg1)); };
NG_MPI_Group_incl = [](NG_MPI_Group arg0, int arg1, int* arg2, NG_MPI_Group* arg3)->int { return MPI_Group_incl( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3)); };
NG_MPI_Init = [](int* arg0, char*** arg1)->int { return MPI_Init( ng2mpi(arg0), ng2mpi(arg1)); };
NG_MPI_Init_thread = [](int* arg0, char*** arg1, int arg2, int* arg3)->int { return MPI_Init_thread( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3)); };
NG_MPI_Initialized = [](int* arg0)->int { return MPI_Initialized( ng2mpi(arg0)); };
NG_MPI_Iprobe = [](int arg0, int arg1, NG_MPI_Comm arg2, int* arg3, NG_MPI_Status* arg4)->int { return MPI_Iprobe( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4)); };
NG_MPI_Irecv = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5, NG_MPI_Request* arg6)->int { return MPI_Irecv( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6)); };
NG_MPI_Isend = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5, NG_MPI_Request* arg6)->int { return MPI_Isend( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6)); };
NG_MPI_Probe = [](int arg0, int arg1, NG_MPI_Comm arg2, NG_MPI_Status* arg3)->int { return MPI_Probe( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3)); };
NG_MPI_Query_thread = [](int* arg0)->int { return MPI_Query_thread( ng2mpi(arg0)); };
NG_MPI_Recv = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5, NG_MPI_Status* arg6)->int { return MPI_Recv( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6)); };
NG_MPI_Reduce = [](void* arg0, void* arg1, int arg2, NG_MPI_Datatype arg3, NG_MPI_Op arg4, int arg5, NG_MPI_Comm arg6)->int { return MPI_Reduce( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6)); };
NG_MPI_Reduce_local = [](void* arg0, void* arg1, int arg2, NG_MPI_Datatype arg3, NG_MPI_Op arg4)->int { return MPI_Reduce_local( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4)); };
NG_MPI_Gather = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, int arg6, NG_MPI_Comm arg7)->int { return MPI_Gather( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), arg6, ng2mpi(arg7)); };
NG_MPI_Get_count = [](NG_MPI_Status* arg0, NG_MPI_Datatype arg1, int* arg2)->int { return MPI_Get_count( ng2mpi(arg0), ng2mpi(arg1), arg2); };
NG_MPI_Get_processor_name = [](char* arg0, int* arg1)->int { return MPI_Get_processor_name( arg0, arg1); };
NG_MPI_Group_incl = [](NG_MPI_Group arg0, int arg1, int* arg2, NG_MPI_Group* arg3)->int { return MPI_Group_incl( ng2mpi(arg0), arg1, arg2, ng2mpi(arg3)); };
NG_MPI_Init = [](int* arg0, char*** arg1)->int { return MPI_Init( arg0, arg1); };
NG_MPI_Init_thread = [](int* arg0, char*** arg1, int arg2, int* arg3)->int { return MPI_Init_thread( arg0, arg1, arg2, arg3); };
NG_MPI_Initialized = [](int* arg0)->int { return MPI_Initialized( arg0); };
NG_MPI_Iprobe = [](int arg0, int arg1, NG_MPI_Comm arg2, int* arg3, NG_MPI_Status* arg4)->int { return MPI_Iprobe( arg0, arg1, ng2mpi(arg2), arg3, ng2mpi(arg4)); };
NG_MPI_Irecv = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5, NG_MPI_Request* arg6)->int { return MPI_Irecv( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), ng2mpi(arg6)); };
NG_MPI_Isend = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5, NG_MPI_Request* arg6)->int { return MPI_Isend( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), ng2mpi(arg6)); };
NG_MPI_Probe = [](int arg0, int arg1, NG_MPI_Comm arg2, NG_MPI_Status* arg3)->int { return MPI_Probe( arg0, arg1, ng2mpi(arg2), ng2mpi(arg3)); };
NG_MPI_Query_thread = [](int* arg0)->int { return MPI_Query_thread( arg0); };
NG_MPI_Recv = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5, NG_MPI_Status* arg6)->int { return MPI_Recv( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), ng2mpi(arg6)); };
NG_MPI_Reduce = [](void* arg0, void* arg1, int arg2, NG_MPI_Datatype arg3, NG_MPI_Op arg4, int arg5, NG_MPI_Comm arg6)->int { return MPI_Reduce( arg0, arg1, arg2, ng2mpi(arg3), ng2mpi(arg4), arg5, ng2mpi(arg6)); };
NG_MPI_Reduce_local = [](void* arg0, void* arg1, int arg2, NG_MPI_Datatype arg3, NG_MPI_Op arg4)->int { return MPI_Reduce_local( arg0, arg1, arg2, ng2mpi(arg3), ng2mpi(arg4)); };
NG_MPI_Request_free = [](NG_MPI_Request* arg0)->int { return MPI_Request_free( ng2mpi(arg0)); };
NG_MPI_Scatter = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, int arg6, NG_MPI_Comm arg7)->int { return MPI_Scatter( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6), ng2mpi(arg7)); };
NG_MPI_Send = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5)->int { return MPI_Send( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5)); };
NG_MPI_Scatter = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, int arg6, NG_MPI_Comm arg7)->int { return MPI_Scatter( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), arg6, ng2mpi(arg7)); };
NG_MPI_Send = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5)->int { return MPI_Send( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5)); };
NG_MPI_Type_commit = [](NG_MPI_Datatype* arg0)->int { return MPI_Type_commit( ng2mpi(arg0)); };
NG_MPI_Type_contiguous = [](int arg0, NG_MPI_Datatype arg1, NG_MPI_Datatype* arg2)->int { return MPI_Type_contiguous( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2)); };
NG_MPI_Type_contiguous = [](int arg0, NG_MPI_Datatype arg1, NG_MPI_Datatype* arg2)->int { return MPI_Type_contiguous( arg0, ng2mpi(arg1), ng2mpi(arg2)); };
NG_MPI_Type_create_resized = [](NG_MPI_Datatype arg0, NG_MPI_Aint arg1, NG_MPI_Aint arg2, NG_MPI_Datatype* arg3)->int { return MPI_Type_create_resized( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3)); };
NG_MPI_Type_create_struct = [](int arg0, int* arg1, NG_MPI_Aint* arg2, NG_MPI_Datatype* arg3, NG_MPI_Datatype* arg4)->int { return MPI_Type_create_struct( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4)); };
NG_MPI_Type_create_struct = [](int arg0, int* arg1, NG_MPI_Aint* arg2, NG_MPI_Datatype* arg3, NG_MPI_Datatype* arg4)->int { return MPI_Type_create_struct( arg0, arg1, ng2mpi(arg2, arg0), ng2mpi(arg3, arg0), ng2mpi(arg4)); };
NG_MPI_Type_free = [](NG_MPI_Datatype* arg0)->int { return MPI_Type_free( ng2mpi(arg0)); };
NG_MPI_Type_get_extent = [](NG_MPI_Datatype arg0, NG_MPI_Aint* arg1, NG_MPI_Aint* arg2)->int { return MPI_Type_get_extent( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2)); };
NG_MPI_Type_indexed = [](int arg0, int* arg1, int* arg2, NG_MPI_Datatype arg3, NG_MPI_Datatype* arg4)->int { return MPI_Type_indexed( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4)); };
NG_MPI_Type_indexed = [](int arg0, int* arg1, int* arg2, NG_MPI_Datatype arg3, NG_MPI_Datatype* arg4)->int { return MPI_Type_indexed( arg0, arg1, arg2, ng2mpi(arg3), ng2mpi(arg4)); };
NG_MPI_Wait = [](NG_MPI_Request* arg0, NG_MPI_Status* arg1)->int { return MPI_Wait( ng2mpi(arg0), ng2mpi(arg1)); };
NG_MPI_Waitall = [](int arg0, NG_MPI_Request* arg1, NG_MPI_Status* arg2)->int { return MPI_Waitall( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2)); };
NG_MPI_Waitany = [](int arg0, NG_MPI_Request* arg1, int* arg2, NG_MPI_Status* arg3)->int { return MPI_Waitany( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3)); };
NG_MPI_Waitall = [](int arg0, NG_MPI_Request* arg1, NG_MPI_Status* arg2)->int { return MPI_Waitall( arg0, ng2mpi(arg1, arg0), ng2mpi(arg2)); };
NG_MPI_Waitany = [](int arg0, NG_MPI_Request* arg1, int* arg2, NG_MPI_Status* arg3)->int { return MPI_Waitany( arg0, ng2mpi(arg1, arg0), arg2, ng2mpi(arg3)); };
NG_MPI_COMM_WORLD = mpi2ng(MPI_COMM_WORLD);
NG_MPI_CHAR = mpi2ng(MPI_CHAR);
NG_MPI_CXX_DOUBLE_COMPLEX = mpi2ng(MPI_CXX_DOUBLE_COMPLEX);