mirror of
https://github.com/NGSolve/netgen.git
synced 2024-12-25 21:40:33 +05:00
Export ngcore Arrays
This commit is contained in:
parent
3f4cc7a07d
commit
f570f31de9
@ -26,7 +26,48 @@ namespace ngcore
|
||||
throw py::type_error("Cannot convert Python object to C Array");
|
||||
return arr;
|
||||
}
|
||||
|
||||
|
||||
template <typename T, typename TIND=typename FlatArray<T>::index_type>
|
||||
void ExportArray (py::module &m)
|
||||
{
|
||||
using TFlat = FlatArray<T, TIND>;
|
||||
using TArray = Array<T, TIND>;
|
||||
std::string suffix = std::string(typeid(T).name()) + "_" + typeid(TIND).name();
|
||||
std::string fname = std::string("FlatArray_") + suffix;
|
||||
py::class_<TFlat>(m, fname.c_str())
|
||||
.def ("__len__", [] ( TFlat &self ) { return self.Size(); } )
|
||||
.def ("__getitem__",
|
||||
[](TFlat & self, TIND i) -> T&
|
||||
{
|
||||
static constexpr int base = IndexBASE<TIND>();
|
||||
static_assert(base==0 || base==1, "IndexBASE not in [0,1]");
|
||||
if (i < 0 || i >= self.Size())
|
||||
throw py::index_error();
|
||||
if(base==1) i++;
|
||||
return self[i]; // Access from Python is always 0-based
|
||||
},
|
||||
py::return_value_policy::reference)
|
||||
.def("__iter__", [] ( TFlat & self) {
|
||||
return py::make_iterator (self.begin(),self.end());
|
||||
}, py::keep_alive<0,1>()) // keep array alive while iterator is used
|
||||
|
||||
;
|
||||
|
||||
std::string aname = std::string("Array_") + suffix;
|
||||
py::class_<TArray, TFlat>(m, aname.c_str())
|
||||
.def(py::init([] (size_t n) { return new TArray(n); }),py::arg("n"), "Makes array of given length")
|
||||
.def(py::init([] (std::vector<T> const & x)
|
||||
{
|
||||
size_t s = x.size();
|
||||
TArray tmp(s);
|
||||
for (size_t i : Range(tmp))
|
||||
tmp[TIND(i)] = x[i];
|
||||
return tmp;
|
||||
}), py::arg("vec"), "Makes array with given list of elements")
|
||||
|
||||
;
|
||||
}
|
||||
|
||||
void NGCORE_API SetFlag(Flags &flags, std::string s, py::object value);
|
||||
// Parse python kwargs to flags
|
||||
Flags NGCORE_API CreateFlagsFromKwArgs(const py::kwargs& kwargs, py::object pyclass = py::none(),
|
||||
|
@ -6,6 +6,11 @@ using namespace std;
|
||||
|
||||
PYBIND11_MODULE(pyngcore, m) // NOLINT
|
||||
{
|
||||
ExportArray<int>(m);
|
||||
ExportArray<unsigned>(m);
|
||||
ExportArray<size_t>(m);
|
||||
ExportArray<double>(m);
|
||||
|
||||
py::class_<Flags>(m, "Flags")
|
||||
.def(py::init<>())
|
||||
.def("__str__", &ToString<Flags>)
|
||||
|
@ -40,28 +40,6 @@ namespace netgen
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int BASE = 0, typename TIND = int>
|
||||
void ExportArray (py::module &m)
|
||||
{
|
||||
using TA = NgArray<T,BASE,TIND>;
|
||||
string name = string("Array_") + typeid(T).name();
|
||||
py::class_<NgArray<T,BASE,TIND>>(m, name.c_str())
|
||||
.def ("__len__", [] ( NgArray<T,BASE,TIND> &self ) { return self.Size(); } )
|
||||
.def ("__getitem__",
|
||||
FunctionPointer ([](NgArray<T,BASE,TIND> & self, TIND i) -> T&
|
||||
{
|
||||
if (i < BASE || i >= BASE+self.Size())
|
||||
throw py::index_error();
|
||||
return self[i];
|
||||
}),
|
||||
py::return_value_policy::reference)
|
||||
.def("__iter__", [] ( TA & self) {
|
||||
return py::make_iterator (self.begin(),self.end());
|
||||
}, py::keep_alive<0,1>()) // keep array alive while iterator is used
|
||||
|
||||
;
|
||||
}
|
||||
|
||||
void TranslateException (const NgException & ex)
|
||||
{
|
||||
string err = string("Netgen exception: ")+ex.What();
|
||||
@ -527,11 +505,11 @@ DLL_HEADER void ExportNetgenMeshing(py::module &m)
|
||||
|
||||
|
||||
|
||||
ExportArray<Element,0,size_t>(m);
|
||||
ExportArray<Element2d,0,size_t>(m);
|
||||
ExportArray<Segment,0,size_t>(m);
|
||||
ExportArray<Element,size_t>(m);
|
||||
ExportArray<Element2d,size_t>(m);
|
||||
ExportArray<Segment,size_t>(m);
|
||||
ExportArray<Element0d>(m);
|
||||
ExportArray<MeshPoint,PointIndex::BASE,PointIndex>(m);
|
||||
ExportArray<MeshPoint,PointIndex>(m);
|
||||
ExportArray<FaceDescriptor>(m);
|
||||
|
||||
py::implicitly_convertible< int, PointIndex>();
|
||||
|
11
tests/pytest/meshes.py
Normal file
11
tests/pytest/meshes.py
Normal file
@ -0,0 +1,11 @@
|
||||
import pytest
|
||||
|
||||
@pytest.fixture
|
||||
def unit_mesh_2d():
|
||||
import netgen.geom2d as g2d
|
||||
return g2d.unit_square.GenerateMesh(maxh=0.2)
|
||||
|
||||
@pytest.fixture
|
||||
def unit_mesh_3d():
|
||||
import netgen.csg as csg
|
||||
return csg.unit_cube.GenerateMesh(maxh=0.2)
|
22
tests/pytest/test_meshclass.py
Normal file
22
tests/pytest/test_meshclass.py
Normal file
@ -0,0 +1,22 @@
|
||||
import pyngcore
|
||||
import netgen
|
||||
|
||||
from meshes import unit_mesh_3d
|
||||
|
||||
def test_element_arrays(unit_mesh_3d):
|
||||
mesh = unit_mesh_3d
|
||||
el0 = mesh.Elements0D()
|
||||
el1 = mesh.Elements1D()
|
||||
el2 = mesh.Elements2D()
|
||||
el3 = mesh.Elements3D()
|
||||
p = mesh.Points()
|
||||
|
||||
assert len(el2) > 0
|
||||
assert len(el3) > 0
|
||||
assert len(p) > 0
|
||||
|
||||
for el in el2:
|
||||
assert len(el.vertices) == 3
|
||||
|
||||
for el in el3:
|
||||
assert len(el.vertices) == 4
|
Loading…
Reference in New Issue
Block a user