LCOV - code coverage report
Current view: top level - python - bind_operators.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 178 217 82.0 %
Date: 2024-12-01 01:56:34 Functions: 256 341 75.1 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #pragma once
       6             : 
       7             : #include "scipp/dataset/arithmetic.h"
       8             : #include "scipp/dataset/astype.h"
       9             : #include "scipp/dataset/generated_comparison.h"
      10             : #include "scipp/dataset/generated_logical.h"
      11             : #include "scipp/dataset/generated_math.h"
      12             : #include "scipp/dataset/to_unit.h"
      13             : #include "scipp/dataset/util.h"
      14             : #include "scipp/units/except.h"
      15             : #include "scipp/variable/arithmetic.h"
      16             : #include "scipp/variable/astype.h"
      17             : #include "scipp/variable/comparison.h"
      18             : #include "scipp/variable/logical.h"
      19             : #include "scipp/variable/pow.h"
      20             : #include "scipp/variable/to_unit.h"
      21             : 
      22             : #include "dtype.h"
      23             : #include "format.h"
      24             : #include "pybind11.h"
      25             : 
      26             : namespace py = pybind11;
      27             : 
      28             : template <class T, class... Ignored>
      29           9 : void bind_common_operators(pybind11::class_<T, Ignored...> &c) {
      30          25 :   c.def("__abs__", [](const T &self) { return abs(self); });
      31         216 :   c.def("__repr__", [](const T &self) { return to_string(self); });
      32          11 :   c.def("__bool__", [](const T &self) {
      33             :     if constexpr (std::is_same_v<T, scipp::Variable>) {
      34        4974 :       if (self.unit() != scipp::units::none)
      35           2 :         throw scipp::except::UnitError(
      36             :             "The truth value of a variable with unit is undefined.");
      37        4972 :       return self.template value<bool>() == true;
      38             :     }
      39           2 :     throw std::runtime_error("The truth value of a variable, data array, or "
      40             :                              "dataset is ambiguous. Use any() or all().");
      41             :   });
      42           9 :   c.def(
      43             :       "copy",
      44        9017 :       [](const T &self, const bool deep) { return deep ? copy(self) : self; },
      45           9 :       py::arg("deep") = true, py::call_guard<py::gil_scoped_release>(),
      46             :       R"(
      47             :       Return a (by default deep) copy.
      48             : 
      49             :       If `deep=True` (the default), a deep copy is made. Otherwise, a shallow
      50             :       copy is made, and the returned data (and meta data) values are new views
      51             :       of the data and meta data values of this object.)");
      52           9 :   c.def(
      53          13 :       "__copy__", [](const T &self) { return self; },
      54           0 :       py::call_guard<py::gil_scoped_release>(), "Return a (shallow) copy.");
      55           9 :   c.def(
      56             :        "__deepcopy__",
      57         182 :        [](const T &self, const py::typing::Dict<py::object, py::object> &) {
      58         182 :          return copy(self);
      59             :        },
      60           9 :        py::call_guard<py::gil_scoped_release>(), "Return a (deep) copy.")
      61           9 :       .def(
      62             :           "__sizeof__",
      63         312 :           [](const T &self) {
      64         312 :             return size_of(self, scipp::SizeofTag::ViewOnly);
      65             :           },
      66             :           R"doc(Return the size of the object in bytes.
      67             : 
      68             : The size includes the object itself and all arrays contained in it.
      69             : But arrays may be counted multiple times if components share buffers,
      70             : e.g. multiple coordinates referencing the same memory.
      71             : Conversely, the size may be underestimated. Especially, but not only,
      72             : with dtype=PyObject.
      73             : 
      74             : This function only includes memory of the current slice. Use
      75             : ``underlying_size`` to get the full memory size of the underlying structure.)doc")
      76           9 :       .def(
      77             :           "underlying_size",
      78         312 :           [](const T &self) {
      79         312 :             return size_of(self, scipp::SizeofTag::Underlying);
      80             :           },
      81             :           R"doc(Return the size of the object in bytes.
      82             : 
      83             : The size includes the object itself and all arrays contained in it.
      84             : But arrays may be counted multiple times if components share buffers,
      85             : e.g. multiple coordinates referencing the same memory.
      86             : Conversely, the size may be underestimated. Especially, but not only,
      87             : with dtype=PyObject.
      88             : 
      89             : This function includes all memory of the underlying buffers. Use
      90             : ``__sizeof__`` to get the size of the current slice only.)doc");
      91           9 : }
      92             : 
      93             : template <class T, class... Ignored>
      94           6 : void bind_astype(py::class_<T, Ignored...> &c) {
      95           6 :   c.def(
      96             :       "astype",
      97        5007 :       [](const T &self, const py::object &type, const bool copy) {
      98        5007 :         const auto [scipp_dtype, dtype_unit] =
      99             :             cast_dtype_and_unit(type, DefaultUnit{});
     100        5012 :         if (dtype_unit.has_value() &&
     101          25 :             (dtype_unit != scipp::units::one && dtype_unit != self.unit())) {
     102           2 :           throw scipp::except::UnitError(scipp::python::format(
     103             :               "Conversion of units via the dtype is not allowed. Occurred when "
     104             :               "trying to change dtype from ",
     105           2 :               self.dtype(), " to ", type,
     106             :               ". Use to_unit in combination with astype."));
     107             :         }
     108        5006 :         [[maybe_unused]] py::gil_scoped_release release;
     109             :         return astype(self, scipp_dtype,
     110             :                       copy ? scipp::CopyPolicy::Always
     111       10008 :                            : scipp::CopyPolicy::TryAvoid);
     112        5006 :       },
     113          12 :       py::arg("type"), py::kw_only(), py::arg("copy") = true,
     114             :       R"(
     115             :         Converts a Variable or DataArray to a different dtype.
     116             : 
     117             :         If the dtype is unchanged and ``copy`` is `False`, the object
     118             :         is returned without making a deep copy.
     119             : 
     120             :         :param type: Target dtype.
     121             :         :param copy: If `False`, return the input object if possible.
     122             :                      If `True`, the function always returns a new object.
     123             :         :raises: If the data cannot be converted to the requested dtype.
     124             :         :return: New variable or data array with specified dtype.
     125             :         :rtype: Union[scipp.Variable, scipp.DataArray])");
     126           6 : }
     127             : 
     128             : template <class Other, class T, class... Ignored>
     129          12 : void bind_inequality_to_operator(pybind11::class_<T, Ignored...> &c) {
     130          12 :   c.def(
     131          48 :       "__eq__", [](const T &a, const Other &b) { return a == b; },
     132           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     133          12 :   c.def(
     134          16 :       "__ne__", [](const T &a, const Other &b) { return a != b; },
     135           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     136          12 : }
     137             : 
     138             : struct Identity {
     139       86344 :   template <class T> const T &operator()(const T &x) const noexcept {
     140       86344 :     return x;
     141             :   }
     142             : };
     143             : struct ScalarToVariable {
     144        9997 :   template <class T> scipp::Variable operator()(const T &x) const noexcept {
     145        9997 :     return x * scipp::units::one;
     146             :   }
     147             : };
     148             : 
     149             : template <class RHSSetup> struct OpBinder {
     150             :   template <class Other, class T, class... Ignored>
     151          36 :   static void in_place_binary(pybind11::class_<T, Ignored...> &c) {
     152             :     using namespace scipp;
     153             :     // In-place operators return py::object due to the way in-place operators
     154             :     // work in Python (assigning return value to this). This avoids extra
     155             :     // copies, and additionally ensures that all references to the object keep
     156             :     // referencing the same object after the operation.
     157             :     // WARNING: It is crucial to explicitly return 'py::object &' here.
     158             :     // Otherwise the py::object is returned by value, which increments the
     159             :     // reference count, which is not only suboptimal but also incorrect since
     160             :     // we have released the GIL via py::gil_scoped_release.
     161          36 :     c.def(
     162             :         "__iadd__",
     163          61 :         [](py::object &a, Other &b) -> py::object & {
     164          61 :           a.cast<T &>() += RHSSetup{}(b);
     165          58 :           return a;
     166             :         },
     167           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     168          36 :     c.def(
     169             :         "__isub__",
     170          14 :         [](py::object &a, Other &b) -> py::object & {
     171          14 :           a.cast<T &>() -= RHSSetup{}(b);
     172          14 :           return a;
     173             :         },
     174           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     175          36 :     c.def(
     176             :         "__imul__",
     177          69 :         [](py::object &a, Other &b) -> py::object & {
     178          71 :           a.cast<T &>() *= RHSSetup{}(b);
     179          66 :           return a;
     180             :         },
     181           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     182          36 :     c.def(
     183             :         "__itruediv__",
     184          13 :         [](py::object &a, Other &b) -> py::object & {
     185          13 :           a.cast<T &>() /= RHSSetup{}(b);
     186          13 :           return a;
     187             :         },
     188           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     189             :     if constexpr (!(std::is_same_v<T, Dataset> ||
     190             :                     std::is_same_v<Other, Dataset>)) {
     191          21 :       c.def(
     192             :           "__imod__",
     193           0 :           [](py::object &a, Other &b) -> py::object & {
     194           0 :             a.cast<T &>() %= RHSSetup{}(b);
     195           0 :             return a;
     196             :           },
     197           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     198          21 :       c.def(
     199             :           "__ifloordiv__",
     200           2 :           [](py::object &a, Other &b) -> py::object & {
     201           2 :             floor_divide_equals(a.cast<T &>(), RHSSetup{}(b));
     202           2 :             return a;
     203             :           },
     204           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     205             :       if constexpr (!(std::is_same_v<T, DataArray> ||
     206             :                       std::is_same_v<Other, DataArray>)) {
     207           9 :         c.def(
     208             :             "__ipow__",
     209           4 :             [](T &base, Other &exponent) -> T & {
     210           4 :               return pow(base, RHSSetup{}(exponent), base);
     211             :             },
     212           0 :             py::is_operator(), py::call_guard<py::gil_scoped_release>());
     213             :       }
     214             :     }
     215          36 :   }
     216             : 
     217             :   template <class Other, class T, class... Ignored>
     218          42 :   static void binary(pybind11::class_<T, Ignored...> &c) {
     219             :     using namespace scipp;
     220          42 :     c.def(
     221        1350 :         "__add__", [](const T &a, const Other &b) { return a + RHSSetup{}(b); },
     222           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     223          42 :     c.def(
     224        3629 :         "__sub__", [](const T &a, const Other &b) { return a - RHSSetup{}(b); },
     225           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     226          42 :     c.def(
     227       67488 :         "__mul__", [](const T &a, const Other &b) { return a * RHSSetup{}(b); },
     228           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     229          42 :     c.def(
     230             :         "__truediv__",
     231        8716 :         [](const T &a, const Other &b) { return a / RHSSetup{}(b); },
     232           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     233             :     if constexpr (!(std::is_same_v<T, Dataset> ||
     234             :                     std::is_same_v<Other, Dataset>)) {
     235          24 :       c.def(
     236             :           "__floordiv__",
     237          88 :           [](const T &a, const Other &b) {
     238          88 :             return floor_divide(a, RHSSetup{}(b));
     239             :           },
     240           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     241          24 :       c.def(
     242             :           "__mod__",
     243          65 :           [](const T &a, const Other &b) { return a % RHSSetup{}(b); },
     244           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     245          24 :       c.def(
     246             :           "__pow__",
     247          41 :           [](const T &base, const Other &exponent) {
     248          41 :             return pow(base, RHSSetup{}(exponent));
     249             :           },
     250           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     251             :     }
     252          42 :   }
     253             : 
     254             :   template <class Other, class T, class... Ignored>
     255          12 :   static void reverse_binary(pybind11::class_<T, Ignored...> &c) {
     256             :     using namespace scipp;
     257          12 :     c.def(
     258        8484 :         "__radd__", [](const T &a, const Other b) { return RHSSetup{}(b) + a; },
     259           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     260          12 :     c.def(
     261           9 :         "__rsub__", [](const T &a, const Other b) { return RHSSetup{}(b)-a; },
     262           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     263          12 :     c.def(
     264         963 :         "__rmul__", [](const T &a, const Other b) { return RHSSetup{}(b)*a; },
     265           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     266          12 :     c.def(
     267             :         "__rtruediv__",
     268          29 :         [](const T &a, const Other b) { return RHSSetup{}(b) / a; },
     269           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     270             :     if constexpr (!(std::is_same_v<T, Dataset> ||
     271             :                     std::is_same_v<Other, Dataset>)) {
     272          12 :       c.def(
     273             :           "__rfloordiv__",
     274           4 :           [](const T &a, const Other &b) {
     275           4 :             return floor_divide(RHSSetup{}(b), a);
     276             :           },
     277           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     278          12 :       c.def(
     279             :           "__rmod__",
     280           4 :           [](const T &a, const Other &b) { return RHSSetup{}(b) % a; },
     281           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     282          12 :       c.def(
     283             :           "__rpow__",
     284           4 :           [](const T &exponent, const Other &base) {
     285           4 :             return pow(RHSSetup{}(base), exponent);
     286             :           },
     287           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     288             :     }
     289          12 :   }
     290             : 
     291             :   template <class Other, class T, class... Ignored>
     292          15 :   static void comparison(pybind11::class_<T, Ignored...> &c) {
     293          15 :     c.def(
     294         106 :         "__eq__", [](T &a, Other &b) { return equal(a, RHSSetup{}(b)); },
     295           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     296          15 :     c.def(
     297          72 :         "__ne__", [](T &a, Other &b) { return not_equal(a, RHSSetup{}(b)); },
     298           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     299          15 :     c.def(
     300          76 :         "__lt__", [](T &a, Other &b) { return less(a, RHSSetup{}(b)); },
     301           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     302          15 :     c.def(
     303        4963 :         "__gt__", [](T &a, Other &b) { return greater(a, RHSSetup{}(b)); },
     304           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     305          15 :     c.def(
     306          40 :         "__le__", [](T &a, Other &b) { return less_equal(a, RHSSetup{}(b)); },
     307           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     308          15 :     c.def(
     309             :         "__ge__",
     310          47 :         [](T &a, Other &b) { return greater_equal(a, RHSSetup{}(b)); },
     311           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     312          15 :   }
     313             : };
     314             : 
     315             : template <class Other, class T, class... Ignored>
     316          18 : static void bind_in_place_binary(pybind11::class_<T, Ignored...> &c) {
     317          18 :   OpBinder<Identity>::in_place_binary<Other>(c);
     318          18 : }
     319             : 
     320             : template <class Other, class T, class... Ignored>
     321          24 : static void bind_binary(pybind11::class_<T, Ignored...> &c) {
     322          24 :   OpBinder<Identity>::binary<Other>(c);
     323          24 : }
     324             : 
     325             : template <class Other, class T, class... Ignored>
     326           9 : static void bind_comparison(pybind11::class_<T, Ignored...> &c) {
     327           9 :   OpBinder<Identity>::comparison<Other>(c);
     328           9 : }
     329             : 
     330             : template <class T, class... Ignored>
     331           9 : void bind_in_place_binary_scalars(pybind11::class_<T, Ignored...> &c) {
     332           9 :   OpBinder<ScalarToVariable>::in_place_binary<double>(c);
     333           9 :   OpBinder<ScalarToVariable>::in_place_binary<int64_t>(c);
     334           9 : }
     335             : 
     336             : template <class T, class... Ignored>
     337           9 : void bind_binary_scalars(pybind11::class_<T, Ignored...> &c) {
     338           9 :   OpBinder<ScalarToVariable>::binary<double>(c);
     339           9 :   OpBinder<ScalarToVariable>::binary<int64_t>(c);
     340           9 : }
     341             : 
     342             : template <class T, class... Ignored>
     343           6 : static void bind_reverse_binary_scalars(pybind11::class_<T, Ignored...> &c) {
     344           6 :   OpBinder<ScalarToVariable>::reverse_binary<double>(c);
     345           6 :   OpBinder<ScalarToVariable>::reverse_binary<int64_t>(c);
     346           6 : }
     347             : 
     348             : template <class T, class... Ignored>
     349           3 : void bind_comparison_scalars(pybind11::class_<T, Ignored...> &c) {
     350           3 :   OpBinder<ScalarToVariable>::comparison<double>(c);
     351           3 :   OpBinder<ScalarToVariable>::comparison<int64_t>(c);
     352           3 : }
     353             : 
     354             : template <class T, class... Ignored>
     355           6 : void bind_unary(pybind11::class_<T, Ignored...> &c) {
     356           6 :   c.def(
     357       32087 :       "__neg__", [](const T &a) { return -a; }, py::is_operator(),
     358           6 :       py::call_guard<py::gil_scoped_release>());
     359           6 : }
     360             : 
     361             : template <class T, class... Ignored>
     362           6 : void bind_boolean_unary(pybind11::class_<T, Ignored...> &c) {
     363           6 :   c.def(
     364         208 :       "__invert__", [](const T &a) { return ~a; }, py::is_operator(),
     365           6 :       py::call_guard<py::gil_scoped_release>());
     366           6 : }
     367             : 
     368             : template <class Other, class T, class... Ignored>
     369           9 : void bind_logical(pybind11::class_<T, Ignored...> &c) {
     370             :   using T1 = const T;
     371             :   using T2 = const Other;
     372           9 :   c.def(
     373           9 :       "__or__", [](const T1 &a, const T2 &b) { return a | b; },
     374           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     375           9 :   c.def(
     376           4 :       "__xor__", [](const T1 &a, const T2 &b) { return a ^ b; },
     377           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     378           9 :   c.def(
     379          12 :       "__and__", [](const T1 &a, const T2 &b) { return a & b; },
     380           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     381           9 :   c.def(
     382             :       "__ior__",
     383          62 :       [](const py::object &a, const T2 &b) -> const py::object & {
     384          62 :         a.cast<T &>() |= b;
     385          62 :         return a;
     386             :       },
     387           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     388           9 :   c.def(
     389             :       "__ixor__",
     390           9 :       [](const py::object &a, const T2 &b) -> const py::object & {
     391           9 :         a.cast<T &>() ^= b;
     392           9 :         return a;
     393             :       },
     394           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     395           9 :   c.def(
     396             :       "__iand__",
     397           2 :       [](const py::object &a, const T2 &b) -> const py::object & {
     398           2 :         a.cast<T &>() &= b;
     399           2 :         return a;
     400             :       },
     401           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     402           9 : }

Generated by: LCOV version 1.14