array numpy buffer protocol

This commit is contained in:
Christopher Lackner 2019-09-10 23:01:05 +02:00
parent d6ebe15c17
commit 5288af641c
4 changed files with 66 additions and 3 deletions

View File

@ -7,7 +7,7 @@ using std::string;
namespace ngcore namespace ngcore
{ {
bool ngcore_have_numpy = false;
void SetFlag(Flags &flags, string s, py::object value) void SetFlag(Flags &flags, string s, py::object value)
{ {
if (py::isinstance<py::dict>(value)) if (py::isinstance<py::dict>(value))

View File

@ -4,6 +4,7 @@
#include "ngcore_api.hpp" // for operator new #include "ngcore_api.hpp" // for operator new
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/operators.h> #include <pybind11/operators.h>
#include <pybind11/numpy.h>
#include "array.hpp" #include "array.hpp"
#include "archive.hpp" #include "archive.hpp"
@ -13,6 +14,7 @@ namespace py = pybind11;
namespace ngcore namespace ngcore
{ {
NGCORE_API extern bool ngcore_have_numpy;
template<typename T> template<typename T>
Array<T> makeCArray(const py::object& obj) Array<T> makeCArray(const py::object& obj)
@ -29,6 +31,20 @@ namespace ngcore
return arr; return arr;
} }
namespace detail
{
template<typename T>
struct HasPyFormat
{
private:
template<typename T2>
static auto check(T2*) -> std::enable_if_t<std::is_same_v<decltype(std::declval<py::format_descriptor<T2>>().format()), std::string>, std::true_type>;
static auto check(...) -> std::false_type;
public:
static constexpr bool value = decltype(check((T*) nullptr))::value;
};
} // namespace detail
template <typename T, typename TIND=typename FlatArray<T>::index_type> template <typename T, typename TIND=typename FlatArray<T>::index_type>
void ExportArray (py::module &m) void ExportArray (py::module &m)
{ {
@ -36,7 +52,8 @@ namespace ngcore
using TArray = Array<T, TIND>; using TArray = Array<T, TIND>;
std::string suffix = std::string(typeid(T).name()) + "_" + typeid(TIND).name(); std::string suffix = std::string(typeid(T).name()) + "_" + typeid(TIND).name();
std::string fname = std::string("FlatArray_") + suffix; std::string fname = std::string("FlatArray_") + suffix;
py::class_<TFlat>(m, fname.c_str()) auto flatarray_class = py::class_<TFlat>(m, fname.c_str(),
py::buffer_protocol())
.def ("__len__", [] ( TFlat &self ) { return self.Size(); } ) .def ("__len__", [] ( TFlat &self ) { return self.Size(); } )
.def ("__getitem__", .def ("__getitem__",
[](TFlat & self, TIND i) -> T& [](TFlat & self, TIND i) -> T&
@ -77,6 +94,30 @@ namespace ngcore
; ;
if constexpr (detail::HasPyFormat<T>::value)
{
if(ngcore_have_numpy && !py::detail::npy_format_descriptor<T>::dtype().is_none())
{
flatarray_class
.def_buffer([](TFlat& self)
{
return py::buffer_info(
self.Addr(0),
sizeof(T),
py::format_descriptor<T>::format(),
1,
{ self.Size() },
{ sizeof(T) * (self.Addr(1) - self.Addr(0)) });
})
.def("NumPy", [](py::object self)
{
return py::module::import("numpy")
.attr("frombuffer")(self, py::detail::npy_format_descriptor<T>::dtype());
})
;
}
}
std::string aname = std::string("Array_") + suffix; std::string aname = std::string("Array_") + suffix;
py::class_<TArray, TFlat>(m, aname.c_str()) py::class_<TArray, TFlat>(m, aname.c_str())
.def(py::init([] (size_t n) { return new TArray(n); }),py::arg("n"), "Makes array of given length") .def(py::init([] (size_t n) { return new TArray(n); }),py::arg("n"), "Makes array of given length")

View File

@ -2,13 +2,18 @@
#include "bitarray.hpp" #include "bitarray.hpp"
#include "taskmanager.hpp" #include "taskmanager.hpp"
using namespace ngcore; using namespace ngcore;
using namespace std; using namespace std;
using namespace pybind11::literals; using namespace pybind11::literals;
PYBIND11_MODULE(pyngcore, m) // NOLINT PYBIND11_MODULE(pyngcore, m) // NOLINT
{ {
try
{
auto numpy = py::module::import("numpy");
ngcore_have_numpy = !numpy.is_none();
}
catch(...) {}
ExportArray<int>(m); ExportArray<int>(m);
ExportArray<unsigned>(m); ExportArray<unsigned>(m);
ExportArray<size_t>(m); ExportArray<size_t>(m);

View File

@ -0,0 +1,17 @@
from pyngcore import *
from numpy import sort, array
def test_array_numpy():
a = Array_i_m(5)
a[:] = 0
a[3:] = 2
assert(sum(a) == 4)
a[1] = 5
b = sort(a)
assert(all(b == array([0,0,2,2,5])))
assert(all(a == array([0,5,0,2,2])))
a.NumPy().sort()
assert(all(a == array([0,0,2,2,5])))
if __name__ == "__main__":
test_array_numpy()