Export ngcore Arrays

This commit is contained in:
Matthias Hochsteger 2019-08-12 14:20:30 +02:00
parent 3f4cc7a07d
commit f570f31de9
5 changed files with 84 additions and 27 deletions

View File

@ -27,6 +27,47 @@ namespace ngcore
return arr; return arr;
} }
template <typename T, typename TIND=typename FlatArray<T>::index_type>
void ExportArray (py::module &m)
{
using TFlat = FlatArray<T, TIND>;
using TArray = Array<T, TIND>;
std::string suffix = std::string(typeid(T).name()) + "_" + typeid(TIND).name();
std::string fname = std::string("FlatArray_") + suffix;
py::class_<TFlat>(m, fname.c_str())
.def ("__len__", [] ( TFlat &self ) { return self.Size(); } )
.def ("__getitem__",
[](TFlat & self, TIND i) -> T&
{
static constexpr int base = IndexBASE<TIND>();
static_assert(base==0 || base==1, "IndexBASE not in [0,1]");
if (i < 0 || i >= self.Size())
throw py::index_error();
if(base==1) i++;
return self[i]; // Access from Python is always 0-based
},
py::return_value_policy::reference)
.def("__iter__", [] ( TFlat & self) {
return py::make_iterator (self.begin(),self.end());
}, py::keep_alive<0,1>()) // keep array alive while iterator is used
;
std::string aname = std::string("Array_") + suffix;
py::class_<TArray, TFlat>(m, aname.c_str())
.def(py::init([] (size_t n) { return new TArray(n); }),py::arg("n"), "Makes array of given length")
.def(py::init([] (std::vector<T> const & x)
{
size_t s = x.size();
TArray tmp(s);
for (size_t i : Range(tmp))
tmp[TIND(i)] = x[i];
return tmp;
}), py::arg("vec"), "Makes array with given list of elements")
;
}
void NGCORE_API SetFlag(Flags &flags, std::string s, py::object value); void NGCORE_API SetFlag(Flags &flags, std::string s, py::object value);
// Parse python kwargs to flags // Parse python kwargs to flags
Flags NGCORE_API CreateFlagsFromKwArgs(const py::kwargs& kwargs, py::object pyclass = py::none(), Flags NGCORE_API CreateFlagsFromKwArgs(const py::kwargs& kwargs, py::object pyclass = py::none(),

View File

@ -6,6 +6,11 @@ using namespace std;
PYBIND11_MODULE(pyngcore, m) // NOLINT PYBIND11_MODULE(pyngcore, m) // NOLINT
{ {
ExportArray<int>(m);
ExportArray<unsigned>(m);
ExportArray<size_t>(m);
ExportArray<double>(m);
py::class_<Flags>(m, "Flags") py::class_<Flags>(m, "Flags")
.def(py::init<>()) .def(py::init<>())
.def("__str__", &ToString<Flags>) .def("__str__", &ToString<Flags>)

View File

@ -40,28 +40,6 @@ namespace netgen
} }
template <typename T, int BASE = 0, typename TIND = int>
void ExportArray (py::module &m)
{
using TA = NgArray<T,BASE,TIND>;
string name = string("Array_") + typeid(T).name();
py::class_<NgArray<T,BASE,TIND>>(m, name.c_str())
.def ("__len__", [] ( NgArray<T,BASE,TIND> &self ) { return self.Size(); } )
.def ("__getitem__",
FunctionPointer ([](NgArray<T,BASE,TIND> & self, TIND i) -> T&
{
if (i < BASE || i >= BASE+self.Size())
throw py::index_error();
return self[i];
}),
py::return_value_policy::reference)
.def("__iter__", [] ( TA & self) {
return py::make_iterator (self.begin(),self.end());
}, py::keep_alive<0,1>()) // keep array alive while iterator is used
;
}
void TranslateException (const NgException & ex) void TranslateException (const NgException & ex)
{ {
string err = string("Netgen exception: ")+ex.What(); string err = string("Netgen exception: ")+ex.What();
@ -527,11 +505,11 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
ExportArray<Element,0,size_t>(m); ExportArray<Element,size_t>(m);
ExportArray<Element2d,0,size_t>(m); ExportArray<Element2d,size_t>(m);
ExportArray<Segment,0,size_t>(m); ExportArray<Segment,size_t>(m);
ExportArray<Element0d>(m); ExportArray<Element0d>(m);
ExportArray<MeshPoint,PointIndex::BASE,PointIndex>(m); ExportArray<MeshPoint,PointIndex>(m);
ExportArray<FaceDescriptor>(m); ExportArray<FaceDescriptor>(m);
py::implicitly_convertible< int, PointIndex>(); py::implicitly_convertible< int, PointIndex>();

11
tests/pytest/meshes.py Normal file
View File

@ -0,0 +1,11 @@
import pytest
@pytest.fixture
def unit_mesh_2d():
import netgen.geom2d as g2d
return g2d.unit_square.GenerateMesh(maxh=0.2)
@pytest.fixture
def unit_mesh_3d():
import netgen.csg as csg
return csg.unit_cube.GenerateMesh(maxh=0.2)

View File

@ -0,0 +1,22 @@
import pyngcore
import netgen
from meshes import unit_mesh_3d
def test_element_arrays(unit_mesh_3d):
mesh = unit_mesh_3d
el0 = mesh.Elements0D()
el1 = mesh.Elements1D()
el2 = mesh.Elements2D()
el3 = mesh.Elements3D()
p = mesh.Points()
assert len(el2) > 0
assert len(el3) > 0
assert len(p) > 0
for el in el2:
assert len(el.vertices) == 3
for el in el3:
assert len(el.vertices) == 4