Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #pragma once
6 :
7 : #include "scipp/dataset/arithmetic.h"
8 : #include "scipp/dataset/astype.h"
9 : #include "scipp/dataset/generated_comparison.h"
10 : #include "scipp/dataset/generated_logical.h"
11 : #include "scipp/dataset/generated_math.h"
12 : #include "scipp/dataset/to_unit.h"
13 : #include "scipp/dataset/util.h"
14 : #include "scipp/units/except.h"
15 : #include "scipp/variable/arithmetic.h"
16 : #include "scipp/variable/astype.h"
17 : #include "scipp/variable/comparison.h"
18 : #include "scipp/variable/logical.h"
19 : #include "scipp/variable/pow.h"
20 : #include "scipp/variable/to_unit.h"
21 :
22 : #include "dtype.h"
23 : #include "format.h"
24 : #include "pybind11.h"
25 :
26 : namespace py = pybind11;
27 :
28 : template <class T, class... Ignored>
29 9 : void bind_common_operators(pybind11::class_<T, Ignored...> &c) {
30 25 : c.def("__abs__", [](const T &self) { return abs(self); });
31 216 : c.def("__repr__", [](const T &self) { return to_string(self); });
32 11 : c.def("__bool__", [](const T &self) {
33 : if constexpr (std::is_same_v<T, scipp::Variable>) {
34 4974 : if (self.unit() != scipp::units::none)
35 2 : throw scipp::except::UnitError(
36 : "The truth value of a variable with unit is undefined.");
37 4972 : return self.template value<bool>() == true;
38 : }
39 2 : throw std::runtime_error("The truth value of a variable, data array, or "
40 : "dataset is ambiguous. Use any() or all().");
41 : });
42 9 : c.def(
43 : "copy",
44 9017 : [](const T &self, const bool deep) { return deep ? copy(self) : self; },
45 9 : py::arg("deep") = true, py::call_guard<py::gil_scoped_release>(),
46 : R"(
47 : Return a (by default deep) copy.
48 :
49 : If `deep=True` (the default), a deep copy is made. Otherwise, a shallow
50 : copy is made, and the returned data (and meta data) values are new views
51 : of the data and meta data values of this object.)");
52 9 : c.def(
53 13 : "__copy__", [](const T &self) { return self; },
54 0 : py::call_guard<py::gil_scoped_release>(), "Return a (shallow) copy.");
55 9 : c.def(
56 : "__deepcopy__",
57 182 : [](const T &self, const py::typing::Dict<py::object, py::object> &) {
58 182 : return copy(self);
59 : },
60 9 : py::call_guard<py::gil_scoped_release>(), "Return a (deep) copy.")
61 9 : .def(
62 : "__sizeof__",
63 312 : [](const T &self) {
64 312 : return size_of(self, scipp::SizeofTag::ViewOnly);
65 : },
66 : R"doc(Return the size of the object in bytes.
67 :
68 : The size includes the object itself and all arrays contained in it.
69 : But arrays may be counted multiple times if components share buffers,
70 : e.g. multiple coordinates referencing the same memory.
71 : Conversely, the size may be underestimated. Especially, but not only,
72 : with dtype=PyObject.
73 :
74 : This function only includes memory of the current slice. Use
75 : ``underlying_size`` to get the full memory size of the underlying structure.)doc")
76 9 : .def(
77 : "underlying_size",
78 312 : [](const T &self) {
79 312 : return size_of(self, scipp::SizeofTag::Underlying);
80 : },
81 : R"doc(Return the size of the object in bytes.
82 :
83 : The size includes the object itself and all arrays contained in it.
84 : But arrays may be counted multiple times if components share buffers,
85 : e.g. multiple coordinates referencing the same memory.
86 : Conversely, the size may be underestimated. Especially, but not only,
87 : with dtype=PyObject.
88 :
89 : This function includes all memory of the underlying buffers. Use
90 : ``__sizeof__`` to get the size of the current slice only.)doc");
91 9 : }
92 :
93 : template <class T, class... Ignored>
94 6 : void bind_astype(py::class_<T, Ignored...> &c) {
95 6 : c.def(
96 : "astype",
97 5007 : [](const T &self, const py::object &type, const bool copy) {
98 5007 : const auto [scipp_dtype, dtype_unit] =
99 : cast_dtype_and_unit(type, DefaultUnit{});
100 5012 : if (dtype_unit.has_value() &&
101 25 : (dtype_unit != scipp::units::one && dtype_unit != self.unit())) {
102 2 : throw scipp::except::UnitError(scipp::python::format(
103 : "Conversion of units via the dtype is not allowed. Occurred when "
104 : "trying to change dtype from ",
105 2 : self.dtype(), " to ", type,
106 : ". Use to_unit in combination with astype."));
107 : }
108 5006 : [[maybe_unused]] py::gil_scoped_release release;
109 : return astype(self, scipp_dtype,
110 : copy ? scipp::CopyPolicy::Always
111 10008 : : scipp::CopyPolicy::TryAvoid);
112 5006 : },
113 12 : py::arg("type"), py::kw_only(), py::arg("copy") = true,
114 : R"(
115 : Converts a Variable or DataArray to a different dtype.
116 :
117 : If the dtype is unchanged and ``copy`` is `False`, the object
118 : is returned without making a deep copy.
119 :
120 : :param type: Target dtype.
121 : :param copy: If `False`, return the input object if possible.
122 : If `True`, the function always returns a new object.
123 : :raises: If the data cannot be converted to the requested dtype.
124 : :return: New variable or data array with specified dtype.
125 : :rtype: Union[scipp.Variable, scipp.DataArray])");
126 6 : }
127 :
128 : template <class Other, class T, class... Ignored>
129 12 : void bind_inequality_to_operator(pybind11::class_<T, Ignored...> &c) {
130 12 : c.def(
131 48 : "__eq__", [](const T &a, const Other &b) { return a == b; },
132 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
133 12 : c.def(
134 16 : "__ne__", [](const T &a, const Other &b) { return a != b; },
135 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
136 12 : }
137 :
138 : struct Identity {
139 86344 : template <class T> const T &operator()(const T &x) const noexcept {
140 86344 : return x;
141 : }
142 : };
143 : struct ScalarToVariable {
144 9997 : template <class T> scipp::Variable operator()(const T &x) const noexcept {
145 9997 : return x * scipp::units::one;
146 : }
147 : };
148 :
149 : template <class RHSSetup> struct OpBinder {
150 : template <class Other, class T, class... Ignored>
151 36 : static void in_place_binary(pybind11::class_<T, Ignored...> &c) {
152 : using namespace scipp;
153 : // In-place operators return py::object due to the way in-place operators
154 : // work in Python (assigning return value to this). This avoids extra
155 : // copies, and additionally ensures that all references to the object keep
156 : // referencing the same object after the operation.
157 : // WARNING: It is crucial to explicitly return 'py::object &' here.
158 : // Otherwise the py::object is returned by value, which increments the
159 : // reference count, which is not only suboptimal but also incorrect since
160 : // we have released the GIL via py::gil_scoped_release.
161 36 : c.def(
162 : "__iadd__",
163 61 : [](py::object &a, Other &b) -> py::object & {
164 61 : a.cast<T &>() += RHSSetup{}(b);
165 58 : return a;
166 : },
167 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
168 36 : c.def(
169 : "__isub__",
170 14 : [](py::object &a, Other &b) -> py::object & {
171 14 : a.cast<T &>() -= RHSSetup{}(b);
172 14 : return a;
173 : },
174 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
175 36 : c.def(
176 : "__imul__",
177 69 : [](py::object &a, Other &b) -> py::object & {
178 71 : a.cast<T &>() *= RHSSetup{}(b);
179 66 : return a;
180 : },
181 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
182 36 : c.def(
183 : "__itruediv__",
184 13 : [](py::object &a, Other &b) -> py::object & {
185 13 : a.cast<T &>() /= RHSSetup{}(b);
186 13 : return a;
187 : },
188 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
189 : if constexpr (!(std::is_same_v<T, Dataset> ||
190 : std::is_same_v<Other, Dataset>)) {
191 21 : c.def(
192 : "__imod__",
193 0 : [](py::object &a, Other &b) -> py::object & {
194 0 : a.cast<T &>() %= RHSSetup{}(b);
195 0 : return a;
196 : },
197 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
198 21 : c.def(
199 : "__ifloordiv__",
200 2 : [](py::object &a, Other &b) -> py::object & {
201 2 : floor_divide_equals(a.cast<T &>(), RHSSetup{}(b));
202 2 : return a;
203 : },
204 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
205 : if constexpr (!(std::is_same_v<T, DataArray> ||
206 : std::is_same_v<Other, DataArray>)) {
207 9 : c.def(
208 : "__ipow__",
209 4 : [](T &base, Other &exponent) -> T & {
210 4 : return pow(base, RHSSetup{}(exponent), base);
211 : },
212 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
213 : }
214 : }
215 36 : }
216 :
217 : template <class Other, class T, class... Ignored>
218 42 : static void binary(pybind11::class_<T, Ignored...> &c) {
219 : using namespace scipp;
220 42 : c.def(
221 1350 : "__add__", [](const T &a, const Other &b) { return a + RHSSetup{}(b); },
222 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
223 42 : c.def(
224 3629 : "__sub__", [](const T &a, const Other &b) { return a - RHSSetup{}(b); },
225 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
226 42 : c.def(
227 67488 : "__mul__", [](const T &a, const Other &b) { return a * RHSSetup{}(b); },
228 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
229 42 : c.def(
230 : "__truediv__",
231 8716 : [](const T &a, const Other &b) { return a / RHSSetup{}(b); },
232 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
233 : if constexpr (!(std::is_same_v<T, Dataset> ||
234 : std::is_same_v<Other, Dataset>)) {
235 24 : c.def(
236 : "__floordiv__",
237 88 : [](const T &a, const Other &b) {
238 88 : return floor_divide(a, RHSSetup{}(b));
239 : },
240 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
241 24 : c.def(
242 : "__mod__",
243 65 : [](const T &a, const Other &b) { return a % RHSSetup{}(b); },
244 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
245 24 : c.def(
246 : "__pow__",
247 41 : [](const T &base, const Other &exponent) {
248 41 : return pow(base, RHSSetup{}(exponent));
249 : },
250 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
251 : }
252 42 : }
253 :
254 : template <class Other, class T, class... Ignored>
255 12 : static void reverse_binary(pybind11::class_<T, Ignored...> &c) {
256 : using namespace scipp;
257 12 : c.def(
258 8484 : "__radd__", [](const T &a, const Other b) { return RHSSetup{}(b) + a; },
259 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
260 12 : c.def(
261 9 : "__rsub__", [](const T &a, const Other b) { return RHSSetup{}(b)-a; },
262 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
263 12 : c.def(
264 963 : "__rmul__", [](const T &a, const Other b) { return RHSSetup{}(b)*a; },
265 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
266 12 : c.def(
267 : "__rtruediv__",
268 29 : [](const T &a, const Other b) { return RHSSetup{}(b) / a; },
269 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
270 : if constexpr (!(std::is_same_v<T, Dataset> ||
271 : std::is_same_v<Other, Dataset>)) {
272 12 : c.def(
273 : "__rfloordiv__",
274 4 : [](const T &a, const Other &b) {
275 4 : return floor_divide(RHSSetup{}(b), a);
276 : },
277 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
278 12 : c.def(
279 : "__rmod__",
280 4 : [](const T &a, const Other &b) { return RHSSetup{}(b) % a; },
281 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
282 12 : c.def(
283 : "__rpow__",
284 4 : [](const T &exponent, const Other &base) {
285 4 : return pow(RHSSetup{}(base), exponent);
286 : },
287 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
288 : }
289 12 : }
290 :
291 : template <class Other, class T, class... Ignored>
292 15 : static void comparison(pybind11::class_<T, Ignored...> &c) {
293 15 : c.def(
294 106 : "__eq__", [](T &a, Other &b) { return equal(a, RHSSetup{}(b)); },
295 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
296 15 : c.def(
297 72 : "__ne__", [](T &a, Other &b) { return not_equal(a, RHSSetup{}(b)); },
298 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
299 15 : c.def(
300 76 : "__lt__", [](T &a, Other &b) { return less(a, RHSSetup{}(b)); },
301 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
302 15 : c.def(
303 4963 : "__gt__", [](T &a, Other &b) { return greater(a, RHSSetup{}(b)); },
304 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
305 15 : c.def(
306 40 : "__le__", [](T &a, Other &b) { return less_equal(a, RHSSetup{}(b)); },
307 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
308 15 : c.def(
309 : "__ge__",
310 47 : [](T &a, Other &b) { return greater_equal(a, RHSSetup{}(b)); },
311 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
312 15 : }
313 : };
314 :
315 : template <class Other, class T, class... Ignored>
316 18 : static void bind_in_place_binary(pybind11::class_<T, Ignored...> &c) {
317 18 : OpBinder<Identity>::in_place_binary<Other>(c);
318 18 : }
319 :
320 : template <class Other, class T, class... Ignored>
321 24 : static void bind_binary(pybind11::class_<T, Ignored...> &c) {
322 24 : OpBinder<Identity>::binary<Other>(c);
323 24 : }
324 :
325 : template <class Other, class T, class... Ignored>
326 9 : static void bind_comparison(pybind11::class_<T, Ignored...> &c) {
327 9 : OpBinder<Identity>::comparison<Other>(c);
328 9 : }
329 :
330 : template <class T, class... Ignored>
331 9 : void bind_in_place_binary_scalars(pybind11::class_<T, Ignored...> &c) {
332 9 : OpBinder<ScalarToVariable>::in_place_binary<double>(c);
333 9 : OpBinder<ScalarToVariable>::in_place_binary<int64_t>(c);
334 9 : }
335 :
336 : template <class T, class... Ignored>
337 9 : void bind_binary_scalars(pybind11::class_<T, Ignored...> &c) {
338 9 : OpBinder<ScalarToVariable>::binary<double>(c);
339 9 : OpBinder<ScalarToVariable>::binary<int64_t>(c);
340 9 : }
341 :
342 : template <class T, class... Ignored>
343 6 : static void bind_reverse_binary_scalars(pybind11::class_<T, Ignored...> &c) {
344 6 : OpBinder<ScalarToVariable>::reverse_binary<double>(c);
345 6 : OpBinder<ScalarToVariable>::reverse_binary<int64_t>(c);
346 6 : }
347 :
348 : template <class T, class... Ignored>
349 3 : void bind_comparison_scalars(pybind11::class_<T, Ignored...> &c) {
350 3 : OpBinder<ScalarToVariable>::comparison<double>(c);
351 3 : OpBinder<ScalarToVariable>::comparison<int64_t>(c);
352 3 : }
353 :
354 : template <class T, class... Ignored>
355 6 : void bind_unary(pybind11::class_<T, Ignored...> &c) {
356 6 : c.def(
357 32087 : "__neg__", [](const T &a) { return -a; }, py::is_operator(),
358 6 : py::call_guard<py::gil_scoped_release>());
359 6 : }
360 :
361 : template <class T, class... Ignored>
362 6 : void bind_boolean_unary(pybind11::class_<T, Ignored...> &c) {
363 6 : c.def(
364 208 : "__invert__", [](const T &a) { return ~a; }, py::is_operator(),
365 6 : py::call_guard<py::gil_scoped_release>());
366 6 : }
367 :
368 : template <class Other, class T, class... Ignored>
369 9 : void bind_logical(pybind11::class_<T, Ignored...> &c) {
370 : using T1 = const T;
371 : using T2 = const Other;
372 9 : c.def(
373 9 : "__or__", [](const T1 &a, const T2 &b) { return a | b; },
374 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
375 9 : c.def(
376 4 : "__xor__", [](const T1 &a, const T2 &b) { return a ^ b; },
377 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
378 9 : c.def(
379 12 : "__and__", [](const T1 &a, const T2 &b) { return a & b; },
380 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
381 9 : c.def(
382 : "__ior__",
383 62 : [](const py::object &a, const T2 &b) -> const py::object & {
384 62 : a.cast<T &>() |= b;
385 62 : return a;
386 : },
387 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
388 9 : c.def(
389 : "__ixor__",
390 9 : [](const py::object &a, const T2 &b) -> const py::object & {
391 9 : a.cast<T &>() ^= b;
392 9 : return a;
393 : },
394 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
395 9 : c.def(
396 : "__iand__",
397 2 : [](const py::object &a, const T2 &b) -> const py::object & {
398 2 : a.cast<T &>() &= b;
399 2 : return a;
400 : },
401 0 : py::is_operator(), py::call_guard<py::gil_scoped_release>());
402 9 : }
|