Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file Various transform functions for variables.
4 : ///
5 : /// The underlying mechanism of the implementation is as follows:
6 : /// 1. `visit<...>::apply` obtains the concrete underlying data type(s).
7 : /// 2. `Transform` is applied to that concrete container, calling
8 : /// `do_transform`. `Transform` essentially builds a callable accepting a
9 : /// container from a callable accepting an element of the container.
10 : /// 3. `do_transform` is essentially a fancy std::transform. It uses recursion
11 : /// to process optional flags (provided as base classes of the user-provided
12 : /// operator). It provides automatic handling of data that has variances in
13 : /// addition to values, calling a different transform implementation for each
14 : /// case (different instantiations of `transform_elements`).
15 : /// 4. The `transform_elements` function calls the overloaded operator for
16 : /// each element. This is also were multi-threading for the majority of
17 : /// scipp's operations is implemented.
18 : ///
19 : /// Handling of binned data is mostly hidden in this implementation, reducing
20 : /// code duplication:
21 : /// - `variableFactory()` is used for output creation and unit access.
22 : /// - `variableFactory()` is used in `visit.h` to obtain a direct pointer to the
23 : /// underlying buffer.
24 : /// - MultiIndex contains special handling for binned data, i.e., it can iterate
25 : /// the buffer in a binning-aware way.
26 : ///
27 : /// The mechanism for in-place transformation is mostly identical to the one
28 : /// outlined above.
29 : ///
30 : /// @author Simon Heybrock
31 : #pragma once
32 :
33 : #include <algorithm>
34 : #include <cassert>
35 : #include <string_view>
36 :
37 : #include "scipp/common/overloaded.h"
38 :
39 : #include "scipp/core/has_eval.h"
40 : #include "scipp/core/multi_index.h"
41 : #include "scipp/core/parallel.h"
42 : #include "scipp/core/transform_common.h"
43 : #include "scipp/core/value_and_variance.h"
44 : #include "scipp/core/values_and_variances.h"
45 :
46 : #include "scipp/variable/except.h"
47 : #include "scipp/variable/variable.h"
48 : #include "scipp/variable/variable_factory.h"
49 : #include "scipp/variable/visit.h"
50 :
51 : namespace scipp::variable {
52 :
53 : namespace detail {
54 :
55 : template <class T> struct has_variances : std::false_type {};
56 : template <class T>
57 : struct has_variances<ValueAndVariance<T>> : std::true_type {};
58 : template <class T>
59 : struct has_variances<ValuesAndVariances<T>> : std::true_type {};
60 : template <class T>
61 : inline constexpr bool has_variances_v = has_variances<T>::value;
62 :
63 : /// Helper for the transform implementation to unify iteration of data with and
64 : /// without variances.
65 : template <class T>
66 1333055266 : static constexpr decltype(auto) value_maybe_variance(T &&range,
67 : const scipp::index i) {
68 : if constexpr (has_variances_v<std::decay_t<T>>) {
69 7962176 : return ValueAndVariance{range.values.data()[i], range.variances.data()[i]};
70 : } else {
71 1325093090 : return range.data()[i];
72 : }
73 : }
74 :
75 4575681 : template <class T> static constexpr auto array_params(T &&iterable) noexcept {
76 : if constexpr (is_ValuesAndVariances_v<std::decay_t<T>>)
77 26871 : return iterable.values;
78 : else
79 4548810 : return iterable;
80 : }
81 :
82 : template <size_t N_Operands, bool in_place>
83 : inline constexpr auto stride_special_cases =
84 : std::array<std::array<scipp::index, N_Operands>, 0>{};
85 :
86 : template <>
87 : inline constexpr auto stride_special_cases<1, true> =
88 : std::array<std::array<scipp::index, 1>, 2>{{{1}}};
89 :
90 : template <>
91 : inline constexpr auto stride_special_cases<2, true> =
92 : std::array<std::array<scipp::index, 2>, 4>{{{1, 1}, {0, 1}, {1, 0}}};
93 :
94 : template <>
95 : inline constexpr auto stride_special_cases<2, false> =
96 : std::array<std::array<scipp::index, 2>, 1>{{{1, 1}}};
97 :
98 : template <>
99 : inline constexpr auto stride_special_cases<3, false> =
100 : std::array<std::array<scipp::index, 3>, 3>{
101 : {{1, 1, 1}, {1, 0, 1}, {1, 1, 0}}};
102 :
103 : template <>
104 : inline constexpr auto stride_special_cases<4, false> =
105 : std::array<std::array<scipp::index, 4>, 4>{
106 : {{1, 1, 1, 1}, {1, 0, 1, 1}, {1, 1, 0, 1}, {1, 1, 1, 0}}};
107 :
108 : template <>
109 : inline constexpr auto stride_special_cases<5, false> =
110 : std::array<std::array<scipp::index, 5>, 4>{
111 : {{1, 1, 1, 1, 1}, {1, 1, 1, 1, 0}, {1, 1, 1, 0, 0}, {1, 1, 0, 0, 0}}};
112 :
113 : template <size_t I, size_t N_Operands, bool in_place, size_t... Is>
114 : auto stride_sequence_impl(std::index_sequence<Is...>) -> std::integer_sequence<
115 : scipp::index, stride_special_cases<N_Operands, in_place>.at(I)[Is]...>;
116 : // THe above uses std::array::at instead of operator[] in order to circumvent
117 : // a false positive error in MSVC 19.
118 :
119 : template <size_t I, size_t N_Operands, bool in_place> struct stride_sequence {
120 : using type = decltype(stride_sequence_impl<I, N_Operands, in_place>(
121 : std::make_index_sequence<N_Operands>{}));
122 : };
123 :
124 : template <size_t I, size_t N_Operands, bool in_place>
125 : using make_stride_sequence =
126 : typename stride_sequence<I, N_Operands, in_place>::type;
127 :
128 : template <scipp::index... Strides, size_t... Is>
129 485153574 : void increment_impl(std::array<scipp::index, sizeof...(Strides)> &indices,
130 : std::integer_sequence<size_t, Is...>) noexcept {
131 485153574 : ((indices[Is] += Strides), ...);
132 485153574 : }
133 :
134 : template <scipp::index... Strides>
135 485153574 : void increment(std::array<scipp::index, sizeof...(Strides)> &indices) noexcept {
136 485153574 : increment_impl<Strides...>(indices,
137 : std::make_index_sequence<sizeof...(Strides)>{});
138 485153574 : }
139 :
140 : template <size_t N>
141 88221365 : void increment(std::array<scipp::index, N> &indices,
142 : const scipp::span<const scipp::index> strides) noexcept {
143 375077358 : for (size_t i = 0; i < N; ++i) {
144 286855993 : indices[i] += strides[i];
145 : }
146 88221365 : }
147 :
148 : template <class Op, class Indices, class... Args, size_t... I>
149 89428048 : static constexpr auto call_impl(Op &&op, const Indices &indices,
150 : std::index_sequence<I...>, Args &&...args) {
151 89428048 : return op(value_maybe_variance(args, indices[I + 1])...);
152 : }
153 : template <class Op, class Indices, class Out, class... Args>
154 89428048 : static constexpr void call(Op &&op, const Indices &indices, Out &&out,
155 : Args &&...args) {
156 89428048 : const auto i = indices.front();
157 89428048 : auto &&out_ = value_maybe_variance(out, i);
158 89428048 : out_ = call_impl(std::forward<Op>(op), indices,
159 : std::make_index_sequence<sizeof...(Args)>{},
160 : std::forward<Args>(args)...);
161 : if constexpr (is_ValueAndVariance_v<std::decay_t<decltype(out_)>>) {
162 2777 : out.values.data()[i] = out_.value;
163 2777 : out.variances.data()[i] = out_.variance;
164 : }
165 89428048 : }
166 :
167 : template <class Op, class Indices, class Arg, class... Args, size_t... I>
168 483946892 : static constexpr void call_in_place_impl(Op &&op, const Indices &indices,
169 : std::index_sequence<I...>, Arg &&arg,
170 : Args &&...args) {
171 : static_assert(std::is_same_v<decltype(op(arg, value_maybe_variance(
172 : args, indices[I + 1])...)),
173 : void>);
174 483946892 : op(arg, value_maybe_variance(args, indices[I + 1])...);
175 483946891 : }
176 : template <class Op, class Indices, class Arg, class... Args>
177 483946892 : static constexpr void call_in_place(Op &&op, const Indices &indices, Arg &&arg,
178 : Args &&...args) {
179 483946892 : const auto i = indices.front();
180 : // For dense data we conditionally create ValueAndVariance, which performs an
181 : // element copy, so the result may have to be updated after the call to `op`.
182 483946892 : auto &&arg_ = value_maybe_variance(arg, i);
183 483946892 : call_in_place_impl(std::forward<Op>(op), indices,
184 : std::make_index_sequence<sizeof...(Args)>{},
185 : std::forward<decltype(arg_)>(arg_),
186 : std::forward<Args>(args)...);
187 : if constexpr (is_ValueAndVariance_v<std::decay_t<decltype(arg_)>>) {
188 3976278 : arg.values.data()[i] = arg_.value;
189 3976278 : arg.variances.data()[i] = arg_.variance;
190 : }
191 483946891 : }
192 : /// Run transform with strides known at compile time.
193 : template <bool in_place, class Op, class... Operands, scipp::index... Strides>
194 9656407 : static void inner_loop(Op &&op,
195 : std::array<scipp::index, sizeof...(Operands)> indices,
196 : std::integer_sequence<scipp::index, Strides...>,
197 : const scipp::index n, Operands &&...operands) {
198 : static_assert(sizeof...(Operands) == sizeof...(Strides));
199 :
200 494809981 : for (scipp::index i = 0; i < n; ++i) {
201 : if constexpr (in_place) {
202 404768922 : detail::call_in_place(op, indices, std::forward<Operands>(operands)...);
203 : } else {
204 80384652 : detail::call(op, indices, std::forward<Operands>(operands)...);
205 : }
206 485153574 : detail::increment<Strides...>(indices);
207 : }
208 9656407 : }
209 :
210 : /// Run transform with strides known at run time but bypassing MultiIndex.
211 : template <bool in_place, class Op, class... Operands>
212 1388499 : static void inner_loop(Op &&op,
213 : std::array<scipp::index, sizeof...(Operands)> indices,
214 : const scipp::span<const scipp::index> strides,
215 : const scipp::index n, Operands &&...operands) {
216 89609864 : for (scipp::index i = 0; i < n; ++i) {
217 : if constexpr (in_place) {
218 79177970 : detail::call_in_place(op, indices, std::forward<Operands>(operands)...);
219 : } else {
220 9043396 : detail::call(op, indices, std::forward<Operands>(operands)...);
221 : }
222 88221365 : detail::increment(indices, strides);
223 : }
224 1388498 : }
225 :
226 : template <bool in_place, size_t I = 0, class Op, class... Operands>
227 19219528 : static void dispatch_inner_loop(
228 : Op &&op, const std::array<scipp::index, sizeof...(Operands)> &indices,
229 : const scipp::span<const scipp::index> inner_strides, const scipp::index n,
230 : Operands &&...operands) {
231 19219528 : constexpr auto N_Operands = sizeof...(Operands);
232 : if constexpr (I ==
233 : detail::stride_special_cases<N_Operands, in_place>.size()) {
234 1388499 : inner_loop<in_place>(std::forward<Op>(op), indices, inner_strides, n,
235 : std::forward<Operands>(operands)...);
236 : } else {
237 17831029 : if (std::equal(
238 : inner_strides.begin(), inner_strides.end(),
239 17831029 : detail::stride_special_cases<N_Operands, in_place>[I].begin())) {
240 9656407 : inner_loop<in_place>(
241 : std::forward<Op>(op), indices,
242 : detail::make_stride_sequence<I, N_Operands, in_place>{}, n,
243 : std::forward<Operands>(operands)...);
244 : } else {
245 8174622 : dispatch_inner_loop<in_place, I + 1>(op, indices, inner_strides, n,
246 : std::forward<Operands>(operands)...);
247 : }
248 : }
249 19219527 : }
250 :
251 : template <class Op, class Out, class... Ts>
252 1184878 : static void transform_elements(Op op, Out &&out, Ts &&...other) {
253 1184878 : const auto begin =
254 : core::MultiIndex(array_params(out), array_params(other)...);
255 :
256 2362104 : auto run = [&](auto &indices, const auto &end) {
257 1184878 : const auto inner_strides = indices.inner_strides();
258 3641608 : while (indices != end) {
259 : // Shape can change when moving between bins -> recompute every time.
260 2456730 : const auto inner_size = indices.in_same_chunk(end, 1)
261 2456730 : ? indices.inner_distance_to(end)
262 : : indices.inner_distance_to_end();
263 2456730 : dispatch_inner_loop<false>(op, indices.get(), inner_strides, inner_size,
264 : std::forward<Out>(out),
265 2456730 : std::forward<Ts>(other)...);
266 2456730 : indices.increment_by(inner_size != 0 ? inner_size : 1);
267 : }
268 : };
269 :
270 2362104 : auto run_parallel = [&](const auto &range) {
271 1184878 : auto indices = begin;
272 1184878 : indices.set_index(range.begin());
273 1184878 : auto end = begin;
274 1184878 : end.set_index(range.end());
275 1184878 : run(indices, end);
276 : };
277 1184878 : core::parallel::parallel_for(core::parallel::blocked_range(0, out.size()),
278 : run_parallel);
279 1184878 : }
280 :
281 : template <class T> static constexpr auto maybe_eval(T &&_) {
282 : if constexpr (core::has_eval_v<std::decay_t<T>>)
283 : return _.eval();
284 : else
285 : return std::forward<T>(_);
286 : }
287 :
288 : template <class Op, class... Args>
289 : constexpr bool check_all_or_none_variances =
290 : std::is_base_of_v<core::transform_flags::expect_all_or_none_have_variance_t,
291 : Op> &&
292 : !std::conjunction_v<is_ValuesAndVariances<std::decay_t<Args>>...> &&
293 : std::disjunction_v<is_ValuesAndVariances<std::decay_t<Args>>...>;
294 :
295 : /// Recursion endpoint for do_transform.
296 : ///
297 : /// Call transform_elements with or without variances for output, depending on
298 : /// whether any of the arguments has variances or not.
299 : template <class Op, class Out, class Tuple>
300 1184879 : static void do_transform(Op op, Out &&out, Tuple &&processed) {
301 1184879 : auto out_val = out.values();
302 1184879 : std::apply(
303 3553027 : [&op, &out, &out_val](auto &&...args) {
304 : if constexpr (check_all_or_none_variances<Op, decltype(args)...>) {
305 1 : throw except::VariancesError(
306 : "Expected either all or none of inputs to have variances.");
307 : } else if constexpr (
308 : !std::is_base_of_v<core::transform_flags::no_out_variance_t, Op> &&
309 : core::canHaveVariances<typename Out::value_type>() &&
310 : (is_ValuesAndVariances_v<std::decay_t<decltype(args)>> || ...)) {
311 804 : auto out_var = out.variances();
312 804 : transform_elements(op, ValuesAndVariances{out_val, out_var},
313 : std::forward<decltype(args)>(args)...);
314 804 : } else {
315 1184074 : transform_elements(op, out_val,
316 : std::forward<decltype(args)>(args)...);
317 : }
318 : },
319 : std::forward<Tuple>(processed));
320 1184879 : }
321 :
322 : /// Helper for transform implementation, performing branching between output
323 : /// with and without variances as well as handling other operands with and
324 : /// without variances.
325 : template <class Op, class Out, class Tuple, class Arg, class... Args>
326 1952100 : static void do_transform(Op op, Out &&out, Tuple &&processed, const Arg &arg,
327 : const Args &...args) {
328 1952100 : auto vals = arg.values();
329 1952100 : if (arg.has_variances()) {
330 : if constexpr (std::is_base_of_v<
331 : core::transform_flags::expect_no_variance_arg_t<
332 : std::tuple_size_v<Tuple>>,
333 : Op>) {
334 1 : throw except::VariancesError("Variances in argument " +
335 : std::to_string(std::tuple_size_v<Tuple>) +
336 : " not supported.");
337 : } else if constexpr (
338 : std::is_base_of_v<
339 : core::transform_flags::
340 : expect_no_in_variance_if_out_cannot_have_variance_t,
341 : Op> &&
342 : !core::canHaveVariances<typename Out::value_type>()) {
343 0 : throw except::VariancesError(
344 : "Variances in argument " + std::to_string(std::tuple_size_v<Tuple>) +
345 : " not supported as output dtype cannot have variances");
346 : } else if constexpr (core::canHaveVariances<typename Arg::value_type>()) {
347 1495 : auto vars = arg.variances();
348 1495 : do_transform(
349 : op, std::forward<Out>(out),
350 1497 : std::tuple_cat(processed, std::tuple(ValuesAndVariances{vals, vars})),
351 : args...);
352 1495 : }
353 : // else {} // Cannot happen because args.has_variances()
354 : // implies canHaveVariances<value_type>.
355 : // The 2nd test is needed to avoid compilation errors
356 : // (has_variances is a runtime check).
357 : } else {
358 : if constexpr (std::is_base_of_v<
359 : core::transform_flags::expect_variance_arg_t<
360 : std::tuple_size_v<Tuple>>,
361 : Op>)
362 0 : throw except::VariancesError("Variances missing in argument " +
363 : std::to_string(std::tuple_size_v<Tuple>) +
364 : " . Must be set.");
365 1950605 : do_transform(op, std::forward<Out>(out),
366 3901209 : std::tuple_cat(processed, std::tuple(vals)), args...);
367 : }
368 1952100 : }
369 :
370 : template <class T> struct as_view {
371 : using value_type = typename T::value_type;
372 3390821 : [[nodiscard]] bool has_variances() const { return data.has_variances(); }
373 3390821 : auto values() const { return decltype(data.values())(data.values(), dims); }
374 26076 : auto variances() const {
375 26076 : return decltype(data.variances())(data.variances(), dims);
376 : }
377 : T &data;
378 : const Dimensions &dims;
379 : };
380 : template <class T> as_view(T &data, const Dimensions &dims) -> as_view<T>;
381 :
382 : template <class T>
383 2102182 : bool bad_variance_broadcast(const Dimensions &dims, const T &var) {
384 2102182 : if (!var.has_variances())
385 2100022 : return false;
386 : // implicit broadcast
387 2160 : if (dims.ndim() > var.dims().ndim())
388 10 : return true;
389 : // there may be a stride==0 when an inner dim has length==0
390 2150 : if (dims.volume() == 0)
391 22 : return false;
392 : // explicit broadcast
393 4256 : return std::any_of(var.strides().begin(), var.strides().end(),
394 4063 : [](scipp::index s) { return s == 0; });
395 : }
396 :
397 : template <class... Vars>
398 14 : [[noreturn]] void throw_variances_broadcast_error(Vars &&...vars) {
399 48 : throw except::VariancesError(
400 : "Cannot broadcast object with variances as this would introduce "
401 : "unhandled correlations. Input dimensions were:\n" +
402 : ((to_string(vars.dims()) +
403 34 : " variances=" + (vars.has_variances() ? "True" : "False") + '\n') +
404 : ...) +
405 : "\n" + "See https://doi.org/10.3233/JNR-220049 for more background.");
406 : }
407 :
408 : template <class Op> struct Transform {
409 : Op op;
410 1184904 : template <class... Ts> Variable operator()(Ts &&...handles) const {
411 1940499 : const auto dims = merge(handles.dims()...);
412 :
413 : if constexpr (!std::is_base_of_v<
414 : core::transform_flags::force_variance_broadcast_t, Op>) {
415 1134620 : if ((bad_variance_broadcast(dims, handles) || ...))
416 8 : throw_variances_broadcast_error(handles...);
417 1134612 : if ((handles.is_bins() || ...))
418 152 : if (((handles.has_variances() && !handles.is_bins()) || ...))
419 1 : throw_variances_broadcast_error(handles...);
420 : }
421 :
422 : using Out = decltype(maybe_eval(op(handles.values()[0]...)));
423 1184895 : const bool variances =
424 : !std::is_base_of_v<core::transform_flags::no_out_variance_t, Op> &&
425 504667 : core::canHaveVariances<Out>() && (handles.has_variances() || ...);
426 1184895 : auto unit = op.base_op()(variableFactory().elem_unit(*handles.m_var)...);
427 1184881 : auto out = variableFactory().create(dtype<Out>, dims, unit, variances,
428 1184881 : *handles.m_var...);
429 1184880 : do_transform(op, variable_access<Out>(out), std::tuple<>(),
430 : as_view{handles, dims}...);
431 2369756 : return out;
432 1184906 : }
433 : };
434 : template <class Op> Transform(Op) -> Transform<Op>;
435 :
436 : // std::tuple_cat does not work correctly on with clang-7. Issue with
437 : // Eigen::Vector3d.
438 : template <typename T, typename...> struct tuple_cat {
439 : using type = T;
440 : };
441 : template <template <typename...> class C, typename... Ts1, typename... Ts2,
442 : typename... Ts3>
443 : struct tuple_cat<C<Ts1...>, C<Ts2...>, Ts3...>
444 : : public tuple_cat<C<Ts1..., Ts2...>, Ts3...> {};
445 :
446 : template <class Op> struct wrap_eigen : Op {
447 1184895 : const Op &base_op() const noexcept { return *this; }
448 89428048 : template <class... Ts> constexpr auto operator()(Ts &&...args) const {
449 : if constexpr ((core::has_eval_v<std::decay_t<Ts>> || ...))
450 : // WARNING! The explicit specification of the template arguments of
451 : // operator() is EXTREMELY IMPORTANT. It ensures that Eigen types are
452 : // passed BY REFERENCE and NOT BY VALUE. Passing by value leads to
453 : // construction of expressions of values on the stack, which are then
454 : // returned from the operator. One way to identify this is using
455 : // address-sanitizer, which finds a `stack-use-after-scope`.
456 155 : return Op::template operator()<Ts...>(std::forward<Ts>(args)...);
457 : else
458 89427893 : return Op::template operator()(std::forward<Ts>(args)...);
459 : }
460 : };
461 : template <class... Ts> wrap_eigen(Ts...) -> wrap_eigen<Ts...>;
462 : } // namespace detail
463 :
464 : template <class... Ts, class Op>
465 1809754 : static constexpr auto type_tuples(Op) noexcept {
466 : if constexpr (sizeof...(Ts) == 0)
467 1066505 : return typename Op::types{};
468 : else if constexpr ((visit_detail::is_tuple<Ts>::value || ...))
469 66 : return typename detail::tuple_cat<Ts...>::type{};
470 : else
471 743183 : return std::tuple<Ts...>{};
472 : }
473 :
474 802572 : constexpr auto overlaps = [](const auto &a, const auto &b) {
475 : if constexpr (std::is_same_v<typename std::decay_t<decltype(a)>::value_type,
476 : typename std::decay_t<decltype(b)>::value_type>)
477 511417 : return a.values().overlaps(b.values());
478 : else
479 291155 : return false;
480 : };
481 :
482 : /// Helper class wrapping functions for in-place transform.
483 : ///
484 : /// The dry_run template argument can be used to disable any actual modification
485 : /// of data. This is used to implement operations on datasets with a strong
486 : /// exception guarantee.
487 : template <bool dry_run> struct in_place {
488 : template <class Op, class T, class... Ts>
489 636338 : static void transform_in_place_impl(Op op, T &&arg, Ts &&...other) {
490 : using namespace detail;
491 636356 : const auto begin =
492 : core::MultiIndex(array_params(arg), array_params(other)...);
493 : if constexpr (dry_run)
494 133 : return;
495 :
496 1272116 : auto run = [&](auto &indices, const auto &end) {
497 636196 : const auto inner_strides = indices.inner_strides();
498 9224371 : while (indices != end) {
499 : // Shape can change when moving between bins -> recompute every time.
500 8588176 : const auto inner_size = indices.in_same_chunk(end, 1)
501 8588176 : ? indices.inner_distance_to(end)
502 : : indices.inner_distance_to_end();
503 8588176 : detail::dispatch_inner_loop<true>(op, indices.get(), inner_strides,
504 : inner_size, std::forward<T>(arg),
505 8565622 : std::forward<Ts>(other)...);
506 8588175 : indices.increment_by(inner_size != 0 ? inner_size : 1);
507 : }
508 : };
509 636196 : if (begin.has_stride_zero()) {
510 : // The output has a dimension with stride zero so parallelization must
511 : // be done differently. See parallelization in accumulate.h.
512 93359 : auto indices = begin;
513 93359 : auto end = begin;
514 93359 : end.set_index(arg.size());
515 93359 : run(indices, end);
516 : } else {
517 1085398 : auto run_parallel = [&](const auto &range) {
518 542837 : auto indices = begin; // copy so that run doesn't modify begin
519 542837 : indices.set_index(range.begin());
520 542837 : auto end = begin;
521 542837 : end.set_index(range.end());
522 542837 : run(indices, end);
523 : };
524 542837 : core::parallel::parallel_for(core::parallel::blocked_range(0, arg.size()),
525 : run_parallel);
526 : }
527 636195 : }
528 :
529 : /// Recursion endpoint for do_transform_in_place.
530 : ///
531 : /// Calls transform_in_place_impl unless the output has no variance even
532 : /// though it should.
533 : template <class Op, class Tuple>
534 636344 : static void do_transform_in_place(Op op, Tuple &&processed) {
535 : using namespace detail;
536 636344 : std::apply(
537 1272689 : [&op](auto &&arg, auto &&...args) {
538 : if constexpr (check_all_or_none_variances<Op, decltype(arg),
539 : decltype(args)...>) {
540 2 : throw except::VariancesError(
541 : "Expected either all or none of inputs to have variances.");
542 : } else {
543 636342 : constexpr bool in_var_if_out_var = std::is_base_of_v<
544 : core::transform_flags::expect_in_variance_if_out_variance_t,
545 : Op>;
546 636342 : constexpr bool arg_var =
547 : is_ValuesAndVariances_v<std::decay_t<decltype(arg)>>;
548 636342 : constexpr bool args_var =
549 : (is_ValuesAndVariances_v<std::decay_t<decltype(args)>> || ...);
550 : if constexpr ((in_var_if_out_var ? arg_var == args_var
551 : : arg_var || !args_var) ||
552 : std::is_base_of_v<core::transform_flags::
553 : expect_no_variance_arg_t<0>,
554 : Op>) {
555 636338 : transform_in_place_impl(op, std::forward<decltype(arg)>(arg),
556 : std::forward<decltype(args)>(args)...);
557 : } else {
558 4 : throw except::VariancesError(
559 : "Output has no variance but at least one input does.");
560 : }
561 : }
562 : },
563 : std::forward<Tuple>(processed));
564 636328 : }
565 :
566 : /// Helper for in-place transform implementation, performing branching between
567 : /// output with and without variances as well as handling other operands with
568 : /// and without variances.
569 : template <class Op, class Tuple, class Arg, class... Args>
570 1438721 : static void do_transform_in_place(Op op, Tuple &&processed, Arg &arg,
571 : const Args &...args) {
572 : using namespace detail;
573 1438721 : auto vals = arg.values();
574 1438721 : if (arg.has_variances()) {
575 : if constexpr (std::is_base_of_v<
576 : core::transform_flags::expect_no_variance_arg_t<
577 : std::tuple_size_v<Tuple>>,
578 : Op>) {
579 0 : throw except::VariancesError("Variances in argument " +
580 : std::to_string(std::tuple_size_v<Tuple>) +
581 : " not supported.");
582 : } else if constexpr (core::canHaveVariances<typename Arg::value_type>()) {
583 24581 : auto vars = arg.variances();
584 24584 : do_transform_in_place(
585 : op,
586 : std::tuple_cat(processed,
587 24606 : std::tuple(ValuesAndVariances{vals, vars})),
588 : args...);
589 24581 : }
590 : } else {
591 : if constexpr (std::is_base_of_v<
592 : core::transform_flags::expect_variance_arg_t<
593 : std::tuple_size_v<Tuple>>,
594 : Op>) {
595 : throw except::VariancesError("Argument " +
596 : std::to_string(std::tuple_size_v<Tuple>) +
597 : " must have variances.");
598 : } else {
599 1414162 : do_transform_in_place(op, std::tuple_cat(processed, std::tuple(vals)),
600 : args...);
601 : }
602 : }
603 1438721 : }
604 :
605 : /// Functor for in-place transformation, applying `op` to all elements.
606 : ///
607 : /// This is responsible for converting the client-provided functor `Op` which
608 : /// operates on elements to a functor for the data container, which is
609 : /// required by `visit`.
610 : template <class Op> struct TransformInPlace {
611 : Op op;
612 : template <class T, class... Ts>
613 636539 : void operator()(T &&out, Ts &&...handles) const {
614 : using namespace detail;
615 : // If there is an overlap between lhs and rhs we copy the rhs before
616 : // applying the operation.
617 614809 : if ((overlaps(out, handles) || ...)) {
618 : if constexpr (sizeof...(Ts) == 1) {
619 179 : auto copy = (handles.clone(), ...);
620 179 : return operator()(std::forward<T>(out), Ts(copy)...);
621 179 : } else {
622 0 : throw std::runtime_error(
623 : "Overlap handling only implemented for 2 inputs.");
624 : }
625 : }
626 1250958 : const auto dims = merge(out.dims(), handles.dims()...);
627 636344 : auto out_view = as_view{out, dims};
628 636344 : do_transform_in_place(op, std::tuple<>{}, out_view,
629 : as_view{handles, dims}...);
630 636344 : }
631 : };
632 : // gcc cannot deal with deduction guide for nested class => helper function.
633 636370 : template <class Op> static auto makeTransformInPlace(Op op) {
634 636370 : return TransformInPlace<Op>{detail::wrap_eigen{op}};
635 : }
636 :
637 : template <class... Ts, class Op, class Var, class... Other>
638 636370 : static void transform_data(const std::tuple<Ts...> &, Op op,
639 : const std::string_view &name, Var &&var,
640 : Other &&...other) {
641 : using namespace detail;
642 : try {
643 636370 : visit<Ts...>::apply(makeTransformInPlace(op), var, other...);
644 20 : } catch (const std::bad_variant_access &) {
645 10 : throw except::TypeError("'" + std::string(name) +
646 : "' does not support dtypes ",
647 : var, other...);
648 : }
649 636328 : }
650 : template <class... Ts, class Op, class Var, class... Other>
651 485638 : static void transform(Op op, const std::string_view &name, Var &&var,
652 : const Other &...other) {
653 : using namespace detail;
654 463908 : (scipp::expect::includes(var.dims(), other.dims()), ...);
655 :
656 485628 : if (!is_bins(var) && ((is_bins(other) || ...))) {
657 0 : throw except::BinnedDataError(
658 : "Cannot apply inplace operation where target is "
659 : "not binned but arguments are binned");
660 : }
661 :
662 : if constexpr (!std::is_base_of_v<
663 : core::transform_flags::force_variance_broadcast_t, Op>) {
664 284594 : if (const auto dims = merge(var.dims(), other.dims()...);
665 131431 : (bad_variance_broadcast(dims, other) || ...))
666 4 : throw_variances_broadcast_error(var, other...);
667 153157 : if (is_bins(var) || (is_bins(other) || ...))
668 19674 : if (((other.has_variances() && !is_bins(other)) || ...))
669 1 : throw_variances_broadcast_error(var, other...);
670 : }
671 :
672 485623 : auto unit = variableFactory().elem_unit(var);
673 485623 : op(unit, variableFactory().elem_unit(other)...);
674 : // Stop early in bad cases of changing units (if `var` is a slice):
675 485621 : variableFactory().expect_can_set_elem_unit(var, unit);
676 : // Wrapped implementation to convert multiple tuples into a parameter pack.
677 485624 : transform_data(type_tuples<Ts...>(op), op, name, std::forward<Var>(var),
678 : other...);
679 : if constexpr (dry_run)
680 133 : return;
681 485440 : variableFactory().set_elem_unit(var, unit);
682 485440 : }
683 : };
684 :
685 : /// Transform the data elements of a variable in-place.
686 : ///
687 : /// Note that this is deliberately not named `for_each`: Unlike std::for_each,
688 : /// this function does not promise in-order execution. This overload is
689 : /// equivalent to std::transform with a single input range and an output range
690 : /// identical to the input range, but avoids potentially costly element copies.
691 : template <class... Ts, class Var, class Op>
692 21730 : void transform_in_place(Var &&var, Op op, const std::string_view &name) {
693 21730 : in_place<false>::transform<Ts...>(op, name, std::forward<Var>(var));
694 21730 : }
695 :
696 : /// Transform the data elements of a variable in-place.
697 : ///
698 : /// This overload is equivalent to std::transform with two input ranges and an
699 : /// output range identical to the second input range, but avoids potentially
700 : /// costly element copies.
701 : template <class... TypePairs, class Var, class Op>
702 419054 : void transform_in_place(Var &&var, const Variable &other, Op op,
703 : const std::string_view &name) {
704 419054 : in_place<false>::transform<TypePairs...>(op, name, std::forward<Var>(var),
705 : other);
706 419003 : }
707 :
708 : /// Transform the data elements of a variable in-place.
709 : template <class... TypePairs, class Var, class Op>
710 9898 : void transform_in_place(Var &&var, const Variable &var1, const Variable &var2,
711 : Op op, const std::string_view &name) {
712 9898 : in_place<false>::transform<TypePairs...>(op, name, std::forward<Var>(var),
713 : var1, var2);
714 9894 : }
715 :
716 : /// Transform the data elements of a variable in-place.
717 : template <class... TypePairs, class Var, class Op>
718 34815 : void transform_in_place(Var &&var, const Variable &var1, const Variable &var2,
719 : const Variable &var3, Op op,
720 : const std::string_view &name) {
721 34815 : in_place<false>::transform<TypePairs...>(op, name, std::forward<Var>(var),
722 : var1, var2, var3);
723 34813 : }
724 :
725 : namespace dry_run {
726 : template <class... Ts, class Var, class Op>
727 : void transform_in_place(Var &&var, Op op, const std::string_view &name) {
728 : in_place<true>::transform<Ts...>(op, name, std::forward<Var>(var));
729 : }
730 : template <class... TypePairs, class Var, class Op>
731 141 : void transform_in_place(Var &&var, const Variable &other, Op op,
732 : const std::string_view &name) {
733 141 : in_place<true>::transform<TypePairs...>(op, name, std::forward<Var>(var),
734 : other);
735 133 : }
736 : } // namespace dry_run
737 :
738 : namespace detail {
739 : template <class... Ts, class Op, class... Vars>
740 1184907 : Variable transform(std::tuple<Ts...> &&, Op op, const std::string_view &name,
741 : const Vars &...vars) {
742 : using namespace detail;
743 : try {
744 1184907 : return visit<Ts...>::apply(Transform{wrap_eigen{op}}, vars...);
745 32 : } catch (const std::bad_variant_access &) {
746 6 : throw except::TypeError(
747 6 : "'" + std::string(name) + "' does not support dtypes ", vars...);
748 : }
749 : }
750 : } // namespace detail
751 :
752 : /// Transform the data elements of a variable and return a new Variable.
753 : ///
754 : /// This overload is equivalent to std::transform with a single input range, but
755 : /// avoids the need to manually create a new variable for the output and the
756 : /// need for, e.g., std::back_inserter.
757 : template <class... Ts, class Op>
758 429309 : [[nodiscard]] Variable transform(const Variable &var, Op op,
759 : const std::string_view &name) {
760 429309 : return detail::transform(type_tuples<Ts...>(op), op, name, var);
761 : }
762 :
763 : /// Transform the data elements of two variables and return a new Variable.
764 : ///
765 : /// This overload is equivalent to std::transform with two input ranges, but
766 : /// avoids the need to manually create a new variable for the output and the
767 : /// need for, e.g., std::back_inserter.
768 : template <class... Ts, class Op>
769 744114 : [[nodiscard]] Variable transform(const Variable &var1, const Variable &var2,
770 : Op op, const std::string_view &name) {
771 744114 : return detail::transform(type_tuples<Ts...>(op), op, name, var1, var2);
772 : }
773 :
774 : /// Transform the data elements of three variables and return a new Variable.
775 : template <class... Ts, class Op>
776 11318 : [[nodiscard]] Variable transform(const Variable &var1, const Variable &var2,
777 : const Variable &var3, Op op,
778 : const std::string_view &name) {
779 11318 : return detail::transform(type_tuples<Ts...>(op), op, name, var1, var2, var3);
780 : }
781 :
782 : /// Transform the data elements of four variables and return a new Variable.
783 : template <class... Ts, class Op>
784 166 : [[nodiscard]] Variable transform(const Variable &var1, const Variable &var2,
785 : const Variable &var3, const Variable &var4,
786 : Op op, const std::string_view &name) {
787 165 : return detail::transform(type_tuples<Ts...>(op), op, name, var1, var2, var3,
788 166 : var4);
789 : }
790 :
791 : } // namespace scipp::variable
|