LCOV - code coverage report
Current view: top level - variable - arithmetic.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 75 75 100.0 %
Date: 2024-12-01 01:56:34 Functions: 17 17 100.0 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : #include "scipp/variable/arithmetic.h"
       4             : #include "scipp/core/dtype.h"
       5             : #include "scipp/core/eigen.h"
       6             : #include "scipp/core/element/arithmetic.h"
       7             : #include "scipp/core/spatial_transforms.h"
       8             : #include "scipp/variable/astype.h"
       9             : #include "scipp/variable/pow.h"
      10             : #include "scipp/variable/transform.h"
      11             : #include "scipp/variable/variable_factory.h"
      12             : 
      13             : namespace scipp::variable {
      14             : 
      15             : namespace {
      16             : 
      17      143352 : bool is_transform_with_translation(const Variable &var) {
      18      286690 :   return var.dtype() == dtype<Eigen::Affine3d> ||
      19      286690 :          var.dtype() == dtype<scipp::core::Translation>;
      20             : }
      21             : 
      22          26 : auto make_factor(const Variable &prototype, const double value) {
      23          26 :   const auto unit = variableFactory().elem_unit(prototype) == units::none
      24          26 :                         ? units::none
      25          26 :                         : units::one;
      26          52 :   return astype(makeVariable<double>(Values{value}, unit),
      27          78 :                 variableFactory().elem_dtype(prototype));
      28             : }
      29             : 
      30             : /// True if a and b are correlated, currently only if referencing same.
      31      493001 : bool correlated(const Variable &a, const Variable &b) {
      32      493001 :   return variableFactory().has_variances(a) &&
      33      493001 :          variableFactory().has_variances(b) && a.is_same(b);
      34             : }
      35             : 
      36             : } // namespace
      37             : 
      38      305680 : Variable operator+(const Variable &a, const Variable &b) {
      39      305680 :   if (correlated(a, b))
      40          36 :     return a * make_factor(a, 2.0);
      41      305662 :   return transform(a, b, core::element::add, "add");
      42             : }
      43             : 
      44       34597 : Variable operator-(const Variable &a, const Variable &b) {
      45       34597 :   if (correlated(a, b))
      46           4 :     return a * make_factor(a, 0.0);
      47       34595 :   return transform(a, b, core::element::subtract, "subtract");
      48             : }
      49             : 
      50      143331 : Variable operator*(const Variable &a, const Variable &b) {
      51      143352 :   if (is_transform_with_translation(a) &&
      52          35 :       (is_transform_with_translation(b) ||
      53      143345 :        b.dtype() == dtype<Eigen::Vector3d>)) {
      54             :     return transform(a, b, core::element::apply_spatial_transformation,
      55          38 :                      std::string_view("apply_spatial_transformation"));
      56             :   } else {
      57      143311 :     if (correlated(a, b))
      58           2 :       return pow(a, make_factor(a, 2.0));
      59             :     return transform(a, b, core::element::multiply,
      60      286616 :                      std::string_view("multiply"));
      61             :   }
      62             : }
      63             : 
      64        8780 : Variable operator/(const Variable &a, const Variable &b) {
      65        8780 :   if (correlated(a, b))
      66           2 :     return pow(a, make_factor(a, 0.0));
      67        8779 :   return transform(a, b, core::element::divide, "divide");
      68             : }
      69             : 
      70         155 : Variable &operator+=(Variable &a, const Variable &b) {
      71         155 :   operator+=(Variable(a), b);
      72         145 :   return a;
      73             : }
      74             : 
      75          31 : Variable &operator-=(Variable &a, const Variable &b) {
      76          31 :   operator-=(Variable(a), b);
      77          29 :   return a;
      78             : }
      79             : 
      80          63 : Variable &operator*=(Variable &a, const Variable &b) {
      81          63 :   operator*=(Variable(a), b);
      82          61 :   return a;
      83             : }
      84             : 
      85          19 : Variable &operator/=(Variable &a, const Variable &b) {
      86          19 :   operator/=(Variable(a), b);
      87          17 :   return a;
      88             : }
      89             : 
      90           2 : Variable &floor_divide_equals(Variable &a, const Variable &b) {
      91           2 :   floor_divide_equals(Variable(a), b);
      92           2 :   return a;
      93             : }
      94             : 
      95         320 : Variable operator+=(Variable &&a, const Variable &b) {
      96         320 :   if (correlated(a, b))
      97           2 :     return a *= make_factor(a, 2.0);
      98         319 :   transform_in_place(a, b, core::element::add_equals,
      99         319 :                      std::string_view("add_equals"));
     100         308 :   return std::move(a);
     101             : }
     102             : 
     103          92 : Variable operator-=(Variable &&a, const Variable &b) {
     104          92 :   if (correlated(a, b))
     105           2 :     return a *= make_factor(a, 0.0);
     106          91 :   transform_in_place(a, b, core::element::subtract_equals,
     107          91 :                      std::string_view("subtract_equals"));
     108          89 :   return std::move(a);
     109             : }
     110             : 
     111         151 : Variable operator*=(Variable &&a, const Variable &b) {
     112         151 :   if (correlated(a, b))
     113           2 :     return pow(a, make_factor(a, 2.0), a);
     114         150 :   transform_in_place(a, b, core::element::multiply_equals,
     115         150 :                      std::string_view("multiply_equals"));
     116         145 :   return std::move(a);
     117             : }
     118             : 
     119          70 : Variable operator/=(Variable &&a, const Variable &b) {
     120          70 :   if (correlated(a, b))
     121           2 :     return pow(a, make_factor(a, 0.0), a);
     122          69 :   transform_in_place(a, b, core::element::divide_equals,
     123          69 :                      std::string_view("divide_equals"));
     124          67 :   return std::move(a);
     125             : }
     126             : 
     127           2 : Variable floor_divide_equals(Variable &&a, const Variable &b) {
     128           2 :   transform_in_place(a, b, core::element::floor_divide_equals,
     129           2 :                      std::string_view("divide_equals"));
     130           2 :   return std::move(a);
     131             : }
     132             : 
     133             : } // namespace scipp::variable

Generated by: LCOV version 1.14