diff --git a/libsrc/core/python_ngcore.hpp b/libsrc/core/python_ngcore.hpp index 7a2fe5b2..7b22ee79 100644 --- a/libsrc/core/python_ngcore.hpp +++ b/libsrc/core/python_ngcore.hpp @@ -26,7 +26,48 @@ namespace ngcore throw py::type_error("Cannot convert Python object to C Array"); return arr; } - + + template ::index_type> + void ExportArray (py::module &m) + { + using TFlat = FlatArray; + using TArray = Array; + std::string suffix = std::string(typeid(T).name()) + "_" + typeid(TIND).name(); + std::string fname = std::string("FlatArray_") + suffix; + py::class_(m, fname.c_str()) + .def ("__len__", [] ( TFlat &self ) { return self.Size(); } ) + .def ("__getitem__", + [](TFlat & self, TIND i) -> T& + { + static constexpr int base = IndexBASE(); + static_assert(base==0 || base==1, "IndexBASE not in [0,1]"); + if (i < 0 || i >= self.Size()) + throw py::index_error(); + if(base==1) i++; + return self[i]; // Access from Python is always 0-based + }, + py::return_value_policy::reference) + .def("__iter__", [] ( TFlat & self) { + return py::make_iterator (self.begin(),self.end()); + }, py::keep_alive<0,1>()) // keep array alive while iterator is used + + ; + + std::string aname = std::string("Array_") + suffix; + py::class_(m, aname.c_str()) + .def(py::init([] (size_t n) { return new TArray(n); }),py::arg("n"), "Makes array of given length") + .def(py::init([] (std::vector const & x) + { + size_t s = x.size(); + TArray tmp(s); + for (size_t i : Range(tmp)) + tmp[TIND(i)] = x[i]; + return tmp; + }), py::arg("vec"), "Makes array with given list of elements") + + ; + } + void NGCORE_API SetFlag(Flags &flags, std::string s, py::object value); // Parse python kwargs to flags Flags NGCORE_API CreateFlagsFromKwArgs(const py::kwargs& kwargs, py::object pyclass = py::none(), diff --git a/libsrc/core/python_ngcore_export.cpp b/libsrc/core/python_ngcore_export.cpp index 90dbf72f..2f409aa4 100644 --- a/libsrc/core/python_ngcore_export.cpp +++ b/libsrc/core/python_ngcore_export.cpp @@ -6,6 +6,11 @@ using namespace std; PYBIND11_MODULE(pyngcore, m) // NOLINT { + ExportArray(m); + ExportArray(m); + ExportArray(m); + ExportArray(m); + py::class_(m, "Flags") .def(py::init<>()) .def("__str__", &ToString) diff --git a/libsrc/meshing/python_mesh.cpp b/libsrc/meshing/python_mesh.cpp index 11b4d85e..a8a40150 100644 --- a/libsrc/meshing/python_mesh.cpp +++ b/libsrc/meshing/python_mesh.cpp @@ -40,28 +40,6 @@ namespace netgen } -template -void ExportArray (py::module &m) -{ - using TA = NgArray; - string name = string("Array_") + typeid(T).name(); - py::class_>(m, name.c_str()) - .def ("__len__", [] ( NgArray &self ) { return self.Size(); } ) - .def ("__getitem__", - FunctionPointer ([](NgArray & self, TIND i) -> T& - { - if (i < BASE || i >= BASE+self.Size()) - throw py::index_error(); - return self[i]; - }), - py::return_value_policy::reference) - .def("__iter__", [] ( TA & self) { - return py::make_iterator (self.begin(),self.end()); - }, py::keep_alive<0,1>()) // keep array alive while iterator is used - - ; -} - void TranslateException (const NgException & ex) { string err = string("Netgen exception: ")+ex.What(); @@ -527,11 +505,11 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m) - ExportArray(m); - ExportArray(m); - ExportArray(m); + ExportArray(m); + ExportArray(m); + ExportArray(m); ExportArray(m); - ExportArray(m); + ExportArray(m); ExportArray(m); py::implicitly_convertible< int, PointIndex>(); diff --git a/tests/pytest/meshes.py b/tests/pytest/meshes.py new file mode 100644 index 00000000..d1c9a7dc --- /dev/null +++ b/tests/pytest/meshes.py @@ -0,0 +1,11 @@ +import pytest + +@pytest.fixture +def unit_mesh_2d(): + import netgen.geom2d as g2d + return g2d.unit_square.GenerateMesh(maxh=0.2) + +@pytest.fixture +def unit_mesh_3d(): + import netgen.csg as csg + return csg.unit_cube.GenerateMesh(maxh=0.2) diff --git a/tests/pytest/test_meshclass.py b/tests/pytest/test_meshclass.py new file mode 100644 index 00000000..0144422f --- /dev/null +++ b/tests/pytest/test_meshclass.py @@ -0,0 +1,22 @@ +import pyngcore +import netgen + +from meshes import unit_mesh_3d + +def test_element_arrays(unit_mesh_3d): + mesh = unit_mesh_3d + el0 = mesh.Elements0D() + el1 = mesh.Elements1D() + el2 = mesh.Elements2D() + el3 = mesh.Elements3D() + p = mesh.Points() + + assert len(el2) > 0 + assert len(el3) > 0 + assert len(p) > 0 + + for el in el2: + assert len(el.vertices) == 3 + + for el in el3: + assert len(el.vertices) == 4