mirror of
https://github.com/NGSolve/netgen.git
synced 2024-12-25 21:40:33 +05:00
array numpy buffer protocol
This commit is contained in:
parent
d6ebe15c17
commit
5288af641c
@ -7,7 +7,7 @@ using std::string;
|
|||||||
|
|
||||||
namespace ngcore
|
namespace ngcore
|
||||||
{
|
{
|
||||||
|
bool ngcore_have_numpy = false;
|
||||||
void SetFlag(Flags &flags, string s, py::object value)
|
void SetFlag(Flags &flags, string s, py::object value)
|
||||||
{
|
{
|
||||||
if (py::isinstance<py::dict>(value))
|
if (py::isinstance<py::dict>(value))
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
#include "ngcore_api.hpp" // for operator new
|
#include "ngcore_api.hpp" // for operator new
|
||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
#include <pybind11/operators.h>
|
#include <pybind11/operators.h>
|
||||||
|
#include <pybind11/numpy.h>
|
||||||
|
|
||||||
#include "array.hpp"
|
#include "array.hpp"
|
||||||
#include "archive.hpp"
|
#include "archive.hpp"
|
||||||
@ -13,6 +14,7 @@ namespace py = pybind11;
|
|||||||
|
|
||||||
namespace ngcore
|
namespace ngcore
|
||||||
{
|
{
|
||||||
|
NGCORE_API extern bool ngcore_have_numpy;
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
Array<T> makeCArray(const py::object& obj)
|
Array<T> makeCArray(const py::object& obj)
|
||||||
@ -29,6 +31,20 @@ namespace ngcore
|
|||||||
return arr;
|
return arr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace detail
|
||||||
|
{
|
||||||
|
template<typename T>
|
||||||
|
struct HasPyFormat
|
||||||
|
{
|
||||||
|
private:
|
||||||
|
template<typename T2>
|
||||||
|
static auto check(T2*) -> std::enable_if_t<std::is_same_v<decltype(std::declval<py::format_descriptor<T2>>().format()), std::string>, std::true_type>;
|
||||||
|
static auto check(...) -> std::false_type;
|
||||||
|
public:
|
||||||
|
static constexpr bool value = decltype(check((T*) nullptr))::value;
|
||||||
|
};
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
template <typename T, typename TIND=typename FlatArray<T>::index_type>
|
template <typename T, typename TIND=typename FlatArray<T>::index_type>
|
||||||
void ExportArray (py::module &m)
|
void ExportArray (py::module &m)
|
||||||
{
|
{
|
||||||
@ -36,7 +52,8 @@ namespace ngcore
|
|||||||
using TArray = Array<T, TIND>;
|
using TArray = Array<T, TIND>;
|
||||||
std::string suffix = std::string(typeid(T).name()) + "_" + typeid(TIND).name();
|
std::string suffix = std::string(typeid(T).name()) + "_" + typeid(TIND).name();
|
||||||
std::string fname = std::string("FlatArray_") + suffix;
|
std::string fname = std::string("FlatArray_") + suffix;
|
||||||
py::class_<TFlat>(m, fname.c_str())
|
auto flatarray_class = py::class_<TFlat>(m, fname.c_str(),
|
||||||
|
py::buffer_protocol())
|
||||||
.def ("__len__", [] ( TFlat &self ) { return self.Size(); } )
|
.def ("__len__", [] ( TFlat &self ) { return self.Size(); } )
|
||||||
.def ("__getitem__",
|
.def ("__getitem__",
|
||||||
[](TFlat & self, TIND i) -> T&
|
[](TFlat & self, TIND i) -> T&
|
||||||
@ -77,6 +94,30 @@ namespace ngcore
|
|||||||
|
|
||||||
;
|
;
|
||||||
|
|
||||||
|
if constexpr (detail::HasPyFormat<T>::value)
|
||||||
|
{
|
||||||
|
if(ngcore_have_numpy && !py::detail::npy_format_descriptor<T>::dtype().is_none())
|
||||||
|
{
|
||||||
|
flatarray_class
|
||||||
|
.def_buffer([](TFlat& self)
|
||||||
|
{
|
||||||
|
return py::buffer_info(
|
||||||
|
self.Addr(0),
|
||||||
|
sizeof(T),
|
||||||
|
py::format_descriptor<T>::format(),
|
||||||
|
1,
|
||||||
|
{ self.Size() },
|
||||||
|
{ sizeof(T) * (self.Addr(1) - self.Addr(0)) });
|
||||||
|
})
|
||||||
|
.def("NumPy", [](py::object self)
|
||||||
|
{
|
||||||
|
return py::module::import("numpy")
|
||||||
|
.attr("frombuffer")(self, py::detail::npy_format_descriptor<T>::dtype());
|
||||||
|
})
|
||||||
|
;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::string aname = std::string("Array_") + suffix;
|
std::string aname = std::string("Array_") + suffix;
|
||||||
py::class_<TArray, TFlat>(m, aname.c_str())
|
py::class_<TArray, TFlat>(m, aname.c_str())
|
||||||
.def(py::init([] (size_t n) { return new TArray(n); }),py::arg("n"), "Makes array of given length")
|
.def(py::init([] (size_t n) { return new TArray(n); }),py::arg("n"), "Makes array of given length")
|
||||||
|
@ -2,13 +2,18 @@
|
|||||||
#include "bitarray.hpp"
|
#include "bitarray.hpp"
|
||||||
#include "taskmanager.hpp"
|
#include "taskmanager.hpp"
|
||||||
|
|
||||||
|
|
||||||
using namespace ngcore;
|
using namespace ngcore;
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace pybind11::literals;
|
using namespace pybind11::literals;
|
||||||
|
|
||||||
PYBIND11_MODULE(pyngcore, m) // NOLINT
|
PYBIND11_MODULE(pyngcore, m) // NOLINT
|
||||||
{
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
auto numpy = py::module::import("numpy");
|
||||||
|
ngcore_have_numpy = !numpy.is_none();
|
||||||
|
}
|
||||||
|
catch(...) {}
|
||||||
ExportArray<int>(m);
|
ExportArray<int>(m);
|
||||||
ExportArray<unsigned>(m);
|
ExportArray<unsigned>(m);
|
||||||
ExportArray<size_t>(m);
|
ExportArray<size_t>(m);
|
||||||
|
17
tests/pytest/test_array.py
Normal file
17
tests/pytest/test_array.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from pyngcore import *
|
||||||
|
from numpy import sort, array
|
||||||
|
|
||||||
|
def test_array_numpy():
|
||||||
|
a = Array_i_m(5)
|
||||||
|
a[:] = 0
|
||||||
|
a[3:] = 2
|
||||||
|
assert(sum(a) == 4)
|
||||||
|
a[1] = 5
|
||||||
|
b = sort(a)
|
||||||
|
assert(all(b == array([0,0,2,2,5])))
|
||||||
|
assert(all(a == array([0,5,0,2,2])))
|
||||||
|
a.NumPy().sort()
|
||||||
|
assert(all(a == array([0,0,2,2,5])))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_array_numpy()
|
Loading…
Reference in New Issue
Block a user