LCOV - code coverage report
Current view: top level - python - bind_operators.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 177 216 81.9 %
Date: 2024-04-28 01:25:40 Functions: 253 330 76.7 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #pragma once
       6             : 
       7             : #include "scipp/dataset/arithmetic.h"
       8             : #include "scipp/dataset/astype.h"
       9             : #include "scipp/dataset/generated_comparison.h"
      10             : #include "scipp/dataset/generated_logical.h"
      11             : #include "scipp/dataset/generated_math.h"
      12             : #include "scipp/dataset/to_unit.h"
      13             : #include "scipp/dataset/util.h"
      14             : #include "scipp/units/except.h"
      15             : #include "scipp/variable/arithmetic.h"
      16             : #include "scipp/variable/astype.h"
      17             : #include "scipp/variable/comparison.h"
      18             : #include "scipp/variable/logical.h"
      19             : #include "scipp/variable/pow.h"
      20             : #include "scipp/variable/to_unit.h"
      21             : 
      22             : #include "dtype.h"
      23             : #include "format.h"
      24             : #include "pybind11.h"
      25             : 
      26             : namespace py = pybind11;
      27             : 
      28             : template <class T, class... Ignored>
      29           9 : void bind_common_operators(pybind11::class_<T, Ignored...> &c) {
      30          25 :   c.def("__abs__", [](const T &self) { return abs(self); });
      31         220 :   c.def("__repr__", [](const T &self) { return to_string(self); });
      32          11 :   c.def("__bool__", [](const T &self) {
      33             :     if constexpr (std::is_same_v<T, scipp::Variable>) {
      34        4814 :       if (self.unit() != scipp::units::none)
      35           2 :         throw scipp::except::UnitError(
      36             :             "The truth value of a variable with unit is undefined.");
      37        4812 :       return self.template value<bool>() == true;
      38             :     }
      39           2 :     throw std::runtime_error("The truth value of a variable, data array, or "
      40             :                              "dataset is ambiguous. Use any() or all().");
      41             :   });
      42           9 :   c.def(
      43             :       "copy",
      44        8419 :       [](const T &self, const bool deep) { return deep ? copy(self) : self; },
      45           9 :       py::arg("deep") = true, py::call_guard<py::gil_scoped_release>(),
      46             :       R"(
      47             :       Return a (by default deep) copy.
      48             : 
      49             :       If `deep=True` (the default), a deep copy is made. Otherwise, a shallow
      50             :       copy is made, and the returned data (and meta data) values are new views
      51             :       of the data and meta data values of this object.)");
      52           9 :   c.def(
      53          13 :       "__copy__", [](const T &self) { return self; },
      54           0 :       py::call_guard<py::gil_scoped_release>(), "Return a (shallow) copy.");
      55           9 :   c.def(
      56             :        "__deepcopy__",
      57         182 :        [](const T &self, const py::dict &) { return copy(self); },
      58           9 :        py::call_guard<py::gil_scoped_release>(), "Return a (deep) copy.")
      59           9 :       .def(
      60             :           "__sizeof__",
      61         312 :           [](const T &self) {
      62         312 :             return size_of(self, scipp::SizeofTag::ViewOnly);
      63             :           },
      64             :           R"doc(Return the size of the object in bytes.
      65             : 
      66             : The size includes the object itself and all arrays contained in it.
      67             : But arrays may be counted multiple times if components share buffers,
      68             : e.g. multiple coordinates referencing the same memory.
      69             : Conversely, the size may be underestimated. Especially, but not only,
      70             : with dtype=PyObject.
      71             : 
      72             : This function only includes memory of the current slice. Use
      73             : ``underlying_size`` to get the full memory size of the underlying structure.)doc")
      74           9 :       .def(
      75             :           "underlying_size",
      76         312 :           [](const T &self) {
      77         312 :             return size_of(self, scipp::SizeofTag::Underlying);
      78             :           },
      79             :           R"doc(Return the size of the object in bytes.
      80             : 
      81             : The size includes the object itself and all arrays contained in it.
      82             : But arrays may be counted multiple times if components share buffers,
      83             : e.g. multiple coordinates referencing the same memory.
      84             : Conversely, the size may be underestimated. Especially, but not only,
      85             : with dtype=PyObject.
      86             : 
      87             : This function includes all memory of the underlying buffers. Use
      88             : ``__sizeof__`` to get the size of the current slice only.)doc");
      89           9 : }
      90             : 
      91             : template <class T, class... Ignored>
      92           6 : void bind_astype(py::class_<T, Ignored...> &c) {
      93           6 :   c.def(
      94             :       "astype",
      95        4855 :       [](const T &self, const py::object &type, const bool copy) {
      96        4855 :         const auto [scipp_dtype, dtype_unit] =
      97             :             cast_dtype_and_unit(type, DefaultUnit{});
      98        4860 :         if (dtype_unit.has_value() &&
      99          25 :             (dtype_unit != scipp::units::one && dtype_unit != self.unit())) {
     100           2 :           throw scipp::except::UnitError(scipp::python::format(
     101             :               "Conversion of units via the dtype is not allowed. Occurred when "
     102             :               "trying to change dtype from ",
     103           2 :               self.dtype(), " to ", type,
     104             :               ". Use to_unit in combination with astype."));
     105             :         }
     106        4854 :         [[maybe_unused]] py::gil_scoped_release release;
     107             :         return astype(self, scipp_dtype,
     108             :                       copy ? scipp::CopyPolicy::Always
     109        9704 :                            : scipp::CopyPolicy::TryAvoid);
     110        4854 :       },
     111          12 :       py::arg("type"), py::kw_only(), py::arg("copy") = true,
     112             :       R"(
     113             :         Converts a Variable or DataArray to a different dtype.
     114             : 
     115             :         If the dtype is unchanged and ``copy`` is `False`, the object
     116             :         is returned without making a deep copy.
     117             : 
     118             :         :param type: Target dtype.
     119             :         :param copy: If `False`, return the input object if possible.
     120             :                      If `True`, the function always returns a new object.
     121             :         :raises: If the data cannot be converted to the requested dtype.
     122             :         :return: New variable or data array with specified dtype.
     123             :         :rtype: Union[scipp.Variable, scipp.DataArray])");
     124           6 : }
     125             : 
     126             : template <class Other, class T, class... Ignored>
     127          12 : void bind_inequality_to_operator(pybind11::class_<T, Ignored...> &c) {
     128          12 :   c.def(
     129          48 :       "__eq__", [](const T &a, const Other &b) { return a == b; },
     130           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     131          12 :   c.def(
     132          16 :       "__ne__", [](const T &a, const Other &b) { return a != b; },
     133           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     134          12 : }
     135             : 
     136             : struct Identity {
     137       14904 :   template <class T> const T &operator()(const T &x) const noexcept {
     138       14904 :     return x;
     139             :   }
     140             : };
     141             : struct ScalarToVariable {
     142       63317 :   template <class T> scipp::Variable operator()(const T &x) const noexcept {
     143       63317 :     return x * scipp::units::one;
     144             :   }
     145             : };
     146             : 
     147             : template <class RHSSetup> struct OpBinder {
     148             :   template <class Other, class T, class... Ignored>
     149          36 :   static void in_place_binary(pybind11::class_<T, Ignored...> &c) {
     150             :     using namespace scipp;
     151             :     // In-place operators return py::object due to the way in-place operators
     152             :     // work in Python (assigning return value to this). This avoids extra
     153             :     // copies, and additionally ensures that all references to the object keep
     154             :     // referencing the same object after the operation.
     155          36 :     c.def(
     156             :         "__iadd__",
     157          59 :         [](py::object &a, Other &b) {
     158          59 :           a.cast<T &>() += RHSSetup{}(b);
     159          56 :           return a;
     160             :         },
     161           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     162          36 :     c.def(
     163             :         "__isub__",
     164          14 :         [](py::object &a, Other &b) {
     165          14 :           a.cast<T &>() -= RHSSetup{}(b);
     166          14 :           return a;
     167             :         },
     168           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     169          36 :     c.def(
     170             :         "__imul__",
     171          69 :         [](py::object &a, Other &b) {
     172          71 :           a.cast<T &>() *= RHSSetup{}(b);
     173          66 :           return a;
     174             :         },
     175           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     176          36 :     c.def(
     177             :         "__itruediv__",
     178          13 :         [](py::object &a, Other &b) {
     179          13 :           a.cast<T &>() /= RHSSetup{}(b);
     180          13 :           return a;
     181             :         },
     182           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     183             :     if constexpr (!(std::is_same_v<T, Dataset> ||
     184             :                     std::is_same_v<Other, Dataset>)) {
     185          21 :       c.def(
     186             :           "__imod__",
     187           0 :           [](py::object &a, Other &b) {
     188           0 :             a.cast<T &>() %= RHSSetup{}(b);
     189           0 :             return a;
     190             :           },
     191           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     192          21 :       c.def(
     193             :           "__ifloordiv__",
     194           2 :           [](py::object &a, Other &b) {
     195           2 :             floor_divide_equals(a.cast<T &>(), RHSSetup{}(b));
     196           2 :             return a;
     197             :           },
     198           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     199             :       if constexpr (!(std::is_same_v<T, DataArray> ||
     200             :                       std::is_same_v<Other, DataArray>)) {
     201           9 :         c.def(
     202             :             "__ipow__",
     203           4 :             [](T &base, Other &exponent) {
     204           8 :               return pow(base, RHSSetup{}(exponent), base);
     205             :             },
     206           0 :             py::is_operator(), py::call_guard<py::gil_scoped_release>());
     207             :       }
     208             :     }
     209          36 :   }
     210             : 
     211             :   template <class Other, class T, class... Ignored>
     212          36 :   static void binary(pybind11::class_<T, Ignored...> &c) {
     213             :     using namespace scipp;
     214          36 :     c.def(
     215        1191 :         "__add__", [](const T &a, const Other &b) { return a + RHSSetup{}(b); },
     216           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     217          36 :     c.def(
     218        3277 :         "__sub__", [](const T &a, const Other &b) { return a - RHSSetup{}(b); },
     219           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     220          36 :     c.def(
     221        2570 :         "__mul__", [](const T &a, const Other &b) { return a * RHSSetup{}(b); },
     222           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     223          36 :     c.def(
     224             :         "__truediv__",
     225        2847 :         [](const T &a, const Other &b) { return a / RHSSetup{}(b); },
     226           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     227             :     if constexpr (!(std::is_same_v<T, Dataset> ||
     228             :                     std::is_same_v<Other, Dataset>)) {
     229          24 :       c.def(
     230             :           "__floordiv__",
     231          88 :           [](const T &a, const Other &b) {
     232          88 :             return floor_divide(a, RHSSetup{}(b));
     233             :           },
     234           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     235          24 :       c.def(
     236             :           "__mod__",
     237          65 :           [](const T &a, const Other &b) { return a % RHSSetup{}(b); },
     238           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     239          24 :       c.def(
     240             :           "__pow__",
     241          41 :           [](const T &base, const Other &exponent) {
     242          41 :             return pow(base, RHSSetup{}(exponent));
     243             :           },
     244           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     245             :     }
     246          36 :   }
     247             : 
     248             :   template <class Other, class T, class... Ignored>
     249          12 :   static void reverse_binary(pybind11::class_<T, Ignored...> &c) {
     250             :     using namespace scipp;
     251          12 :     c.def(
     252        5967 :         "__radd__", [](const T &a, const Other b) { return RHSSetup{}(b) + a; },
     253           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     254          12 :     c.def(
     255           9 :         "__rsub__", [](const T &a, const Other b) { return RHSSetup{}(b)-a; },
     256           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     257          12 :     c.def(
     258       53486 :         "__rmul__", [](const T &a, const Other b) { return RHSSetup{}(b)*a; },
     259           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     260          12 :     c.def(
     261             :         "__rtruediv__",
     262        3366 :         [](const T &a, const Other b) { return RHSSetup{}(b) / a; },
     263           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     264             :     if constexpr (!(std::is_same_v<T, Dataset> ||
     265             :                     std::is_same_v<Other, Dataset>)) {
     266          12 :       c.def(
     267             :           "__rfloordiv__",
     268           4 :           [](const T &a, const Other &b) {
     269           4 :             return floor_divide(RHSSetup{}(b), a);
     270             :           },
     271           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     272          12 :       c.def(
     273             :           "__rmod__",
     274           4 :           [](const T &a, const Other &b) { return RHSSetup{}(b) % a; },
     275           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     276          12 :       c.def(
     277             :           "__rpow__",
     278           4 :           [](const T &exponent, const Other &base) {
     279           4 :             return pow(RHSSetup{}(base), exponent);
     280             :           },
     281           0 :           py::is_operator(), py::call_guard<py::gil_scoped_release>());
     282             :     }
     283          12 :   }
     284             : 
     285             :   template <class Other, class T, class... Ignored>
     286          15 :   static void comparison(pybind11::class_<T, Ignored...> &c) {
     287          15 :     c.def(
     288         104 :         "__eq__", [](T &a, Other &b) { return equal(a, RHSSetup{}(b)); },
     289           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     290          15 :     c.def(
     291          72 :         "__ne__", [](T &a, Other &b) { return not_equal(a, RHSSetup{}(b)); },
     292           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     293          15 :     c.def(
     294          75 :         "__lt__", [](T &a, Other &b) { return less(a, RHSSetup{}(b)); },
     295           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     296          15 :     c.def(
     297        4805 :         "__gt__", [](T &a, Other &b) { return greater(a, RHSSetup{}(b)); },
     298           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     299          15 :     c.def(
     300          38 :         "__le__", [](T &a, Other &b) { return less_equal(a, RHSSetup{}(b)); },
     301           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     302          15 :     c.def(
     303             :         "__ge__",
     304          47 :         [](T &a, Other &b) { return greater_equal(a, RHSSetup{}(b)); },
     305           0 :         py::is_operator(), py::call_guard<py::gil_scoped_release>());
     306          15 :   }
     307             : };
     308             : 
     309             : template <class Other, class T, class... Ignored>
     310          18 : static void bind_in_place_binary(pybind11::class_<T, Ignored...> &c) {
     311          18 :   OpBinder<Identity>::in_place_binary<Other>(c);
     312          18 : }
     313             : 
     314             : template <class Other, class T, class... Ignored>
     315          24 : static void bind_binary(pybind11::class_<T, Ignored...> &c) {
     316          24 :   OpBinder<Identity>::binary<Other>(c);
     317          24 : }
     318             : 
     319             : template <class Other, class T, class... Ignored>
     320           9 : static void bind_comparison(pybind11::class_<T, Ignored...> &c) {
     321           9 :   OpBinder<Identity>::comparison<Other>(c);
     322           9 : }
     323             : 
     324             : template <class T, class... Ignored>
     325           9 : void bind_in_place_binary_scalars(pybind11::class_<T, Ignored...> &c) {
     326           9 :   OpBinder<ScalarToVariable>::in_place_binary<double>(c);
     327           9 :   OpBinder<ScalarToVariable>::in_place_binary<int64_t>(c);
     328           9 : }
     329             : 
     330             : template <class T, class... Ignored>
     331           6 : void bind_binary_scalars(pybind11::class_<T, Ignored...> &c) {
     332           6 :   OpBinder<ScalarToVariable>::binary<double>(c);
     333           6 :   OpBinder<ScalarToVariable>::binary<int64_t>(c);
     334           6 : }
     335             : 
     336             : template <class T, class... Ignored>
     337           6 : static void bind_reverse_binary_scalars(pybind11::class_<T, Ignored...> &c) {
     338           6 :   OpBinder<ScalarToVariable>::reverse_binary<double>(c);
     339           6 :   OpBinder<ScalarToVariable>::reverse_binary<int64_t>(c);
     340           6 : }
     341             : 
     342             : template <class T, class... Ignored>
     343           3 : void bind_comparison_scalars(pybind11::class_<T, Ignored...> &c) {
     344           3 :   OpBinder<ScalarToVariable>::comparison<double>(c);
     345           3 :   OpBinder<ScalarToVariable>::comparison<int64_t>(c);
     346           3 : }
     347             : 
     348             : template <class T, class... Ignored>
     349           6 : void bind_unary(pybind11::class_<T, Ignored...> &c) {
     350           6 :   c.def(
     351        1077 :       "__neg__", [](const T &a) { return -a; }, py::is_operator(),
     352           6 :       py::call_guard<py::gil_scoped_release>());
     353           6 : }
     354             : 
     355             : template <class T, class... Ignored>
     356           6 : void bind_boolean_unary(pybind11::class_<T, Ignored...> &c) {
     357           6 :   c.def(
     358         199 :       "__invert__", [](const T &a) { return ~a; }, py::is_operator(),
     359           6 :       py::call_guard<py::gil_scoped_release>());
     360           6 : }
     361             : 
     362             : template <class Other, class T, class... Ignored>
     363           9 : void bind_logical(pybind11::class_<T, Ignored...> &c) {
     364             :   using T1 = const T;
     365             :   using T2 = const Other;
     366           9 :   c.def(
     367           8 :       "__or__", [](const T1 &a, const T2 &b) { return a | b; },
     368           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     369           9 :   c.def(
     370           4 :       "__xor__", [](const T1 &a, const T2 &b) { return a ^ b; },
     371           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     372           9 :   c.def(
     373           8 :       "__and__", [](const T1 &a, const T2 &b) { return a & b; },
     374           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     375           9 :   c.def(
     376             :       "__ior__",
     377           2 :       [](const py::object &a, const T2 &b) {
     378           2 :         a.cast<T &>() |= b;
     379           2 :         return a;
     380             :       },
     381           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     382           9 :   c.def(
     383             :       "__ixor__",
     384           9 :       [](const py::object &a, const T2 &b) {
     385           9 :         a.cast<T &>() ^= b;
     386           9 :         return a;
     387             :       },
     388           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     389           9 :   c.def(
     390             :       "__iand__",
     391           2 :       [](const py::object &a, const T2 &b) {
     392           2 :         a.cast<T &>() &= b;
     393           2 :         return a;
     394             :       },
     395           0 :       py::is_operator(), py::call_guard<py::gil_scoped_release>());
     396           9 : }

Generated by: LCOV version 1.14