diff --git a/libsrc/core/generate_mpi_sources.py b/libsrc/core/generate_mpi_sources.py index da494a05..a5cd83c7 100644 --- a/libsrc/core/generate_mpi_sources.py +++ b/libsrc/core/generate_mpi_sources.py @@ -32,13 +32,13 @@ functions = [ ("int", "MPI_Type_commit", "MPI_Datatype*"), ("int", "MPI_Type_contiguous", "int", "MPI_Datatype", "MPI_Datatype*"), ("int", "MPI_Type_create_resized", "MPI_Datatype", "MPI_Aint", "MPI_Aint", "MPI_Datatype*"), - ("int", "MPI_Type_create_struct", "int", "int*", "MPI_Aint*", "MPI_Datatype*", "MPI_Datatype*"), + ("int", "MPI_Type_create_struct", "int", "int*:0", "MPI_Aint*:0", "MPI_Datatype*:0", "MPI_Datatype*"), ("int", "MPI_Type_free", "MPI_Datatype*"), ("int", "MPI_Type_get_extent", "MPI_Datatype", "MPI_Aint*", "MPI_Aint*"), - ("int", "MPI_Type_indexed", "int", "int*", "int*", "MPI_Datatype", "MPI_Datatype*"), + ("int", "MPI_Type_indexed", "int", "int*:0", "int*:0", "MPI_Datatype", "MPI_Datatype*"), ("int", "MPI_Wait", "MPI_Request*", "MPI_Status*"), - ("int", "MPI_Waitall", "int", "MPI_Request*", "MPI_Status*"), - ("int", "MPI_Waitany", "int", "MPI_Request*", "int*", "MPI_Status*"), + ("int", "MPI_Waitall", "int", "MPI_Request*:0", "MPI_Status*"), + ("int", "MPI_Waitany", "int", "MPI_Request*:0", "int*", "MPI_Status*"), ] constants = [ @@ -69,8 +69,23 @@ constants = [ ("void*", "MPI_IN_PLACE"), ] -def get_args(f): - return ["NG_"+a if a.startswith("MPI_") else a for a in f[2:]] +def get_args(f, counts=False): + args = [] + for arg in f[2:]: + has_count = ':' in arg + if has_count: + s, count = arg.split(':') + count = int(count) + else: + s = arg + count = None + if s.startswith("MPI_"): + s = "NG_" + s + if counts: + args.append((s, count)) + else: + args.append(s) + return args def generate_declarations(): code = "" @@ -109,15 +124,24 @@ def generate_init(): for f in functions: ret = f[0] name = f[1] - args = get_args(f) + args = get_args(f, counts=True) in_args ='' call_args = '' - for i, a in enumerate(args): + for i in range(len(args)): + arg, count = args[i] if i > 0: in_args += ', ' call_args += ', ' - in_args += a + f" arg{i}" - call_args += f" ng2mpi(arg{i})" + in_args += arg + f" arg{i}" + if not arg.startswith("NG_"): + # plain type (like int, int *, etc.), just pass the argument along + call_args += f" arg{i}" + elif count is None: + # MPI type (by value or pointer), but just one object, no arrays + call_args += f" ng2mpi(arg{i})" + else: + # arrays of MPI types, we need to copy them due to incompatible size + call_args += f" ng2mpi(arg{i}, arg{count})" code += f"NG_{name} = []({in_args})->{ret} {{ return {name}({call_args}); }};\n" for _, name in constants: diff --git a/libsrc/core/ng_mpi.cpp b/libsrc/core/ng_mpi.cpp index 92bac08f..daf1410a 100644 --- a/libsrc/core/ng_mpi.cpp +++ b/libsrc/core/ng_mpi.cpp @@ -33,6 +33,20 @@ NG_MPI_Status* mpi2ng(MPI_Status* status) { NG_MPI_Comm mpi2ng(MPI_Comm comm) { return reinterpret_cast(comm); } #endif +template +void gather_strided_array(size_t count, char* data) { + static_assert(size <= stride, "Size must be less than or equal to stride"); + if constexpr (size < stride) { + char* dst = data; + char* src = data; + for (auto i : Range(count)) { + memcpy(dst, src, size); + dst += size; + src += stride; + } + } +} + template T cast_ng2mpi(uintptr_t obj) { if constexpr (std::is_pointer_v) @@ -49,6 +63,13 @@ T cast_ng2mpi(uintptr_t* ptr) { return static_cast(ptr); } +template +T* cast_ng2mpi(TSrc* ptr, int count) { + gather_strided_array(count, + reinterpret_cast(ptr)); + return reinterpret_cast(ptr); +} + MPI_Comm ng2mpi(NG_MPI_Comm comm) { static_assert(sizeof(MPI_Comm) <= sizeof(comm.value), "Size mismatch"); static_assert(alignof(MPI_Comm) <= alignof(NG_MPI_Comm), "Size mismatch"); @@ -70,15 +91,24 @@ MPI_Group* ng2mpi(NG_MPI_Group* group) { MPI_Datatype* ng2mpi(NG_MPI_Datatype* type) { return cast_ng2mpi(&type->value); } +MPI_Datatype* ng2mpi(NG_MPI_Datatype* type, int count) { + return cast_ng2mpi(&type->value, count); +} MPI_Request* ng2mpi(NG_MPI_Request* request) { return cast_ng2mpi(&request->value); } +MPI_Request* ng2mpi(NG_MPI_Request* request, int count) { + return cast_ng2mpi(&request->value, count); +} MPI_Status* ng2mpi(NG_MPI_Status* status) { return reinterpret_cast(status); } MPI_Aint* ng2mpi(NG_MPI_Aint* aint) { return reinterpret_cast(aint); } +MPI_Aint* ng2mpi(NG_MPI_Aint* aint, int count) { + return cast_ng2mpi(aint, count); +} MPI_Datatype ng2mpi(NG_MPI_Datatype type) { static_assert(sizeof(MPI_Datatype) <= sizeof(type.value), "Size mismatch"); diff --git a/libsrc/core/ng_mpi_generated_init.hpp b/libsrc/core/ng_mpi_generated_init.hpp index 3b3b93ef..dfdb0dfa 100644 --- a/libsrc/core/ng_mpi_generated_init.hpp +++ b/libsrc/core/ng_mpi_generated_init.hpp @@ -1,43 +1,43 @@ NG_MPI_Wtime = []()->double { return MPI_Wtime(); }; -NG_MPI_Allgather = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, NG_MPI_Comm arg6)->int { return MPI_Allgather( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6)); }; -NG_MPI_Allreduce = [](void* arg0, void* arg1, int arg2, NG_MPI_Datatype arg3, NG_MPI_Op arg4, NG_MPI_Comm arg5)->int { return MPI_Allreduce( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5)); }; -NG_MPI_Alltoall = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, NG_MPI_Comm arg6)->int { return MPI_Alltoall( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6)); }; +NG_MPI_Allgather = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, NG_MPI_Comm arg6)->int { return MPI_Allgather( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), ng2mpi(arg6)); }; +NG_MPI_Allreduce = [](void* arg0, void* arg1, int arg2, NG_MPI_Datatype arg3, NG_MPI_Op arg4, NG_MPI_Comm arg5)->int { return MPI_Allreduce( arg0, arg1, arg2, ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5)); }; +NG_MPI_Alltoall = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, NG_MPI_Comm arg6)->int { return MPI_Alltoall( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), ng2mpi(arg6)); }; NG_MPI_Barrier = [](NG_MPI_Comm arg0)->int { return MPI_Barrier( ng2mpi(arg0)); }; -NG_MPI_Bcast = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, NG_MPI_Comm arg4)->int { return MPI_Bcast( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4)); }; -NG_MPI_Comm_create_group = [](NG_MPI_Comm arg0, NG_MPI_Group arg1, int arg2, NG_MPI_Comm* arg3)->int { return MPI_Comm_create_group( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3)); }; +NG_MPI_Bcast = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, NG_MPI_Comm arg4)->int { return MPI_Bcast( arg0, arg1, ng2mpi(arg2), arg3, ng2mpi(arg4)); }; +NG_MPI_Comm_create_group = [](NG_MPI_Comm arg0, NG_MPI_Group arg1, int arg2, NG_MPI_Comm* arg3)->int { return MPI_Comm_create_group( ng2mpi(arg0), ng2mpi(arg1), arg2, ng2mpi(arg3)); }; NG_MPI_Comm_free = [](NG_MPI_Comm* arg0)->int { return MPI_Comm_free( ng2mpi(arg0)); }; NG_MPI_Comm_group = [](NG_MPI_Comm arg0, NG_MPI_Group* arg1)->int { return MPI_Comm_group( ng2mpi(arg0), ng2mpi(arg1)); }; -NG_MPI_Comm_rank = [](NG_MPI_Comm arg0, int* arg1)->int { return MPI_Comm_rank( ng2mpi(arg0), ng2mpi(arg1)); }; -NG_MPI_Comm_size = [](NG_MPI_Comm arg0, int* arg1)->int { return MPI_Comm_size( ng2mpi(arg0), ng2mpi(arg1)); }; +NG_MPI_Comm_rank = [](NG_MPI_Comm arg0, int* arg1)->int { return MPI_Comm_rank( ng2mpi(arg0), arg1); }; +NG_MPI_Comm_size = [](NG_MPI_Comm arg0, int* arg1)->int { return MPI_Comm_size( ng2mpi(arg0), arg1); }; NG_MPI_Finalize = []()->int { return MPI_Finalize(); }; -NG_MPI_Gather = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, int arg6, NG_MPI_Comm arg7)->int { return MPI_Gather( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6), ng2mpi(arg7)); }; -NG_MPI_Get_count = [](NG_MPI_Status* arg0, NG_MPI_Datatype arg1, int* arg2)->int { return MPI_Get_count( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2)); }; -NG_MPI_Get_processor_name = [](char* arg0, int* arg1)->int { return MPI_Get_processor_name( ng2mpi(arg0), ng2mpi(arg1)); }; -NG_MPI_Group_incl = [](NG_MPI_Group arg0, int arg1, int* arg2, NG_MPI_Group* arg3)->int { return MPI_Group_incl( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3)); }; -NG_MPI_Init = [](int* arg0, char*** arg1)->int { return MPI_Init( ng2mpi(arg0), ng2mpi(arg1)); }; -NG_MPI_Init_thread = [](int* arg0, char*** arg1, int arg2, int* arg3)->int { return MPI_Init_thread( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3)); }; -NG_MPI_Initialized = [](int* arg0)->int { return MPI_Initialized( ng2mpi(arg0)); }; -NG_MPI_Iprobe = [](int arg0, int arg1, NG_MPI_Comm arg2, int* arg3, NG_MPI_Status* arg4)->int { return MPI_Iprobe( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4)); }; -NG_MPI_Irecv = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5, NG_MPI_Request* arg6)->int { return MPI_Irecv( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6)); }; -NG_MPI_Isend = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5, NG_MPI_Request* arg6)->int { return MPI_Isend( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6)); }; -NG_MPI_Probe = [](int arg0, int arg1, NG_MPI_Comm arg2, NG_MPI_Status* arg3)->int { return MPI_Probe( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3)); }; -NG_MPI_Query_thread = [](int* arg0)->int { return MPI_Query_thread( ng2mpi(arg0)); }; -NG_MPI_Recv = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5, NG_MPI_Status* arg6)->int { return MPI_Recv( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6)); }; -NG_MPI_Reduce = [](void* arg0, void* arg1, int arg2, NG_MPI_Datatype arg3, NG_MPI_Op arg4, int arg5, NG_MPI_Comm arg6)->int { return MPI_Reduce( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6)); }; -NG_MPI_Reduce_local = [](void* arg0, void* arg1, int arg2, NG_MPI_Datatype arg3, NG_MPI_Op arg4)->int { return MPI_Reduce_local( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4)); }; +NG_MPI_Gather = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, int arg6, NG_MPI_Comm arg7)->int { return MPI_Gather( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), arg6, ng2mpi(arg7)); }; +NG_MPI_Get_count = [](NG_MPI_Status* arg0, NG_MPI_Datatype arg1, int* arg2)->int { return MPI_Get_count( ng2mpi(arg0), ng2mpi(arg1), arg2); }; +NG_MPI_Get_processor_name = [](char* arg0, int* arg1)->int { return MPI_Get_processor_name( arg0, arg1); }; +NG_MPI_Group_incl = [](NG_MPI_Group arg0, int arg1, int* arg2, NG_MPI_Group* arg3)->int { return MPI_Group_incl( ng2mpi(arg0), arg1, arg2, ng2mpi(arg3)); }; +NG_MPI_Init = [](int* arg0, char*** arg1)->int { return MPI_Init( arg0, arg1); }; +NG_MPI_Init_thread = [](int* arg0, char*** arg1, int arg2, int* arg3)->int { return MPI_Init_thread( arg0, arg1, arg2, arg3); }; +NG_MPI_Initialized = [](int* arg0)->int { return MPI_Initialized( arg0); }; +NG_MPI_Iprobe = [](int arg0, int arg1, NG_MPI_Comm arg2, int* arg3, NG_MPI_Status* arg4)->int { return MPI_Iprobe( arg0, arg1, ng2mpi(arg2), arg3, ng2mpi(arg4)); }; +NG_MPI_Irecv = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5, NG_MPI_Request* arg6)->int { return MPI_Irecv( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), ng2mpi(arg6)); }; +NG_MPI_Isend = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5, NG_MPI_Request* arg6)->int { return MPI_Isend( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), ng2mpi(arg6)); }; +NG_MPI_Probe = [](int arg0, int arg1, NG_MPI_Comm arg2, NG_MPI_Status* arg3)->int { return MPI_Probe( arg0, arg1, ng2mpi(arg2), ng2mpi(arg3)); }; +NG_MPI_Query_thread = [](int* arg0)->int { return MPI_Query_thread( arg0); }; +NG_MPI_Recv = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5, NG_MPI_Status* arg6)->int { return MPI_Recv( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), ng2mpi(arg6)); }; +NG_MPI_Reduce = [](void* arg0, void* arg1, int arg2, NG_MPI_Datatype arg3, NG_MPI_Op arg4, int arg5, NG_MPI_Comm arg6)->int { return MPI_Reduce( arg0, arg1, arg2, ng2mpi(arg3), ng2mpi(arg4), arg5, ng2mpi(arg6)); }; +NG_MPI_Reduce_local = [](void* arg0, void* arg1, int arg2, NG_MPI_Datatype arg3, NG_MPI_Op arg4)->int { return MPI_Reduce_local( arg0, arg1, arg2, ng2mpi(arg3), ng2mpi(arg4)); }; NG_MPI_Request_free = [](NG_MPI_Request* arg0)->int { return MPI_Request_free( ng2mpi(arg0)); }; -NG_MPI_Scatter = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, int arg6, NG_MPI_Comm arg7)->int { return MPI_Scatter( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5), ng2mpi(arg6), ng2mpi(arg7)); }; -NG_MPI_Send = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5)->int { return MPI_Send( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4), ng2mpi(arg5)); }; +NG_MPI_Scatter = [](void* arg0, int arg1, NG_MPI_Datatype arg2, void* arg3, int arg4, NG_MPI_Datatype arg5, int arg6, NG_MPI_Comm arg7)->int { return MPI_Scatter( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5), arg6, ng2mpi(arg7)); }; +NG_MPI_Send = [](void* arg0, int arg1, NG_MPI_Datatype arg2, int arg3, int arg4, NG_MPI_Comm arg5)->int { return MPI_Send( arg0, arg1, ng2mpi(arg2), arg3, arg4, ng2mpi(arg5)); }; NG_MPI_Type_commit = [](NG_MPI_Datatype* arg0)->int { return MPI_Type_commit( ng2mpi(arg0)); }; -NG_MPI_Type_contiguous = [](int arg0, NG_MPI_Datatype arg1, NG_MPI_Datatype* arg2)->int { return MPI_Type_contiguous( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2)); }; +NG_MPI_Type_contiguous = [](int arg0, NG_MPI_Datatype arg1, NG_MPI_Datatype* arg2)->int { return MPI_Type_contiguous( arg0, ng2mpi(arg1), ng2mpi(arg2)); }; NG_MPI_Type_create_resized = [](NG_MPI_Datatype arg0, NG_MPI_Aint arg1, NG_MPI_Aint arg2, NG_MPI_Datatype* arg3)->int { return MPI_Type_create_resized( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3)); }; -NG_MPI_Type_create_struct = [](int arg0, int* arg1, NG_MPI_Aint* arg2, NG_MPI_Datatype* arg3, NG_MPI_Datatype* arg4)->int { return MPI_Type_create_struct( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4)); }; +NG_MPI_Type_create_struct = [](int arg0, int* arg1, NG_MPI_Aint* arg2, NG_MPI_Datatype* arg3, NG_MPI_Datatype* arg4)->int { return MPI_Type_create_struct( arg0, arg1, ng2mpi(arg2, arg0), ng2mpi(arg3, arg0), ng2mpi(arg4)); }; NG_MPI_Type_free = [](NG_MPI_Datatype* arg0)->int { return MPI_Type_free( ng2mpi(arg0)); }; NG_MPI_Type_get_extent = [](NG_MPI_Datatype arg0, NG_MPI_Aint* arg1, NG_MPI_Aint* arg2)->int { return MPI_Type_get_extent( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2)); }; -NG_MPI_Type_indexed = [](int arg0, int* arg1, int* arg2, NG_MPI_Datatype arg3, NG_MPI_Datatype* arg4)->int { return MPI_Type_indexed( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3), ng2mpi(arg4)); }; +NG_MPI_Type_indexed = [](int arg0, int* arg1, int* arg2, NG_MPI_Datatype arg3, NG_MPI_Datatype* arg4)->int { return MPI_Type_indexed( arg0, arg1, arg2, ng2mpi(arg3), ng2mpi(arg4)); }; NG_MPI_Wait = [](NG_MPI_Request* arg0, NG_MPI_Status* arg1)->int { return MPI_Wait( ng2mpi(arg0), ng2mpi(arg1)); }; -NG_MPI_Waitall = [](int arg0, NG_MPI_Request* arg1, NG_MPI_Status* arg2)->int { return MPI_Waitall( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2)); }; -NG_MPI_Waitany = [](int arg0, NG_MPI_Request* arg1, int* arg2, NG_MPI_Status* arg3)->int { return MPI_Waitany( ng2mpi(arg0), ng2mpi(arg1), ng2mpi(arg2), ng2mpi(arg3)); }; +NG_MPI_Waitall = [](int arg0, NG_MPI_Request* arg1, NG_MPI_Status* arg2)->int { return MPI_Waitall( arg0, ng2mpi(arg1, arg0), ng2mpi(arg2)); }; +NG_MPI_Waitany = [](int arg0, NG_MPI_Request* arg1, int* arg2, NG_MPI_Status* arg3)->int { return MPI_Waitany( arg0, ng2mpi(arg1, arg0), arg2, ng2mpi(arg3)); }; NG_MPI_COMM_WORLD = mpi2ng(MPI_COMM_WORLD); NG_MPI_CHAR = mpi2ng(MPI_CHAR); NG_MPI_CXX_DOUBLE_COMPLEX = mpi2ng(MPI_CXX_DOUBLE_COMPLEX);