functions = [ ("double", "MPI_Wtime"), ("int", "MPI_Allgather", "void*", "int", "MPI_Datatype", "void*", "int", "MPI_Datatype", "MPI_Comm"), ("int", "MPI_Allreduce", "void*", "void*", "int", "MPI_Datatype", "MPI_Op", "MPI_Comm"), ("int", "MPI_Alltoall", "void*", "int", "MPI_Datatype", "void*", "int", "MPI_Datatype", "MPI_Comm"), ("int", "MPI_Barrier", "MPI_Comm"), ("int", "MPI_Bcast", "void*", "int", "MPI_Datatype", "int", "MPI_Comm"), ("int", "MPI_Comm_c2f", "MPI_Comm"), ("int", "MPI_Comm_create", "MPI_Comm", "MPI_Group", "MPI_Comm*"), ("int", "MPI_Comm_create_group", "MPI_Comm", "MPI_Group", "int", "MPI_Comm*"), ("int", "MPI_Comm_free", "MPI_Comm*"), ("int", "MPI_Comm_group", "MPI_Comm", "MPI_Group*"), ("int", "MPI_Comm_rank", "MPI_Comm", "int*"), ("int", "MPI_Comm_size", "MPI_Comm", "int*"), ("int", "MPI_Finalize"), ("int", "MPI_Gather", "void*", "int", "MPI_Datatype", "void*", "int", "MPI_Datatype", "int", "MPI_Comm"), ("int", "MPI_Gatherv", "void*", "int", "MPI_Datatype", "void*", "int*", "int*", "MPI_Datatype", "int", "MPI_Comm"), ("int", "MPI_Get_count", "MPI_Status*", "MPI_Datatype", "int*"), ("int", "MPI_Get_processor_name", "char*", "int*"), ("int", "MPI_Group_incl", "MPI_Group", "int", "int*", "MPI_Group*"), ("int", "MPI_Init", "int*", "char***"), ("int", "MPI_Init_thread", "int*", "char***", "int", "int*"), ("int", "MPI_Initialized", "int*"), ("int", "MPI_Iprobe", "int", "int", "MPI_Comm", "int*", "MPI_Status*"), ("int", "MPI_Irecv", "void*", "int", "MPI_Datatype", "int", "int", "MPI_Comm", "MPI_Request*"), ("int", "MPI_Isend", "void*", "int", "MPI_Datatype", "int", "int", "MPI_Comm", "MPI_Request*"), ("int", "MPI_Probe", "int", "int", "MPI_Comm", "MPI_Status*"), ("int", "MPI_Query_thread", "int*"), ("int", "MPI_Recv", "void*", "int", "MPI_Datatype", "int", "int", "MPI_Comm", "MPI_Status*"), ("int", "MPI_Recv_init", "void*", "int", "MPI_Datatype", "int", "int", "MPI_Comm", "MPI_Request*"), ("int", "MPI_Reduce", "void*", "void*", "int", "MPI_Datatype", "MPI_Op", "int", "MPI_Comm"), ("int", "MPI_Reduce_local", "void*", "void*", "int", "MPI_Datatype", "MPI_Op"), ("int", "MPI_Request_free", "MPI_Request*"), ("int", "MPI_Scatter", "void*", "int", "MPI_Datatype", "void*", "int", "MPI_Datatype", "int", "MPI_Comm"), ("int", "MPI_Send", "void*", "int", "MPI_Datatype", "int", "int", "MPI_Comm"), ("int", "MPI_Send_init", "void*", "int", "MPI_Datatype", "int", "int", "MPI_Comm", "MPI_Request*"), ("int", "MPI_Startall", "int", "MPI_Request*:0"), ("int", "MPI_Type_commit", "MPI_Datatype*"), ("int", "MPI_Type_contiguous", "int", "MPI_Datatype", "MPI_Datatype*"), ("int", "MPI_Type_create_resized", "MPI_Datatype", "MPI_Aint", "MPI_Aint", "MPI_Datatype*"), ("int", "MPI_Type_create_struct", "int", "int*:0", "MPI_Aint*:0", "MPI_Datatype*:0", "MPI_Datatype*"), ("int", "MPI_Type_free", "MPI_Datatype*"), ("int", "MPI_Type_get_extent", "MPI_Datatype", "MPI_Aint*", "MPI_Aint*"), ("int", "MPI_Type_indexed", "int", "int*:0", "int*:0", "MPI_Datatype", "MPI_Datatype*"), ("int", "MPI_Type_size", "MPI_Datatype", "int*"), ("int", "MPI_Wait", "MPI_Request*", "MPI_Status*"), ("int", "MPI_Waitall", "int", "MPI_Request*:0", "MPI_Status*"), ("int", "MPI_Waitany", "int", "MPI_Request*:0", "int*", "MPI_Status*"), ] constants = [ ("MPI_Comm", "MPI_COMM_NULL"), ("MPI_Comm", "MPI_COMM_WORLD"), ("MPI_Datatype", "MPI_CHAR"), ("MPI_Datatype", "MPI_CXX_DOUBLE_COMPLEX"), ("MPI_Datatype", "MPI_C_BOOL"), ("MPI_Datatype", "MPI_DATATYPE_NULL"), ("MPI_Datatype", "MPI_DOUBLE"), ("MPI_Datatype", "MPI_FLOAT"), ("MPI_Datatype", "MPI_INT"), ("MPI_Datatype", "MPI_SHORT"), ("MPI_Datatype", "MPI_UINT64_T"), ("MPI_Op", "MPI_LOR"), ("MPI_Op", "MPI_MAX"), ("MPI_Op", "MPI_MIN"), ("MPI_Op", "MPI_SUM"), ("MPI_Request", "MPI_REQUEST_NULL"), ("MPI_Status*", "MPI_STATUSES_IGNORE"), ("MPI_Status*", "MPI_STATUS_IGNORE"), ("int", "MPI_ANY_SOURCE"), ("int", "MPI_ANY_TAG"), ("int", "MPI_MAX_PROCESSOR_NAME"), ("int", "MPI_PROC_NULL"), ("int", "MPI_ROOT"), ("int", "MPI_SUBVERSION"), ("int", "MPI_THREAD_MULTIPLE"), ("int", "MPI_THREAD_SINGLE"), ("int", "MPI_VERSION"), ("void*", "MPI_IN_PLACE"), ] def get_args(f, counts=False): args = [] for arg in f[2:]: has_count = ':' in arg if has_count: s, count = arg.split(':') count = int(count) else: s = arg count = None if s.startswith("MPI_"): s = "NG_" + s if counts: args.append((s, count)) else: args.append(s) return args def generate_declarations(): code = "" nowrapper_code = "" for f in functions: ret = f[0] name = f[1] args = ", ".join(get_args(f)) code += f"NGCORE_API extern {ret} (*NG_{name})({args});\n" nowrapper_code += f"static const auto NG_{name} = {name};\n" for typ, name in constants: if typ.startswith("MPI_"): typ = "NG_" + typ code += f"NGCORE_API extern {typ} NG_{name};\n" nowrapper_code += f"static const decltype({name}) NG_{name} = {name};\n" with open("ng_mpi_generated_declarations.hpp", "w") as f: f.write("#ifdef NG_MPI_WRAPPER\n") f.write(code) f.write("#else // NG_MPI_WRAPPER\n") f.write(nowrapper_code) f.write("#endif // NG_MPI_WRAPPER\n") def generate_dummy_init(): code = "" for f in functions: ret = f[0] name = f[1] args = ", ".join(get_args(f)) code += f"decltype(NG_{name}) NG_{name} = []({args})->{ret} {{ throw no_mpi(); }};\n" for typ, name in constants: if typ.startswith("MPI_"): typ = "NG_" + typ code += f"{typ} NG_{name} = 0;\n" with open("ng_mpi_generated_dummy_init.hpp", "w") as f: f.write(code) def generate_init(): code = "" for f in functions: ret = f[0] name = f[1] args = get_args(f, counts=True) in_args ='' call_args = '' for i in range(len(args)): arg, count = args[i] if i > 0: in_args += ', ' call_args += ', ' in_args += arg + f" arg{i}" if not arg.startswith("NG_"): # plain type (like int, int *, etc.), just pass the argument along call_args += f" arg{i}" elif count is None: # MPI type (by value or pointer), but just one object, no arrays call_args += f" ng2mpi(arg{i})" else: # arrays of MPI types, we need to copy them due to incompatible size call_args += f" ng2mpi(arg{i}, arg{count})" code += f"NG_{name} = []({in_args})->{ret} {{ return {name}({call_args}); }};\n" for _, name in constants: code += f"NG_{name} = mpi2ng({name});\n" with open("ng_mpi_generated_init.hpp", "w") as f: f.write(code) if __name__ == "__main__": generate_declarations() generate_dummy_init() generate_init()