From 829defd3eb1ff70677691830c7191ff3df781d8a Mon Sep 17 00:00:00 2001 From: Christopher Lackner Date: Thu, 20 Dec 2018 17:01:27 +0100 Subject: [PATCH] archive now support python exported objects --- libsrc/core/archive.hpp | 215 +++++++++++++++++++++----------- libsrc/core/type_traits.hpp | 16 +++ libsrc/csg/python_csg.cpp | 18 +-- libsrc/geom2d/python_geom2d.cpp | 21 +--- libsrc/meshing/meshclass.cpp | 2 +- libsrc/meshing/python_mesh.cpp | 1 + libsrc/occ/python_occ.cpp | 18 +-- libsrc/stlgeom/python_stl.cpp | 18 +-- tests/pytest/test_pickling.py | 18 ++- 9 files changed, 182 insertions(+), 145 deletions(-) diff --git a/libsrc/core/archive.hpp b/libsrc/core/archive.hpp index bb3e1488..15ecb894 100644 --- a/libsrc/core/archive.hpp +++ b/libsrc/core/archive.hpp @@ -18,6 +18,10 @@ #include "type_traits.hpp" // for all_of_tmpl #include "version.hpp" // for VersionInfo +#ifdef NG_PYTHON +#include +#endif // NG_PYTHON + namespace ngcore { // Libraries using this archive can store their version here to implement backwards compatibility @@ -98,7 +102,8 @@ namespace ngcore // vectors for storing the unarchived (shared) pointers std::vector> nr2shared_ptr; std::vector nr2ptr; - + protected: + bool shallow_to_python = false; public: Archive() = delete; Archive(const Archive&) = delete; @@ -108,6 +113,31 @@ namespace ngcore virtual ~Archive() { ; } + template + Archive& Shallow(T& val) + { + static_assert(detail::is_any_pointer, "ShallowArchive must be given pointer type!"); +#ifdef NG_PYTHON + if(shallow_to_python) + { + if(is_output) + ShallowOutPython(pybind11::cast(val)); + else + val = pybind11::cast(ShallowInPython()); + } + else +#endif // NG_PYTHON + *this & val; + return *this; + } + +#ifdef NG_PYTHON + virtual void ShallowOutPython(pybind11::object /*unused*/) // NOLINT (copy by val is ok for this virt func) + { throw std::runtime_error("Should not get in ShallowToPython base class implementation!"); } + virtual pybind11::object ShallowInPython() + { throw std::runtime_error("Should not get in ShallowFromPython base class implementation!"); } +#endif // NG_PYTHON + Archive& operator=(const Archive&) = delete; Archive& operator=(Archive&&) = delete; @@ -526,16 +556,15 @@ namespace ngcore static constexpr size_t BUFFERSIZE = 1024; char buffer[BUFFERSIZE] = {}; size_t ptr = 0; - std::shared_ptr fout; + protected: + std::shared_ptr stream; public: BinaryOutArchive() = delete; BinaryOutArchive(const BinaryOutArchive&) = delete; BinaryOutArchive(BinaryOutArchive&&) = delete; - BinaryOutArchive(std::shared_ptr&& afout) - : Archive(true), fout(std::move(afout)) - { - (*this) & GetLibraryVersions(); - } + BinaryOutArchive(std::shared_ptr&& astream) + : Archive(true), stream(std::move(astream)) + { } BinaryOutArchive(const std::string& filename) : BinaryOutArchive(std::make_shared(filename)) {} ~BinaryOutArchive () override { FlushBuffer(); } @@ -543,9 +572,6 @@ namespace ngcore BinaryOutArchive& operator=(const BinaryOutArchive&) = delete; BinaryOutArchive& operator=(BinaryOutArchive&&) = delete; - const VersionInfo& GetVersion(const std::string& library) override - { return GetLibraryVersions()[library]; } - using Archive::operator&; Archive & operator & (double & d) override { return Write(d); } @@ -567,7 +593,7 @@ namespace ngcore (*this) & len; FlushBuffer(); if(len) - fout->write (&str[0], len); + stream->write (&str[0], len); return *this; } Archive & operator & (char *& str) override @@ -576,14 +602,14 @@ namespace ngcore (*this) & len; FlushBuffer(); if(len > 0) - fout->write (&str[0], len); // NOLINT + stream->write (&str[0], len); // NOLINT return *this; } void FlushBuffer() override { if (ptr > 0) { - fout->write(&buffer[0], ptr); + stream->write(&buffer[0], ptr); ptr = 0; } } @@ -594,7 +620,7 @@ namespace ngcore { if (unlikely(ptr > BUFFERSIZE-sizeof(T))) { - fout->write(&buffer[0], ptr); + stream->write(&buffer[0], ptr); *reinterpret_cast(&buffer[0]) = x; // NOLINT ptr = sizeof(T); return *this; @@ -608,20 +634,15 @@ namespace ngcore // BinaryInArchive ====================================================================== class NGCORE_API BinaryInArchive : public Archive { - std::map vinfo{}; - std::shared_ptr fin; + protected: + std::shared_ptr stream; public: - BinaryInArchive (std::shared_ptr&& afin) - : Archive(false), fin(std::move(afin)) - { - (*this) & vinfo; - } + BinaryInArchive (std::shared_ptr&& astream) + : Archive(false), stream(std::move(astream)) + { } BinaryInArchive (const std::string& filename) : BinaryInArchive(std::make_shared(filename)) { ; } - const VersionInfo& GetVersion(const std::string& library) override - { return vinfo[library]; } - using Archive::operator&; Archive & operator & (double & d) override { Read(d); return *this; } @@ -643,7 +664,7 @@ namespace ngcore (*this) & len; str.resize(len); if(len) - fin->read(&str[0], len); // NOLINT + stream->read(&str[0], len); // NOLINT return *this; } Archive & operator & (char *& str) override @@ -655,64 +676,60 @@ namespace ngcore else { str = new char[len+1]; // NOLINT - fin->read(&str[0], len); // NOLINT + stream->read(&str[0], len); // NOLINT str[len] = '\0'; // NOLINT } return *this; } Archive & Do (double * d, size_t n) override - { fin->read(reinterpret_cast(d), n*sizeof(double)); return *this; } // NOLINT + { stream->read(reinterpret_cast(d), n*sizeof(double)); return *this; } // NOLINT Archive & Do (int * i, size_t n) override - { fin->read(reinterpret_cast(i), n*sizeof(int)); return *this; } // NOLINT + { stream->read(reinterpret_cast(i), n*sizeof(int)); return *this; } // NOLINT Archive & Do (size_t * i, size_t n) override - { fin->read(reinterpret_cast(i), n*sizeof(size_t)); return *this; } // NOLINT + { stream->read(reinterpret_cast(i), n*sizeof(size_t)); return *this; } // NOLINT private: template inline void Read(T& val) - { fin->read(reinterpret_cast(&val), sizeof(T)); } // NOLINT + { stream->read(reinterpret_cast(&val), sizeof(T)); } // NOLINT }; // TextOutArchive ====================================================================== class NGCORE_API TextOutArchive : public Archive { - std::shared_ptr fout; + protected: + std::shared_ptr stream; public: - TextOutArchive (std::shared_ptr&& afout) - : Archive(true), fout(std::move(afout)) - { - (*this) & GetLibraryVersions(); - } + TextOutArchive (std::shared_ptr&& astream) + : Archive(true), stream(std::move(astream)) + { } TextOutArchive (const std::string& filename) : TextOutArchive(std::make_shared(filename)) { } - const VersionInfo& GetVersion(const std::string& library) override - { return GetLibraryVersions()[library]; } - using Archive::operator&; Archive & operator & (double & d) override - { *fout << d << '\n'; return *this; } + { *stream << d << '\n'; return *this; } Archive & operator & (int & i) override - { *fout << i << '\n'; return *this; } + { *stream << i << '\n'; return *this; } Archive & operator & (short & i) override - { *fout << i << '\n'; return *this; } + { *stream << i << '\n'; return *this; } Archive & operator & (long & i) override - { *fout << i << '\n'; return *this; } + { *stream << i << '\n'; return *this; } Archive & operator & (size_t & i) override - { *fout << i << '\n'; return *this; } + { *stream << i << '\n'; return *this; } Archive & operator & (unsigned char & i) override - { *fout << int(i) << '\n'; return *this; } + { *stream << int(i) << '\n'; return *this; } Archive & operator & (bool & b) override - { *fout << (b ? 't' : 'f') << '\n'; return *this; } + { *stream << (b ? 't' : 'f') << '\n'; return *this; } Archive & operator & (std::string & str) override { int len = str.length(); - *fout << len << '\n'; + *stream << len << '\n'; if(len) { - fout->write(&str[0], len); // NOLINT - *fout << '\n'; + stream->write(&str[0], len); // NOLINT + *stream << '\n'; } return *this; } @@ -722,8 +739,8 @@ namespace ngcore *this & len; if(len > 0) { - fout->write (&str[0], len); // NOLINT - *fout << '\n'; + stream->write (&str[0], len); // NOLINT + *stream << '\n'; } return *this; } @@ -732,44 +749,39 @@ namespace ngcore // TextInArchive ====================================================================== class NGCORE_API TextInArchive : public Archive { - std::map vinfo{}; - std::shared_ptr fin; + protected: + std::shared_ptr stream; public: - TextInArchive (std::shared_ptr&& afin) : - Archive(false), fin(std::move(afin)) - { - (*this) & vinfo; - } + TextInArchive (std::shared_ptr&& astream) : + Archive(false), stream(std::move(astream)) + { } TextInArchive (const std::string& filename) : TextInArchive(std::make_shared(filename)) {} - const VersionInfo& GetVersion(const std::string& library) override - { return vinfo[library]; } - using Archive::operator&; Archive & operator & (double & d) override - { *fin >> d; return *this; } + { *stream >> d; return *this; } Archive & operator & (int & i) override - { *fin >> i; return *this; } + { *stream >> i; return *this; } Archive & operator & (short & i) override - { *fin >> i; return *this; } + { *stream >> i; return *this; } Archive & operator & (long & i) override - { *fin >> i; return *this; } + { *stream >> i; return *this; } Archive & operator & (size_t & i) override - { *fin >> i; return *this; } + { *stream >> i; return *this; } Archive & operator & (unsigned char & i) override - { int _i; *fin >> _i; i = _i; return *this; } + { int _i; *stream >> _i; i = _i; return *this; } Archive & operator & (bool & b) override - { char c; *fin >> c; b = (c=='t'); return *this; } + { char c; *stream >> c; b = (c=='t'); return *this; } Archive & operator & (std::string & str) override { int len; - *fin >> len; + *stream >> len; char ch; - fin->get(ch); // '\n' + stream->get(ch); // '\n' str.resize(len); if(len) - fin->get(&str[0], len+1, '\0'); + stream->get(&str[0], len+1, '\0'); return *this; } Archive & operator & (char *& str) override @@ -785,13 +797,70 @@ namespace ngcore str = new char[len+1]; // NOLINT if(len) { - fin->get(ch); // \n - fin->get(&str[0], len+1, '\0'); // NOLINT + stream->get(ch); // \n + stream->get(&str[0], len+1, '\0'); // NOLINT } str[len] = '\0'; // NOLINT return *this; } }; + +#ifdef NG_PYTHON + namespace py = pybind11; + + template + class PyArchive : public ARCHIVE + { + private: + py::list lst; + size_t index = 0; + using ARCHIVE::stream; + public: + PyArchive(const py::object& alst = py::none()) : + ARCHIVE(std::make_shared()), + lst(alst.is_none() ? py::list() : py::cast(alst)) + { + ARCHIVE::shallow_to_python = true; + if(Input()) + stream = std::make_shared(py::cast(lst[py::len(lst)-1])); + } + + using ARCHIVE::Output; + using ARCHIVE::Input; + using ARCHIVE::FlushBuffer; + using ARCHIVE::operator&; + using ARCHIVE::operator<<; + using ARCHIVE::GetVersion; + void ShallowOutPython(py::object val) override { lst.append(val); } + py::object ShallowInPython() override { return lst[index++]; } + + py::list WriteOut() + { + FlushBuffer(); + lst.append(py::bytes(std::static_pointer_cast(stream)->str())); + return lst; + } + }; + + template + auto NGSPickle() + { + return py::pickle([](T& self) + { + PyArchive ar; + ar & self; + return py::make_tuple(ar.WriteOut()); + }, + [](py::tuple state) + { + auto val = std::make_unique(); + PyArchive ar(state[0]); + ar & *val; + return std::move(val); + }); + } + +#endif // NG_PYTHON } // namespace ngcore #endif // NETGEN_CORE_ARCHIVE_HPP diff --git a/libsrc/core/type_traits.hpp b/libsrc/core/type_traits.hpp index 3e7cd4b4..3863940b 100644 --- a/libsrc/core/type_traits.hpp +++ b/libsrc/core/type_traits.hpp @@ -1,6 +1,7 @@ #ifndef NETGEN_CORE_TYPE_TRAITS_HPP #define NETGEN_CORE_TYPE_TRAITS_HPP +#include #include namespace ngcore @@ -11,6 +12,21 @@ namespace ngcore template constexpr bool all_of_tmpl = std::is_same<_BoolArray, _BoolArray<(vals || true)...>>::value; // NOLINT + + template + struct is_any_pointer_impl : std::false_type {}; + + template + struct is_any_pointer_impl : std::true_type {}; + + template + struct is_any_pointer_impl> : std::true_type {}; + + template + struct is_any_pointer_impl> : std::true_type {}; + + template + constexpr bool is_any_pointer = is_any_pointer_impl::value; } // namespace detail } // namespace ngcore diff --git a/libsrc/csg/python_csg.cpp b/libsrc/csg/python_csg.cpp index b84bb3e1..74067f23 100644 --- a/libsrc/csg/python_csg.cpp +++ b/libsrc/csg/python_csg.cpp @@ -368,23 +368,7 @@ However, when r = 0, the top part becomes a point(tip) and meshing fails! geo->FindIdenticSurfaces(1e-8 * geo->MaxSize()); return geo; }), py::arg("filename")) - .def(py::pickle( - [](CSGeometry& self) - { - auto ss = make_shared(); - BinaryOutArchive archive(ss); - archive & self; - archive.FlushBuffer(); - return py::make_tuple(py::bytes(ss->str())); - }, - [](py::tuple state) - { - auto geo = make_shared(); - auto ss = make_shared (py::cast(state[0])); - BinaryInArchive archive(ss); - archive & (*geo); - return geo; - })) + .def(NGSPickle()) .def("Save", FunctionPointer([] (CSGeometry & self, string filename) { cout << "save geometry to file " << filename << endl; diff --git a/libsrc/geom2d/python_geom2d.cpp b/libsrc/geom2d/python_geom2d.cpp index 1fc34f99..1ba68b20 100644 --- a/libsrc/geom2d/python_geom2d.cpp +++ b/libsrc/geom2d/python_geom2d.cpp @@ -26,25 +26,8 @@ DLL_HEADER void ExportGeom2d(py::module &m) ng_geometry = geo; return geo; })) - .def(py::pickle( - [](SplineGeometry2d& self) - { - auto ss = make_shared(); - BinaryOutArchive archive(ss); - archive & self; - archive.FlushBuffer(); - return py::make_tuple(py::bytes(ss->str())); - }, - [](py::tuple state) - { - auto geo = make_shared(); - auto ss = make_shared (py::cast(state[0])); - BinaryInArchive archive(ss); - archive & (*geo); - return geo; - })) - - .def("Load",&SplineGeometry2d::Load) + .def(NGSPickle()) + .def("Load",&SplineGeometry2d::Load) .def("AppendPoint", FunctionPointer ([](SplineGeometry2d &self, double px, double py, double maxh, double hpref, string name) { diff --git a/libsrc/meshing/meshclass.cpp b/libsrc/meshing/meshclass.cpp index cee29d51..ebf75ff4 100644 --- a/libsrc/meshing/meshclass.cpp +++ b/libsrc/meshing/meshclass.cpp @@ -1316,7 +1316,7 @@ namespace netgen archive & *ident; - archive & geometry; + archive.Shallow(geometry); archive & *curvedelems; if (archive.Input()) diff --git a/libsrc/meshing/python_mesh.cpp b/libsrc/meshing/python_mesh.cpp index de97e28d..6360174a 100644 --- a/libsrc/meshing/python_mesh.cpp +++ b/libsrc/meshing/python_mesh.cpp @@ -493,6 +493,7 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m) } ), py::arg("dim")=3 ) + .def(NGSPickle()) /* .def("__init__", diff --git a/libsrc/occ/python_occ.cpp b/libsrc/occ/python_occ.cpp index 157dc93b..eec271a1 100644 --- a/libsrc/occ/python_occ.cpp +++ b/libsrc/occ/python_occ.cpp @@ -18,23 +18,7 @@ DLL_HEADER void ExportNgOCC(py::module &m) { py::class_, NetgenGeometry> (m, "OCCGeometry", R"raw_string(Use LoadOCCGeometry to load the geometry from a *.step file.)raw_string") .def(py::init<>()) - .def(py::pickle( - [](OCCGeometry& self) - { - auto ss = make_shared(); - BinaryOutArchive archive(ss); - archive & self; - archive.FlushBuffer(); - return py::make_tuple(py::bytes(ss->str())); - }, - [](py::tuple state) - { - auto geo = make_shared(); - auto ss = make_shared (py::cast(state[0])); - BinaryInArchive archive(ss); - archive & (*geo); - return geo; - })) + .def(NGSPickle()) .def("Heal",[](OCCGeometry & self, double tolerance, bool fixsmalledges, bool fixspotstripfaces, bool sewfaces, bool makesolids, bool splitpartitions) { self.tolerance = tolerance; diff --git a/libsrc/stlgeom/python_stl.cpp b/libsrc/stlgeom/python_stl.cpp index 9fbe49cd..4968078e 100644 --- a/libsrc/stlgeom/python_stl.cpp +++ b/libsrc/stlgeom/python_stl.cpp @@ -20,23 +20,7 @@ DLL_HEADER void ExportSTL(py::module & m) { py::class_, NetgenGeometry> (m,"STLGeometry") .def(py::init<>()) - .def(py::pickle( - [](STLGeometry& self) - { - auto ss = make_shared(); - BinaryOutArchive archive(ss); - archive & self; - archive.FlushBuffer(); - return py::make_tuple(py::bytes(ss->str())); - }, - [](py::tuple state) - { - auto geo = make_shared(); - auto ss = make_shared (py::cast(state[0])); - BinaryInArchive archive(ss); - archive & (*geo); - return geo; - })) + .def(NGSPickle()) .def("_visualizationData", [](shared_ptr stl_geo) { std::vector vertices; diff --git a/tests/pytest/test_pickling.py b/tests/pytest/test_pickling.py index 336b08e8..13e20f9e 100644 --- a/tests/pytest/test_pickling.py +++ b/tests/pytest/test_pickling.py @@ -85,5 +85,21 @@ def test_pickle_geom2d(): for val1, val2 in zip(vd1.values(), vd2.values()): assert numpy.array_equal(val1, val2) +def test_pickle_mesh(): + import netgen.csg as csg + geo = csg.CSGeometry() + brick = csg.OrthoBrick(csg.Pnt(-3,-3,-3), csg.Pnt(3,3,3)) + mesh = geo.GenerateMesh(maxh=0.2) + assert geo == mesh.GetGeometry() + dump = pickle.dumps([geo,mesh]) + geo2, mesh2 = pickle.loads(dump) + assert geo2 == mesh2.GetGeometry() + mesh.Save("msh1.vol.gz") + mesh2.Save("msh2.vol.gz") + import filecmp, os + assert filecmp.cmp("msh1.vol.gz", "msh2.vol.gz") + os.remove("msh1.vol.gz") + os.remove("msh2.vol.gz") + if __name__ == "__main__": - test_pickle_csg() + test_pickle_mesh()