Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #include "pybind11.h"
6 : #include "unit.h"
7 :
8 : #include "scipp/dataset/dataset.h"
9 : #include "scipp/dataset/to_unit.h"
10 : #include "scipp/variable/operations.h"
11 : #include "scipp/variable/to_unit.h"
12 :
13 : using namespace scipp;
14 : using namespace scipp::variable;
15 : using namespace scipp::dataset;
16 :
17 : namespace py = pybind11;
18 :
19 : namespace {
20 3 : template <typename T> void bind_norm(py::module &m) {
21 3 : m.def(
22 1 : "norm", [](const T &x) { return norm(x); }, py::arg("x"),
23 3 : py::call_guard<py::gil_scoped_release>());
24 3 : }
25 :
26 3 : template <typename T> void bind_nan_to_num(py::module &m) {
27 9 : m.def(
28 : "nan_to_num",
29 4 : [](const T &x, const std::optional<Variable> &nan,
30 : const std::optional<Variable> &posinf,
31 : const std::optional<Variable> &neginf) {
32 4 : Variable out(x);
33 4 : if (nan)
34 2 : nan_to_num(out, *nan, out);
35 4 : if (posinf)
36 2 : positive_inf_to_num(out, *posinf, out);
37 4 : if (neginf)
38 2 : negative_inf_to_num(out, *neginf, out);
39 4 : return out;
40 0 : },
41 6 : py::arg("x"), py::kw_only(), py::arg("nan") = std::optional<Variable>(),
42 6 : py::arg("posinf") = std::optional<Variable>(),
43 6 : py::arg("neginf") = std::optional<Variable>(),
44 0 : py::call_guard<py::gil_scoped_release>());
45 :
46 12 : m.def(
47 : "nan_to_num",
48 5 : [](const T &x, const std::optional<Variable> &nan,
49 : const std::optional<Variable> &posinf,
50 : const std::optional<Variable> &neginf, T &out) {
51 5 : if (nan)
52 2 : nan_to_num(x, *nan, out);
53 5 : if (posinf)
54 1 : positive_inf_to_num(x, *posinf, out);
55 5 : if (neginf)
56 1 : negative_inf_to_num(x, *neginf, out);
57 5 : return out;
58 : },
59 6 : py::arg("x"), py::kw_only(), py::arg("nan") = std::optional<Variable>(),
60 6 : py::arg("posinf") = std::optional<Variable>(),
61 6 : py::arg("neginf") = std::optional<Variable>(), py::arg("out"),
62 3 : py::call_guard<py::gil_scoped_release>());
63 3 : }
64 :
65 6 : template <class T> void bind_to_unit(py::module &m) {
66 6 : m.def(
67 : "to_unit",
68 20153 : [](const T &x, const ProtoUnit &unit, const bool copy) {
69 40287 : return to_unit(x, unit_or_default(unit),
70 60439 : copy ? CopyPolicy::Always : CopyPolicy::TryAvoid);
71 : },
72 12 : py::arg("x"), py::arg("unit"), py::arg("copy") = true,
73 6 : py::call_guard<py::gil_scoped_release>());
74 6 : }
75 :
76 6 : template <class T> void bind_as_const(py::module &m) {
77 46 : m.def("as_const", [](const T &x) { return x.as_const(); }, py::arg("x"));
78 6 : }
79 : } // namespace
80 :
81 3 : void init_unary(py::module &m) {
82 3 : bind_norm<Variable>(m);
83 3 : bind_nan_to_num<Variable>(m);
84 3 : bind_to_unit<Variable>(m);
85 3 : bind_to_unit<DataArray>(m);
86 3 : bind_as_const<Variable>(m);
87 3 : bind_as_const<DataArray>(m);
88 3 : }
|