LCOV - code coverage report
Current view: top level - dataset - arithmetic.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 125 126 99.2 %
Date: 2024-12-01 01:56:34 Functions: 79 79 100.0 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #include "scipp/core/element/arithmetic.h"
       6             : #include "scipp/dataset/dataset.h"
       7             : #include "scipp/dataset/except.h"
       8             : #include "scipp/dataset/util.h"
       9             : #include "scipp/variable/arithmetic.h"
      10             : #include "scipp/variable/transform.h"
      11             : #include "scipp/variable/util.h"
      12             : 
      13             : #include "dataset_operations_common.h"
      14             : 
      15             : using namespace scipp::core;
      16             : 
      17             : namespace scipp::dataset {
      18             : 
      19             : namespace {
      20             : 
      21         141 : template <class T, class Op> void dry_run_op(T &&a, const Variable &b, Op op) {
      22             :   // This dry run relies on the knowledge that the implementation of operations
      23             :   // for variable simply calls transform_in_place and nothing else.
      24             :   // TODO use proper op name here once dataset ops are generated
      25         149 :   variable::dry_run::transform_in_place(a.data(), b, op, "binary_arithmetic");
      26         133 : }
      27             : 
      28         118 : template <class T, class Op> void dry_run_op(T &&a, const DataArray &b, Op op) {
      29         118 :   expect::coords_are_superset(a, b, "");
      30         110 :   dry_run_op(a, b.data(), op);
      31         102 : }
      32             : 
      33             : template <class Op, class A, class B>
      34          74 : auto &apply(const Op &op, A &a, const B &b) {
      35         158 :   for (const auto &item : b)
      36          96 :     dry_run_op(a[item.name()], item, op);
      37         108 :   for (const auto &item : b)
      38          54 :     op(a[item.name()], item);
      39          54 :   return a;
      40             : }
      41             : 
      42          67 : template <typename T> bool are_same(const T &a, const T &b) {
      43          67 :   return a.get() == b.get();
      44             : }
      45             : 
      46             : template <class A, class B>
      47             : bool have_common_underlying(const A &a, const B &b) {
      48             :   return are_same(a.data_handle(), b.data_handle());
      49             : }
      50             : 
      51             : template <>
      52          31 : bool have_common_underlying<DataArray, Variable>(const DataArray &a,
      53             :                                                  const Variable &b) {
      54          31 :   return are_same(a.data().data_handle(), b.data_handle());
      55             : }
      56             : 
      57             : template <>
      58          36 : bool have_common_underlying<DataArray, DataArray>(const DataArray &a,
      59             :                                                   const DataArray &b) {
      60          36 :   return are_same(a.data().data_handle(), b.data().data_handle());
      61             : }
      62             : 
      63             : template <class Op, class A, class B>
      64          69 : decltype(auto) apply_with_delay(const Op &op, A &&a, const B &b) {
      65         142 :   for (auto &&item : a)
      66          73 :     dry_run_op(item, b, op);
      67             :   // For `b` referencing data in `a` we delay operation. The alternative would
      68             :   // be to make a deep copy of `other` before starting the iteration over items.
      69          65 :   DataArray delayed;
      70             :   // Note the inefficiency here: We are comparing some or all of the coords for
      71             :   // each item. This could be improved by implementing the operations for
      72             :   // internal items of Dataset instead of DataArray.
      73         132 :   for (auto &&item : a) {
      74          67 :     if (have_common_underlying(item, b))
      75          32 :       delayed = item;
      76             :     else
      77          35 :       op(item, b);
      78             :   }
      79          65 :   if (delayed.is_valid())
      80          32 :     op(delayed, b);
      81         130 :   return std::forward<A>(a);
      82          65 : }
      83             : 
      84             : template <class Op, class A, class B>
      85          55 : auto apply_with_broadcast(const Op &op, const A &a, const B &b) {
      86          55 :   Dataset res;
      87         114 :   for (const auto &item : b)
      88          59 :     if (const auto it = a.find(item.name()); it != a.end())
      89          51 :       res.setDataInit(item.name(), op(*it, item));
      90         110 :   return std::move(res).or_empty();
      91          55 : }
      92             : 
      93             : template <class Op, class A>
      94           4 : auto apply_with_broadcast(const Op &op, const A &a, const DataArray &b) {
      95           4 :   Dataset res;
      96          12 :   for (const auto &item : a)
      97           8 :     res.setDataInit(item.name(), op(item, b));
      98           8 :   return std::move(res).or_empty();
      99           4 : }
     100             : 
     101             : template <class Op, class B>
     102           4 : auto apply_with_broadcast(const Op &op, const DataArray &a, const B &b) {
     103           4 :   Dataset res;
     104           4 :   for (const auto &item : b)
     105           0 :     res.setDataInit(item.name(), op(a, item));
     106           8 :   return std::move(res).or_empty();
     107           4 : }
     108             : 
     109             : template <class Op, class A>
     110          13 : auto apply_with_broadcast(const Op &op, const A &a, const Variable &b) {
     111          13 :   Dataset res;
     112          30 :   for (const auto &item : a)
     113          17 :     res.setDataInit(item.name(), op(item, b));
     114          26 :   return std::move(res).or_empty();
     115          13 : }
     116             : 
     117             : template <class Op, class B>
     118           8 : auto apply_with_broadcast(const Op &op, const Variable &a, const B &b) {
     119           8 :   Dataset res;
     120          20 :   for (const auto &item : b)
     121          12 :     res.setDataInit(item.name(), op(a, item));
     122          16 :   return std::move(res).or_empty();
     123           8 : }
     124             : 
     125             : } // namespace
     126             : 
     127          13 : Dataset &Dataset::operator+=(const DataArray &other) {
     128          13 :   return apply_with_delay(core::element::add_equals, *this, other);
     129             : }
     130             : 
     131           9 : Dataset &Dataset::operator-=(const DataArray &other) {
     132           9 :   return apply_with_delay(core::element::subtract_equals, *this, other);
     133             : }
     134             : 
     135           9 : Dataset &Dataset::operator*=(const DataArray &other) {
     136           9 :   return apply_with_delay(core::element::multiply_equals, *this, other);
     137             : }
     138             : 
     139           9 : Dataset &Dataset::operator/=(const DataArray &other) {
     140           9 :   return apply_with_delay(core::element::divide_equals, *this, other);
     141             : }
     142             : 
     143           9 : Dataset &Dataset::operator+=(const Variable &other) {
     144           9 :   return apply_with_delay(core::element::add_equals, *this, other);
     145             : }
     146             : 
     147           5 : Dataset &Dataset::operator-=(const Variable &other) {
     148           5 :   return apply_with_delay(core::element::subtract_equals, *this, other);
     149             : }
     150             : 
     151          10 : Dataset &Dataset::operator*=(const Variable &other) {
     152          10 :   return apply_with_delay(core::element::multiply_equals, *this, other);
     153             : }
     154             : 
     155           5 : Dataset &Dataset::operator/=(const Variable &other) {
     156           5 :   return apply_with_delay(core::element::divide_equals, *this, other);
     157             : }
     158             : 
     159          22 : Dataset &Dataset::operator+=(const Dataset &other) {
     160          22 :   return apply(core::element::add_equals, *this, other);
     161             : }
     162             : 
     163          18 : Dataset &Dataset::operator-=(const Dataset &other) {
     164          18 :   return apply(core::element::subtract_equals, *this, other);
     165             : }
     166             : 
     167          17 : Dataset &Dataset::operator*=(const Dataset &other) {
     168          17 :   return apply(core::element::multiply_equals, *this, other);
     169             : }
     170             : 
     171          17 : Dataset &Dataset::operator/=(const Dataset &other) {
     172          17 :   return apply(core::element::divide_equals, *this, other);
     173             : }
     174             : 
     175          27 : Dataset operator+(const Dataset &lhs, const Dataset &rhs) {
     176          27 :   return apply_with_broadcast(core::element::add, lhs, rhs);
     177             : }
     178             : 
     179           1 : Dataset operator+(const Dataset &lhs, const DataArray &rhs) {
     180           1 :   return apply_with_broadcast(core::element::add, lhs, rhs);
     181             : }
     182             : 
     183           1 : Dataset operator+(const DataArray &lhs, const Dataset &rhs) {
     184           1 :   return apply_with_broadcast(core::element::add, lhs, rhs);
     185             : }
     186             : 
     187           4 : Dataset operator+(const Dataset &lhs, const Variable &rhs) {
     188           4 :   return apply_with_broadcast(core::element::add, lhs, rhs);
     189             : }
     190             : 
     191           2 : Dataset operator+(const Variable &lhs, const Dataset &rhs) {
     192           2 :   return apply_with_broadcast(core::element::add, lhs, rhs);
     193             : }
     194             : 
     195          10 : Dataset operator-(const Dataset &lhs, const Dataset &rhs) {
     196          10 :   return apply_with_broadcast(core::element::subtract, lhs, rhs);
     197             : }
     198             : 
     199           1 : Dataset operator-(const Dataset &lhs, const DataArray &rhs) {
     200           1 :   return apply_with_broadcast(core::element::subtract, lhs, rhs);
     201             : }
     202             : 
     203           1 : Dataset operator-(const DataArray &lhs, const Dataset &rhs) {
     204           1 :   return apply_with_broadcast(core::element::subtract, lhs, rhs);
     205             : }
     206             : 
     207           3 : Dataset operator-(const Dataset &lhs, const Variable &rhs) {
     208           3 :   return apply_with_broadcast(core::element::subtract, lhs, rhs);
     209             : }
     210             : 
     211           2 : Dataset operator-(const Variable &lhs, const Dataset &rhs) {
     212           2 :   return apply_with_broadcast(core::element::subtract, lhs, rhs);
     213             : }
     214             : 
     215           9 : Dataset operator*(const Dataset &lhs, const Dataset &rhs) {
     216           9 :   return apply_with_broadcast(core::element::multiply, lhs, rhs);
     217             : }
     218             : 
     219           1 : Dataset operator*(const Dataset &lhs, const DataArray &rhs) {
     220           1 :   return apply_with_broadcast(core::element::multiply, lhs, rhs);
     221             : }
     222             : 
     223           1 : Dataset operator*(const DataArray &lhs, const Dataset &rhs) {
     224           1 :   return apply_with_broadcast(core::element::multiply, lhs, rhs);
     225             : }
     226             : 
     227           3 : Dataset operator*(const Dataset &lhs, const Variable &rhs) {
     228           3 :   return apply_with_broadcast(core::element::multiply, lhs, rhs);
     229             : }
     230             : 
     231           2 : Dataset operator*(const Variable &lhs, const Dataset &rhs) {
     232           2 :   return apply_with_broadcast(core::element::multiply, lhs, rhs);
     233             : }
     234             : 
     235           9 : Dataset operator/(const Dataset &lhs, const Dataset &rhs) {
     236           9 :   return apply_with_broadcast(core::element::divide, lhs, rhs);
     237             : }
     238             : 
     239           1 : Dataset operator/(const Dataset &lhs, const DataArray &rhs) {
     240           1 :   return apply_with_broadcast(core::element::divide, lhs, rhs);
     241             : }
     242             : 
     243           1 : Dataset operator/(const DataArray &lhs, const Dataset &rhs) {
     244           1 :   return apply_with_broadcast(core::element::divide, lhs, rhs);
     245             : }
     246             : 
     247           3 : Dataset operator/(const Dataset &lhs, const Variable &rhs) {
     248           3 :   return apply_with_broadcast(core::element::divide, lhs, rhs);
     249             : }
     250             : 
     251           2 : Dataset operator/(const Variable &lhs, const Dataset &rhs) {
     252           2 :   return apply_with_broadcast(core::element::divide, lhs, rhs);
     253             : }
     254             : 
     255             : } // namespace scipp::dataset

Generated by: LCOV version 1.14