Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #include "scipp/core/element/arithmetic.h"
6 : #include "scipp/dataset/dataset.h"
7 : #include "scipp/dataset/except.h"
8 : #include "scipp/dataset/util.h"
9 : #include "scipp/variable/arithmetic.h"
10 : #include "scipp/variable/transform.h"
11 : #include "scipp/variable/util.h"
12 :
13 : #include "dataset_operations_common.h"
14 :
15 : using namespace scipp::core;
16 :
17 : namespace scipp::dataset {
18 :
19 : namespace {
20 :
21 141 : template <class T, class Op> void dry_run_op(T &&a, const Variable &b, Op op) {
22 : // This dry run relies on the knowledge that the implementation of operations
23 : // for variable simply calls transform_in_place and nothing else.
24 : // TODO use proper op name here once dataset ops are generated
25 149 : variable::dry_run::transform_in_place(a.data(), b, op, "binary_arithmetic");
26 133 : }
27 :
28 118 : template <class T, class Op> void dry_run_op(T &&a, const DataArray &b, Op op) {
29 118 : expect::coords_are_superset(a, b, "");
30 110 : dry_run_op(a, b.data(), op);
31 102 : }
32 :
33 : template <class Op, class A, class B>
34 74 : auto &apply(const Op &op, A &a, const B &b) {
35 158 : for (const auto &item : b)
36 96 : dry_run_op(a[item.name()], item, op);
37 108 : for (const auto &item : b)
38 54 : op(a[item.name()], item);
39 54 : return a;
40 : }
41 :
42 67 : template <typename T> bool are_same(const T &a, const T &b) {
43 67 : return a.get() == b.get();
44 : }
45 :
46 : template <class A, class B>
47 : bool have_common_underlying(const A &a, const B &b) {
48 : return are_same(a.data_handle(), b.data_handle());
49 : }
50 :
51 : template <>
52 31 : bool have_common_underlying<DataArray, Variable>(const DataArray &a,
53 : const Variable &b) {
54 31 : return are_same(a.data().data_handle(), b.data_handle());
55 : }
56 :
57 : template <>
58 36 : bool have_common_underlying<DataArray, DataArray>(const DataArray &a,
59 : const DataArray &b) {
60 36 : return are_same(a.data().data_handle(), b.data().data_handle());
61 : }
62 :
63 : template <class Op, class A, class B>
64 69 : decltype(auto) apply_with_delay(const Op &op, A &&a, const B &b) {
65 142 : for (auto &&item : a)
66 73 : dry_run_op(item, b, op);
67 : // For `b` referencing data in `a` we delay operation. The alternative would
68 : // be to make a deep copy of `other` before starting the iteration over items.
69 65 : DataArray delayed;
70 : // Note the inefficiency here: We are comparing some or all of the coords for
71 : // each item. This could be improved by implementing the operations for
72 : // internal items of Dataset instead of DataArray.
73 132 : for (auto &&item : a) {
74 67 : if (have_common_underlying(item, b))
75 32 : delayed = item;
76 : else
77 35 : op(item, b);
78 : }
79 65 : if (delayed.is_valid())
80 32 : op(delayed, b);
81 130 : return std::forward<A>(a);
82 65 : }
83 :
84 : template <class Op, class A, class B>
85 55 : auto apply_with_broadcast(const Op &op, const A &a, const B &b) {
86 55 : Dataset res;
87 114 : for (const auto &item : b)
88 59 : if (const auto it = a.find(item.name()); it != a.end())
89 51 : res.setDataInit(item.name(), op(*it, item));
90 110 : return std::move(res).or_empty();
91 55 : }
92 :
93 : template <class Op, class A>
94 4 : auto apply_with_broadcast(const Op &op, const A &a, const DataArray &b) {
95 4 : Dataset res;
96 12 : for (const auto &item : a)
97 8 : res.setDataInit(item.name(), op(item, b));
98 8 : return std::move(res).or_empty();
99 4 : }
100 :
101 : template <class Op, class B>
102 4 : auto apply_with_broadcast(const Op &op, const DataArray &a, const B &b) {
103 4 : Dataset res;
104 4 : for (const auto &item : b)
105 0 : res.setDataInit(item.name(), op(a, item));
106 8 : return std::move(res).or_empty();
107 4 : }
108 :
109 : template <class Op, class A>
110 13 : auto apply_with_broadcast(const Op &op, const A &a, const Variable &b) {
111 13 : Dataset res;
112 30 : for (const auto &item : a)
113 17 : res.setDataInit(item.name(), op(item, b));
114 26 : return std::move(res).or_empty();
115 13 : }
116 :
117 : template <class Op, class B>
118 8 : auto apply_with_broadcast(const Op &op, const Variable &a, const B &b) {
119 8 : Dataset res;
120 20 : for (const auto &item : b)
121 12 : res.setDataInit(item.name(), op(a, item));
122 16 : return std::move(res).or_empty();
123 8 : }
124 :
125 : } // namespace
126 :
127 13 : Dataset &Dataset::operator+=(const DataArray &other) {
128 13 : return apply_with_delay(core::element::add_equals, *this, other);
129 : }
130 :
131 9 : Dataset &Dataset::operator-=(const DataArray &other) {
132 9 : return apply_with_delay(core::element::subtract_equals, *this, other);
133 : }
134 :
135 9 : Dataset &Dataset::operator*=(const DataArray &other) {
136 9 : return apply_with_delay(core::element::multiply_equals, *this, other);
137 : }
138 :
139 9 : Dataset &Dataset::operator/=(const DataArray &other) {
140 9 : return apply_with_delay(core::element::divide_equals, *this, other);
141 : }
142 :
143 9 : Dataset &Dataset::operator+=(const Variable &other) {
144 9 : return apply_with_delay(core::element::add_equals, *this, other);
145 : }
146 :
147 5 : Dataset &Dataset::operator-=(const Variable &other) {
148 5 : return apply_with_delay(core::element::subtract_equals, *this, other);
149 : }
150 :
151 10 : Dataset &Dataset::operator*=(const Variable &other) {
152 10 : return apply_with_delay(core::element::multiply_equals, *this, other);
153 : }
154 :
155 5 : Dataset &Dataset::operator/=(const Variable &other) {
156 5 : return apply_with_delay(core::element::divide_equals, *this, other);
157 : }
158 :
159 22 : Dataset &Dataset::operator+=(const Dataset &other) {
160 22 : return apply(core::element::add_equals, *this, other);
161 : }
162 :
163 18 : Dataset &Dataset::operator-=(const Dataset &other) {
164 18 : return apply(core::element::subtract_equals, *this, other);
165 : }
166 :
167 17 : Dataset &Dataset::operator*=(const Dataset &other) {
168 17 : return apply(core::element::multiply_equals, *this, other);
169 : }
170 :
171 17 : Dataset &Dataset::operator/=(const Dataset &other) {
172 17 : return apply(core::element::divide_equals, *this, other);
173 : }
174 :
175 27 : Dataset operator+(const Dataset &lhs, const Dataset &rhs) {
176 27 : return apply_with_broadcast(core::element::add, lhs, rhs);
177 : }
178 :
179 1 : Dataset operator+(const Dataset &lhs, const DataArray &rhs) {
180 1 : return apply_with_broadcast(core::element::add, lhs, rhs);
181 : }
182 :
183 1 : Dataset operator+(const DataArray &lhs, const Dataset &rhs) {
184 1 : return apply_with_broadcast(core::element::add, lhs, rhs);
185 : }
186 :
187 4 : Dataset operator+(const Dataset &lhs, const Variable &rhs) {
188 4 : return apply_with_broadcast(core::element::add, lhs, rhs);
189 : }
190 :
191 2 : Dataset operator+(const Variable &lhs, const Dataset &rhs) {
192 2 : return apply_with_broadcast(core::element::add, lhs, rhs);
193 : }
194 :
195 10 : Dataset operator-(const Dataset &lhs, const Dataset &rhs) {
196 10 : return apply_with_broadcast(core::element::subtract, lhs, rhs);
197 : }
198 :
199 1 : Dataset operator-(const Dataset &lhs, const DataArray &rhs) {
200 1 : return apply_with_broadcast(core::element::subtract, lhs, rhs);
201 : }
202 :
203 1 : Dataset operator-(const DataArray &lhs, const Dataset &rhs) {
204 1 : return apply_with_broadcast(core::element::subtract, lhs, rhs);
205 : }
206 :
207 3 : Dataset operator-(const Dataset &lhs, const Variable &rhs) {
208 3 : return apply_with_broadcast(core::element::subtract, lhs, rhs);
209 : }
210 :
211 2 : Dataset operator-(const Variable &lhs, const Dataset &rhs) {
212 2 : return apply_with_broadcast(core::element::subtract, lhs, rhs);
213 : }
214 :
215 9 : Dataset operator*(const Dataset &lhs, const Dataset &rhs) {
216 9 : return apply_with_broadcast(core::element::multiply, lhs, rhs);
217 : }
218 :
219 1 : Dataset operator*(const Dataset &lhs, const DataArray &rhs) {
220 1 : return apply_with_broadcast(core::element::multiply, lhs, rhs);
221 : }
222 :
223 1 : Dataset operator*(const DataArray &lhs, const Dataset &rhs) {
224 1 : return apply_with_broadcast(core::element::multiply, lhs, rhs);
225 : }
226 :
227 3 : Dataset operator*(const Dataset &lhs, const Variable &rhs) {
228 3 : return apply_with_broadcast(core::element::multiply, lhs, rhs);
229 : }
230 :
231 2 : Dataset operator*(const Variable &lhs, const Dataset &rhs) {
232 2 : return apply_with_broadcast(core::element::multiply, lhs, rhs);
233 : }
234 :
235 9 : Dataset operator/(const Dataset &lhs, const Dataset &rhs) {
236 9 : return apply_with_broadcast(core::element::divide, lhs, rhs);
237 : }
238 :
239 1 : Dataset operator/(const Dataset &lhs, const DataArray &rhs) {
240 1 : return apply_with_broadcast(core::element::divide, lhs, rhs);
241 : }
242 :
243 1 : Dataset operator/(const DataArray &lhs, const Dataset &rhs) {
244 1 : return apply_with_broadcast(core::element::divide, lhs, rhs);
245 : }
246 :
247 3 : Dataset operator/(const Dataset &lhs, const Variable &rhs) {
248 3 : return apply_with_broadcast(core::element::divide, lhs, rhs);
249 : }
250 :
251 2 : Dataset operator/(const Variable &lhs, const Dataset &rhs) {
252 2 : return apply_with_broadcast(core::element::divide, lhs, rhs);
253 : }
254 :
255 : } // namespace scipp::dataset
|