netgen/libsrc/core/python_ngcore.cpp

169 lines
5.0 KiB
C++
Raw Normal View History

2018-12-28 17:43:15 +05:00
#include "logging.hpp"
#include "python_ngcore.hpp"
2018-12-28 17:43:15 +05:00
namespace py = pybind11;
2019-08-06 18:50:08 +05:00
using std::string;
2018-12-28 17:43:15 +05:00
2019-08-06 18:50:08 +05:00
namespace ngcore
2018-12-28 17:43:15 +05:00
{
2019-09-11 02:01:05 +05:00
bool ngcore_have_numpy = false;
2020-09-04 17:47:49 +05:00
bool parallel_pickling = true;
2020-08-19 17:50:11 +05:00
2019-08-06 18:50:08 +05:00
void SetFlag(Flags &flags, string s, py::object value)
{
if (py::isinstance<py::dict>(value))
{
Flags nested_flags;
for(auto item : value.cast<py::dict>())
SetFlag(nested_flags, item.first.cast<string>(),
item.second.cast<py::object>());
flags.SetFlag(s, nested_flags);
2019-08-06 18:50:08 +05:00
return;
}
if (py::isinstance<py::bool_>(value))
flags.SetFlag(s, value.cast<bool>());
if (py::isinstance<py::float_>(value))
flags.SetFlag(s, value.cast<double>());
if (py::isinstance<py::int_>(value))
flags.SetFlag(s, double(value.cast<int>()));
if (py::isinstance<py::str>(value))
flags.SetFlag(s, value.cast<string>());
if (py::isinstance<py::list>(value))
{
auto vdl = py::cast<py::list>(value);
2019-08-06 18:50:08 +05:00
if (py::len(vdl) > 0)
{
if(py::isinstance<py::float_>(vdl[0]) || py::isinstance<py::int_>(vdl[0]))
2019-08-06 18:50:08 +05:00
flags.SetFlag(s, makeCArray<double>(vdl));
if(py::isinstance<py::str>(vdl[0]))
flags.SetFlag(s, makeCArray<string>(vdl));
}
else
{
Array<string> dummystr;
Array<double> dummydbl;
flags.SetFlag(s,dummystr);
flags.SetFlag(s,dummydbl);
}
}
if (py::isinstance<py::tuple>(value))
{
auto vdt = py::cast<py::tuple>(value);
2019-08-06 18:50:08 +05:00
if (py::isinstance<py::float_>(value))
flags.SetFlag(s, makeCArray<double>(vdt));
if (py::isinstance<py::int_>(value))
flags.SetFlag(s, makeCArray<double>(vdt));
if (py::isinstance<py::str>(value))
flags.SetFlag(s, makeCArray<string>(vdt));
}
}
Flags CreateFlagsFromKwArgs(const py::kwargs& kwargs, py::object pyclass, py::list info)
2019-08-06 18:50:08 +05:00
{
static std::shared_ptr<Logger> logger = GetLogger("Flags");
py::dict flags_dict;
if (kwargs.contains("flags"))
{
logger->warn("WARNING: using flags as kwarg is deprecated{}, use the flag arguments as kwargs instead!",
pyclass.is_none() ? "" : std::string(" in ") + std::string(py::str(pyclass)));
2019-08-06 18:50:08 +05:00
auto addflags = py::cast<py::dict>(kwargs["flags"]);
for (auto item : addflags)
flags_dict[item.first.cast<string>().c_str()] = item.second;
}
py::dict special;
if(!pyclass.is_none())
{
auto flags_doc = pyclass.attr("__flags_doc__")();
for (auto item : kwargs)
if (!flags_doc.contains(item.first.cast<string>().c_str()) &&
!(item.first.cast<string>() == "flags"))
logger->warn("WARNING: kwarg '{}' is an undocumented flags option for class {}, maybe there is a typo?",
item.first.cast<string>(), std::string(py::str(pyclass)));
if(py::hasattr(pyclass,"__special_treated_flags__"))
special = pyclass.attr("__special_treated_flags__")();
}
2019-08-06 18:50:08 +05:00
for (auto item : kwargs)
{
auto name = item.first.cast<string>();
if (name != "flags")
{
if(!special.contains(name.c_str()))
flags_dict[name.c_str()] = item.second;
}
}
Flags flags;
for(auto item : flags_dict)
SetFlag(flags, item.first.cast<string>(), item.second.cast<py::object>());
2019-08-06 18:50:08 +05:00
for (auto item : kwargs)
{
auto name = item.first.cast<string>();
if (name != "flags")
{
if(special.contains(name.c_str()))
special[name.c_str()](item.second, &flags, info);
}
}
return flags;
}
py::dict CreateDictFromFlags(const Flags& flags)
{
py::dict d;
std::string key;
for(auto i : Range(flags.GetNFlagsFlags()))
{
auto& f = flags.GetFlagsFlag(i, key);
d[key.c_str()] = CreateDictFromFlags(f);
}
for(auto i : Range(flags.GetNStringListFlags()))
{
auto strlistflag = flags.GetStringListFlag(i, key);
py::list lst;
for(auto& val : *strlistflag)
lst.append(val);
d[key.c_str()] = lst;
}
for(auto i : Range(flags.GetNNumListFlags()))
{
auto numlistflag = flags.GetNumListFlag(i, key);
py::list lst;
for(auto& val : *numlistflag)
lst.append(val);
d[key.c_str()] = lst;
}
for(auto i : Range(flags.GetNStringFlags()))
{
auto val = flags.GetStringFlag(i, key);
d[key.c_str()] = val;
}
for(auto i : Range(flags.GetNNumFlags()))
{
auto val = flags.GetNumFlag(i, key);
d[key.c_str()] = val;
}
for(auto i : Range(flags.GetNDefineFlags()))
{
auto val = flags.GetDefineFlag(i, key);
d[key.c_str()] = val;
}
for(auto i : Range(flags.GetNAnyFlags()))
{
auto& a = flags.GetAnyFlag(i, key);
d[key.c_str()] = CastAnyToPy(a);
}
return d;
}
2019-08-06 18:50:08 +05:00
} // namespace ngcore