LCOV - code coverage report
Current view: top level - variable/include/scipp/variable - transform.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 240 245 98.0 %
Date: 2024-04-28 01:25:40 Functions: 9144 35928 25.5 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file Various transform functions for variables.
       4             : ///
       5             : /// The underlying mechanism of the implementation is as follows:
       6             : /// 1. `visit<...>::apply` obtains the concrete underlying data type(s).
       7             : /// 2. `Transform` is applied to that concrete container, calling
       8             : ///    `do_transform`. `Transform` essentially builds a callable accepting a
       9             : ///    container from a callable accepting an element of the container.
      10             : /// 3. `do_transform` is essentially a fancy std::transform. It uses recursion
      11             : ///    to process optional flags (provided as base classes of the user-provided
      12             : ///    operator). It provides automatic handling of data that has variances in
      13             : ///    addition to values, calling a different transform implementation for each
      14             : ///    case (different instantiations of `transform_elements`).
      15             : /// 4. The `transform_elements` function calls the overloaded operator for
      16             : ///    each element. This is also were multi-threading for the majority of
      17             : ///    scipp's operations is implemented.
      18             : ///
      19             : /// Handling of binned data is mostly hidden in this implementation, reducing
      20             : /// code duplication:
      21             : /// - `variableFactory()` is used for output creation and unit access.
      22             : /// - `variableFactory()` is used in `visit.h` to obtain a direct pointer to the
      23             : ///   underlying buffer.
      24             : /// - MultiIndex contains special handling for binned data, i.e., it can iterate
      25             : ///   the buffer in a binning-aware way.
      26             : ///
      27             : /// The mechanism for in-place transformation is mostly identical to the one
      28             : /// outlined above.
      29             : ///
      30             : /// @author Simon Heybrock
      31             : #pragma once
      32             : 
      33             : #include <algorithm>
      34             : #include <cassert>
      35             : #include <string_view>
      36             : 
      37             : #include "scipp/common/overloaded.h"
      38             : 
      39             : #include "scipp/core/has_eval.h"
      40             : #include "scipp/core/multi_index.h"
      41             : #include "scipp/core/parallel.h"
      42             : #include "scipp/core/transform_common.h"
      43             : #include "scipp/core/value_and_variance.h"
      44             : #include "scipp/core/values_and_variances.h"
      45             : 
      46             : #include "scipp/variable/except.h"
      47             : #include "scipp/variable/variable.h"
      48             : #include "scipp/variable/variable_factory.h"
      49             : #include "scipp/variable/visit.h"
      50             : 
      51             : namespace scipp::variable {
      52             : 
      53             : namespace detail {
      54             : 
      55             : template <class T> struct has_variances : std::false_type {};
      56             : template <class T>
      57             : struct has_variances<ValueAndVariance<T>> : std::true_type {};
      58             : template <class T>
      59             : struct has_variances<ValuesAndVariances<T>> : std::true_type {};
      60             : template <class T>
      61             : inline constexpr bool has_variances_v = has_variances<T>::value;
      62             : 
      63             : /// Helper for the transform implementation to unify iteration of data with and
      64             : /// without variances.
      65             : template <class T>
      66  1333055266 : static constexpr decltype(auto) value_maybe_variance(T &&range,
      67             :                                                      const scipp::index i) {
      68             :   if constexpr (has_variances_v<std::decay_t<T>>) {
      69     7962176 :     return ValueAndVariance{range.values.data()[i], range.variances.data()[i]};
      70             :   } else {
      71  1325093090 :     return range.data()[i];
      72             :   }
      73             : }
      74             : 
      75     4575681 : template <class T> static constexpr auto array_params(T &&iterable) noexcept {
      76             :   if constexpr (is_ValuesAndVariances_v<std::decay_t<T>>)
      77       26871 :     return iterable.values;
      78             :   else
      79     4548810 :     return iterable;
      80             : }
      81             : 
      82             : template <size_t N_Operands, bool in_place>
      83             : inline constexpr auto stride_special_cases =
      84             :     std::array<std::array<scipp::index, N_Operands>, 0>{};
      85             : 
      86             : template <>
      87             : inline constexpr auto stride_special_cases<1, true> =
      88             :     std::array<std::array<scipp::index, 1>, 2>{{{1}}};
      89             : 
      90             : template <>
      91             : inline constexpr auto stride_special_cases<2, true> =
      92             :     std::array<std::array<scipp::index, 2>, 4>{{{1, 1}, {0, 1}, {1, 0}}};
      93             : 
      94             : template <>
      95             : inline constexpr auto stride_special_cases<2, false> =
      96             :     std::array<std::array<scipp::index, 2>, 1>{{{1, 1}}};
      97             : 
      98             : template <>
      99             : inline constexpr auto stride_special_cases<3, false> =
     100             :     std::array<std::array<scipp::index, 3>, 3>{
     101             :         {{1, 1, 1}, {1, 0, 1}, {1, 1, 0}}};
     102             : 
     103             : template <>
     104             : inline constexpr auto stride_special_cases<4, false> =
     105             :     std::array<std::array<scipp::index, 4>, 4>{
     106             :         {{1, 1, 1, 1}, {1, 0, 1, 1}, {1, 1, 0, 1}, {1, 1, 1, 0}}};
     107             : 
     108             : template <>
     109             : inline constexpr auto stride_special_cases<5, false> =
     110             :     std::array<std::array<scipp::index, 5>, 4>{
     111             :         {{1, 1, 1, 1, 1}, {1, 1, 1, 1, 0}, {1, 1, 1, 0, 0}, {1, 1, 0, 0, 0}}};
     112             : 
     113             : template <size_t I, size_t N_Operands, bool in_place, size_t... Is>
     114             : auto stride_sequence_impl(std::index_sequence<Is...>) -> std::integer_sequence<
     115             :     scipp::index, stride_special_cases<N_Operands, in_place>.at(I)[Is]...>;
     116             : // THe above uses std::array::at instead of operator[] in order to circumvent
     117             : // a false positive error in MSVC 19.
     118             : 
     119             : template <size_t I, size_t N_Operands, bool in_place> struct stride_sequence {
     120             :   using type = decltype(stride_sequence_impl<I, N_Operands, in_place>(
     121             :       std::make_index_sequence<N_Operands>{}));
     122             : };
     123             : 
     124             : template <size_t I, size_t N_Operands, bool in_place>
     125             : using make_stride_sequence =
     126             :     typename stride_sequence<I, N_Operands, in_place>::type;
     127             : 
     128             : template <scipp::index... Strides, size_t... Is>
     129   485153574 : void increment_impl(std::array<scipp::index, sizeof...(Strides)> &indices,
     130             :                     std::integer_sequence<size_t, Is...>) noexcept {
     131   485153574 :   ((indices[Is] += Strides), ...);
     132   485153574 : }
     133             : 
     134             : template <scipp::index... Strides>
     135   485153574 : void increment(std::array<scipp::index, sizeof...(Strides)> &indices) noexcept {
     136   485153574 :   increment_impl<Strides...>(indices,
     137             :                              std::make_index_sequence<sizeof...(Strides)>{});
     138   485153574 : }
     139             : 
     140             : template <size_t N>
     141    88221365 : void increment(std::array<scipp::index, N> &indices,
     142             :                const scipp::span<const scipp::index> strides) noexcept {
     143   375077358 :   for (size_t i = 0; i < N; ++i) {
     144   286855993 :     indices[i] += strides[i];
     145             :   }
     146    88221365 : }
     147             : 
     148             : template <class Op, class Indices, class... Args, size_t... I>
     149    89428048 : static constexpr auto call_impl(Op &&op, const Indices &indices,
     150             :                                 std::index_sequence<I...>, Args &&...args) {
     151    89428048 :   return op(value_maybe_variance(args, indices[I + 1])...);
     152             : }
     153             : template <class Op, class Indices, class Out, class... Args>
     154    89428048 : static constexpr void call(Op &&op, const Indices &indices, Out &&out,
     155             :                            Args &&...args) {
     156    89428048 :   const auto i = indices.front();
     157    89428048 :   auto &&out_ = value_maybe_variance(out, i);
     158    89428048 :   out_ = call_impl(std::forward<Op>(op), indices,
     159             :                    std::make_index_sequence<sizeof...(Args)>{},
     160             :                    std::forward<Args>(args)...);
     161             :   if constexpr (is_ValueAndVariance_v<std::decay_t<decltype(out_)>>) {
     162        2777 :     out.values.data()[i] = out_.value;
     163        2777 :     out.variances.data()[i] = out_.variance;
     164             :   }
     165    89428048 : }
     166             : 
     167             : template <class Op, class Indices, class Arg, class... Args, size_t... I>
     168   483946892 : static constexpr void call_in_place_impl(Op &&op, const Indices &indices,
     169             :                                          std::index_sequence<I...>, Arg &&arg,
     170             :                                          Args &&...args) {
     171             :   static_assert(std::is_same_v<decltype(op(arg, value_maybe_variance(
     172             :                                                     args, indices[I + 1])...)),
     173             :                                void>);
     174   483946892 :   op(arg, value_maybe_variance(args, indices[I + 1])...);
     175   483946891 : }
     176             : template <class Op, class Indices, class Arg, class... Args>
     177   483946892 : static constexpr void call_in_place(Op &&op, const Indices &indices, Arg &&arg,
     178             :                                     Args &&...args) {
     179   483946892 :   const auto i = indices.front();
     180             :   // For dense data we conditionally create ValueAndVariance, which performs an
     181             :   // element copy, so the result may have to be updated after the call to `op`.
     182   483946892 :   auto &&arg_ = value_maybe_variance(arg, i);
     183   483946892 :   call_in_place_impl(std::forward<Op>(op), indices,
     184             :                      std::make_index_sequence<sizeof...(Args)>{},
     185             :                      std::forward<decltype(arg_)>(arg_),
     186             :                      std::forward<Args>(args)...);
     187             :   if constexpr (is_ValueAndVariance_v<std::decay_t<decltype(arg_)>>) {
     188     3976278 :     arg.values.data()[i] = arg_.value;
     189     3976278 :     arg.variances.data()[i] = arg_.variance;
     190             :   }
     191   483946891 : }
     192             : /// Run transform with strides known at compile time.
     193             : template <bool in_place, class Op, class... Operands, scipp::index... Strides>
     194     9656407 : static void inner_loop(Op &&op,
     195             :                        std::array<scipp::index, sizeof...(Operands)> indices,
     196             :                        std::integer_sequence<scipp::index, Strides...>,
     197             :                        const scipp::index n, Operands &&...operands) {
     198             :   static_assert(sizeof...(Operands) == sizeof...(Strides));
     199             : 
     200   494809981 :   for (scipp::index i = 0; i < n; ++i) {
     201             :     if constexpr (in_place) {
     202   404768922 :       detail::call_in_place(op, indices, std::forward<Operands>(operands)...);
     203             :     } else {
     204    80384652 :       detail::call(op, indices, std::forward<Operands>(operands)...);
     205             :     }
     206   485153574 :     detail::increment<Strides...>(indices);
     207             :   }
     208     9656407 : }
     209             : 
     210             : /// Run transform with strides known at run time but bypassing MultiIndex.
     211             : template <bool in_place, class Op, class... Operands>
     212     1388499 : static void inner_loop(Op &&op,
     213             :                        std::array<scipp::index, sizeof...(Operands)> indices,
     214             :                        const scipp::span<const scipp::index> strides,
     215             :                        const scipp::index n, Operands &&...operands) {
     216    89609864 :   for (scipp::index i = 0; i < n; ++i) {
     217             :     if constexpr (in_place) {
     218    79177970 :       detail::call_in_place(op, indices, std::forward<Operands>(operands)...);
     219             :     } else {
     220     9043396 :       detail::call(op, indices, std::forward<Operands>(operands)...);
     221             :     }
     222    88221365 :     detail::increment(indices, strides);
     223             :   }
     224     1388498 : }
     225             : 
     226             : template <bool in_place, size_t I = 0, class Op, class... Operands>
     227    19219528 : static void dispatch_inner_loop(
     228             :     Op &&op, const std::array<scipp::index, sizeof...(Operands)> &indices,
     229             :     const scipp::span<const scipp::index> inner_strides, const scipp::index n,
     230             :     Operands &&...operands) {
     231    19219528 :   constexpr auto N_Operands = sizeof...(Operands);
     232             :   if constexpr (I ==
     233             :                 detail::stride_special_cases<N_Operands, in_place>.size()) {
     234     1388499 :     inner_loop<in_place>(std::forward<Op>(op), indices, inner_strides, n,
     235             :                          std::forward<Operands>(operands)...);
     236             :   } else {
     237    17831029 :     if (std::equal(
     238             :             inner_strides.begin(), inner_strides.end(),
     239    17831029 :             detail::stride_special_cases<N_Operands, in_place>[I].begin())) {
     240     9656407 :       inner_loop<in_place>(
     241             :           std::forward<Op>(op), indices,
     242             :           detail::make_stride_sequence<I, N_Operands, in_place>{}, n,
     243             :           std::forward<Operands>(operands)...);
     244             :     } else {
     245     8174622 :       dispatch_inner_loop<in_place, I + 1>(op, indices, inner_strides, n,
     246             :                                            std::forward<Operands>(operands)...);
     247             :     }
     248             :   }
     249    19219527 : }
     250             : 
     251             : template <class Op, class Out, class... Ts>
     252     1184878 : static void transform_elements(Op op, Out &&out, Ts &&...other) {
     253     1184878 :   const auto begin =
     254             :       core::MultiIndex(array_params(out), array_params(other)...);
     255             : 
     256     2362104 :   auto run = [&](auto &indices, const auto &end) {
     257     1184878 :     const auto inner_strides = indices.inner_strides();
     258     3641608 :     while (indices != end) {
     259             :       // Shape can change when moving between bins -> recompute every time.
     260     2456730 :       const auto inner_size = indices.in_same_chunk(end, 1)
     261     2456730 :                                   ? indices.inner_distance_to(end)
     262             :                                   : indices.inner_distance_to_end();
     263     2456730 :       dispatch_inner_loop<false>(op, indices.get(), inner_strides, inner_size,
     264             :                                  std::forward<Out>(out),
     265     2456730 :                                  std::forward<Ts>(other)...);
     266     2456730 :       indices.increment_by(inner_size != 0 ? inner_size : 1);
     267             :     }
     268             :   };
     269             : 
     270     2362104 :   auto run_parallel = [&](const auto &range) {
     271     1184878 :     auto indices = begin;
     272     1184878 :     indices.set_index(range.begin());
     273     1184878 :     auto end = begin;
     274     1184878 :     end.set_index(range.end());
     275     1184878 :     run(indices, end);
     276             :   };
     277     1184878 :   core::parallel::parallel_for(core::parallel::blocked_range(0, out.size()),
     278             :                                run_parallel);
     279     1184878 : }
     280             : 
     281             : template <class T> static constexpr auto maybe_eval(T &&_) {
     282             :   if constexpr (core::has_eval_v<std::decay_t<T>>)
     283             :     return _.eval();
     284             :   else
     285             :     return std::forward<T>(_);
     286             : }
     287             : 
     288             : template <class Op, class... Args>
     289             : constexpr bool check_all_or_none_variances =
     290             :     std::is_base_of_v<core::transform_flags::expect_all_or_none_have_variance_t,
     291             :                       Op> &&
     292             :     !std::conjunction_v<is_ValuesAndVariances<std::decay_t<Args>>...> &&
     293             :     std::disjunction_v<is_ValuesAndVariances<std::decay_t<Args>>...>;
     294             : 
     295             : /// Recursion endpoint for do_transform.
     296             : ///
     297             : /// Call transform_elements with or without variances for output, depending on
     298             : /// whether any of the arguments has variances or not.
     299             : template <class Op, class Out, class Tuple>
     300     1184879 : static void do_transform(Op op, Out &&out, Tuple &&processed) {
     301     1184879 :   auto out_val = out.values();
     302     1184879 :   std::apply(
     303     3553027 :       [&op, &out, &out_val](auto &&...args) {
     304             :         if constexpr (check_all_or_none_variances<Op, decltype(args)...>) {
     305           1 :           throw except::VariancesError(
     306             :               "Expected either all or none of inputs to have variances.");
     307             :         } else if constexpr (
     308             :             !std::is_base_of_v<core::transform_flags::no_out_variance_t, Op> &&
     309             :             core::canHaveVariances<typename Out::value_type>() &&
     310             :             (is_ValuesAndVariances_v<std::decay_t<decltype(args)>> || ...)) {
     311         804 :           auto out_var = out.variances();
     312         804 :           transform_elements(op, ValuesAndVariances{out_val, out_var},
     313             :                              std::forward<decltype(args)>(args)...);
     314         804 :         } else {
     315     1184074 :           transform_elements(op, out_val,
     316             :                              std::forward<decltype(args)>(args)...);
     317             :         }
     318             :       },
     319             :       std::forward<Tuple>(processed));
     320     1184879 : }
     321             : 
     322             : /// Helper for transform implementation, performing branching between output
     323             : /// with and without variances as well as handling other operands with and
     324             : /// without variances.
     325             : template <class Op, class Out, class Tuple, class Arg, class... Args>
     326     1952100 : static void do_transform(Op op, Out &&out, Tuple &&processed, const Arg &arg,
     327             :                          const Args &...args) {
     328     1952100 :   auto vals = arg.values();
     329     1952100 :   if (arg.has_variances()) {
     330             :     if constexpr (std::is_base_of_v<
     331             :                       core::transform_flags::expect_no_variance_arg_t<
     332             :                           std::tuple_size_v<Tuple>>,
     333             :                       Op>) {
     334           1 :       throw except::VariancesError("Variances in argument " +
     335             :                                    std::to_string(std::tuple_size_v<Tuple>) +
     336             :                                    " not supported.");
     337             :     } else if constexpr (
     338             :         std::is_base_of_v<
     339             :             core::transform_flags::
     340             :                 expect_no_in_variance_if_out_cannot_have_variance_t,
     341             :             Op> &&
     342             :         !core::canHaveVariances<typename Out::value_type>()) {
     343           0 :       throw except::VariancesError(
     344             :           "Variances in argument " + std::to_string(std::tuple_size_v<Tuple>) +
     345             :           " not supported as output dtype cannot have variances");
     346             :     } else if constexpr (core::canHaveVariances<typename Arg::value_type>()) {
     347        1495 :       auto vars = arg.variances();
     348        1495 :       do_transform(
     349             :           op, std::forward<Out>(out),
     350        1497 :           std::tuple_cat(processed, std::tuple(ValuesAndVariances{vals, vars})),
     351             :           args...);
     352        1495 :     }
     353             :     // else {}  // Cannot happen because args.has_variances()
     354             :     //             implies canHaveVariances<value_type>.
     355             :     //             The 2nd test is needed to avoid compilation errors
     356             :     //             (has_variances is a runtime check).
     357             :   } else {
     358             :     if constexpr (std::is_base_of_v<
     359             :                       core::transform_flags::expect_variance_arg_t<
     360             :                           std::tuple_size_v<Tuple>>,
     361             :                       Op>)
     362           0 :       throw except::VariancesError("Variances missing in argument " +
     363             :                                    std::to_string(std::tuple_size_v<Tuple>) +
     364             :                                    " . Must be set.");
     365     1950605 :     do_transform(op, std::forward<Out>(out),
     366     3901209 :                  std::tuple_cat(processed, std::tuple(vals)), args...);
     367             :   }
     368     1952100 : }
     369             : 
     370             : template <class T> struct as_view {
     371             :   using value_type = typename T::value_type;
     372     3390821 :   [[nodiscard]] bool has_variances() const { return data.has_variances(); }
     373     3390821 :   auto values() const { return decltype(data.values())(data.values(), dims); }
     374       26076 :   auto variances() const {
     375       26076 :     return decltype(data.variances())(data.variances(), dims);
     376             :   }
     377             :   T &data;
     378             :   const Dimensions &dims;
     379             : };
     380             : template <class T> as_view(T &data, const Dimensions &dims) -> as_view<T>;
     381             : 
     382             : template <class T>
     383     2102182 : bool bad_variance_broadcast(const Dimensions &dims, const T &var) {
     384     2102182 :   if (!var.has_variances())
     385     2100022 :     return false;
     386             :   // implicit broadcast
     387        2160 :   if (dims.ndim() > var.dims().ndim())
     388          10 :     return true;
     389             :   // there may be a stride==0 when an inner dim has length==0
     390        2150 :   if (dims.volume() == 0)
     391          22 :     return false;
     392             :   // explicit broadcast
     393        4256 :   return std::any_of(var.strides().begin(), var.strides().end(),
     394        4063 :                      [](scipp::index s) { return s == 0; });
     395             : }
     396             : 
     397             : template <class... Vars>
     398          14 : [[noreturn]] void throw_variances_broadcast_error(Vars &&...vars) {
     399          48 :   throw except::VariancesError(
     400             :       "Cannot broadcast object with variances as this would introduce "
     401             :       "unhandled correlations. Input dimensions were:\n" +
     402             :       ((to_string(vars.dims()) +
     403          34 :         " variances=" + (vars.has_variances() ? "True" : "False") + '\n') +
     404             :        ...) +
     405             :       "\n" + "See https://doi.org/10.3233/JNR-220049 for more background.");
     406             : }
     407             : 
     408             : template <class Op> struct Transform {
     409             :   Op op;
     410     1184904 :   template <class... Ts> Variable operator()(Ts &&...handles) const {
     411     1940499 :     const auto dims = merge(handles.dims()...);
     412             : 
     413             :     if constexpr (!std::is_base_of_v<
     414             :                       core::transform_flags::force_variance_broadcast_t, Op>) {
     415     1134620 :       if ((bad_variance_broadcast(dims, handles) || ...))
     416           8 :         throw_variances_broadcast_error(handles...);
     417     1134612 :       if ((handles.is_bins() || ...))
     418         152 :         if (((handles.has_variances() && !handles.is_bins()) || ...))
     419           1 :           throw_variances_broadcast_error(handles...);
     420             :     }
     421             : 
     422             :     using Out = decltype(maybe_eval(op(handles.values()[0]...)));
     423     1184895 :     const bool variances =
     424             :         !std::is_base_of_v<core::transform_flags::no_out_variance_t, Op> &&
     425      504667 :         core::canHaveVariances<Out>() && (handles.has_variances() || ...);
     426     1184895 :     auto unit = op.base_op()(variableFactory().elem_unit(*handles.m_var)...);
     427     1184881 :     auto out = variableFactory().create(dtype<Out>, dims, unit, variances,
     428     1184881 :                                         *handles.m_var...);
     429     1184880 :     do_transform(op, variable_access<Out>(out), std::tuple<>(),
     430             :                  as_view{handles, dims}...);
     431     2369756 :     return out;
     432     1184906 :   }
     433             : };
     434             : template <class Op> Transform(Op) -> Transform<Op>;
     435             : 
     436             : // std::tuple_cat does not work correctly on with clang-7. Issue with
     437             : // Eigen::Vector3d.
     438             : template <typename T, typename...> struct tuple_cat {
     439             :   using type = T;
     440             : };
     441             : template <template <typename...> class C, typename... Ts1, typename... Ts2,
     442             :           typename... Ts3>
     443             : struct tuple_cat<C<Ts1...>, C<Ts2...>, Ts3...>
     444             :     : public tuple_cat<C<Ts1..., Ts2...>, Ts3...> {};
     445             : 
     446             : template <class Op> struct wrap_eigen : Op {
     447     1184895 :   const Op &base_op() const noexcept { return *this; }
     448    89428048 :   template <class... Ts> constexpr auto operator()(Ts &&...args) const {
     449             :     if constexpr ((core::has_eval_v<std::decay_t<Ts>> || ...))
     450             :       // WARNING! The explicit specification of the template arguments of
     451             :       // operator() is EXTREMELY IMPORTANT. It ensures that Eigen types are
     452             :       // passed BY REFERENCE and NOT BY VALUE. Passing by value leads to
     453             :       // construction of expressions of values on the stack, which are then
     454             :       // returned from the operator. One way to identify this is using
     455             :       // address-sanitizer, which finds a `stack-use-after-scope`.
     456         155 :       return Op::template operator()<Ts...>(std::forward<Ts>(args)...);
     457             :     else
     458    89427893 :       return Op::template operator()(std::forward<Ts>(args)...);
     459             :   }
     460             : };
     461             : template <class... Ts> wrap_eigen(Ts...) -> wrap_eigen<Ts...>;
     462             : } // namespace detail
     463             : 
     464             : template <class... Ts, class Op>
     465     1809754 : static constexpr auto type_tuples(Op) noexcept {
     466             :   if constexpr (sizeof...(Ts) == 0)
     467     1066505 :     return typename Op::types{};
     468             :   else if constexpr ((visit_detail::is_tuple<Ts>::value || ...))
     469          66 :     return typename detail::tuple_cat<Ts...>::type{};
     470             :   else
     471      743183 :     return std::tuple<Ts...>{};
     472             : }
     473             : 
     474      802572 : constexpr auto overlaps = [](const auto &a, const auto &b) {
     475             :   if constexpr (std::is_same_v<typename std::decay_t<decltype(a)>::value_type,
     476             :                                typename std::decay_t<decltype(b)>::value_type>)
     477      511417 :     return a.values().overlaps(b.values());
     478             :   else
     479      291155 :     return false;
     480             : };
     481             : 
     482             : /// Helper class wrapping functions for in-place transform.
     483             : ///
     484             : /// The dry_run template argument can be used to disable any actual modification
     485             : /// of data. This is used to implement operations on datasets with a strong
     486             : /// exception guarantee.
     487             : template <bool dry_run> struct in_place {
     488             :   template <class Op, class T, class... Ts>
     489      636338 :   static void transform_in_place_impl(Op op, T &&arg, Ts &&...other) {
     490             :     using namespace detail;
     491      636356 :     const auto begin =
     492             :         core::MultiIndex(array_params(arg), array_params(other)...);
     493             :     if constexpr (dry_run)
     494         133 :       return;
     495             : 
     496     1272116 :     auto run = [&](auto &indices, const auto &end) {
     497      636196 :       const auto inner_strides = indices.inner_strides();
     498     9224371 :       while (indices != end) {
     499             :         // Shape can change when moving between bins -> recompute every time.
     500     8588176 :         const auto inner_size = indices.in_same_chunk(end, 1)
     501     8588176 :                                     ? indices.inner_distance_to(end)
     502             :                                     : indices.inner_distance_to_end();
     503     8588176 :         detail::dispatch_inner_loop<true>(op, indices.get(), inner_strides,
     504             :                                           inner_size, std::forward<T>(arg),
     505     8565622 :                                           std::forward<Ts>(other)...);
     506     8588175 :         indices.increment_by(inner_size != 0 ? inner_size : 1);
     507             :       }
     508             :     };
     509      636196 :     if (begin.has_stride_zero()) {
     510             :       // The output has a dimension with stride zero so parallelization must
     511             :       // be done differently. See parallelization in accumulate.h.
     512       93359 :       auto indices = begin;
     513       93359 :       auto end = begin;
     514       93359 :       end.set_index(arg.size());
     515       93359 :       run(indices, end);
     516             :     } else {
     517     1085398 :       auto run_parallel = [&](const auto &range) {
     518      542837 :         auto indices = begin; // copy so that run doesn't modify begin
     519      542837 :         indices.set_index(range.begin());
     520      542837 :         auto end = begin;
     521      542837 :         end.set_index(range.end());
     522      542837 :         run(indices, end);
     523             :       };
     524      542837 :       core::parallel::parallel_for(core::parallel::blocked_range(0, arg.size()),
     525             :                                    run_parallel);
     526             :     }
     527      636195 :   }
     528             : 
     529             :   /// Recursion endpoint for do_transform_in_place.
     530             :   ///
     531             :   /// Calls transform_in_place_impl unless the output has no variance even
     532             :   /// though it should.
     533             :   template <class Op, class Tuple>
     534      636344 :   static void do_transform_in_place(Op op, Tuple &&processed) {
     535             :     using namespace detail;
     536      636344 :     std::apply(
     537     1272689 :         [&op](auto &&arg, auto &&...args) {
     538             :           if constexpr (check_all_or_none_variances<Op, decltype(arg),
     539             :                                                     decltype(args)...>) {
     540           2 :             throw except::VariancesError(
     541             :                 "Expected either all or none of inputs to have variances.");
     542             :           } else {
     543      636342 :             constexpr bool in_var_if_out_var = std::is_base_of_v<
     544             :                 core::transform_flags::expect_in_variance_if_out_variance_t,
     545             :                 Op>;
     546      636342 :             constexpr bool arg_var =
     547             :                 is_ValuesAndVariances_v<std::decay_t<decltype(arg)>>;
     548      636342 :             constexpr bool args_var =
     549             :                 (is_ValuesAndVariances_v<std::decay_t<decltype(args)>> || ...);
     550             :             if constexpr ((in_var_if_out_var ? arg_var == args_var
     551             :                                              : arg_var || !args_var) ||
     552             :                           std::is_base_of_v<core::transform_flags::
     553             :                                                 expect_no_variance_arg_t<0>,
     554             :                                             Op>) {
     555      636338 :               transform_in_place_impl(op, std::forward<decltype(arg)>(arg),
     556             :                                       std::forward<decltype(args)>(args)...);
     557             :             } else {
     558           4 :               throw except::VariancesError(
     559             :                   "Output has no variance but at least one input does.");
     560             :             }
     561             :           }
     562             :         },
     563             :         std::forward<Tuple>(processed));
     564      636328 :   }
     565             : 
     566             :   /// Helper for in-place transform implementation, performing branching between
     567             :   /// output with and without variances as well as handling other operands with
     568             :   /// and without variances.
     569             :   template <class Op, class Tuple, class Arg, class... Args>
     570     1438721 :   static void do_transform_in_place(Op op, Tuple &&processed, Arg &arg,
     571             :                                     const Args &...args) {
     572             :     using namespace detail;
     573     1438721 :     auto vals = arg.values();
     574     1438721 :     if (arg.has_variances()) {
     575             :       if constexpr (std::is_base_of_v<
     576             :                         core::transform_flags::expect_no_variance_arg_t<
     577             :                             std::tuple_size_v<Tuple>>,
     578             :                         Op>) {
     579           0 :         throw except::VariancesError("Variances in argument " +
     580             :                                      std::to_string(std::tuple_size_v<Tuple>) +
     581             :                                      " not supported.");
     582             :       } else if constexpr (core::canHaveVariances<typename Arg::value_type>()) {
     583       24581 :         auto vars = arg.variances();
     584       24584 :         do_transform_in_place(
     585             :             op,
     586             :             std::tuple_cat(processed,
     587       24606 :                            std::tuple(ValuesAndVariances{vals, vars})),
     588             :             args...);
     589       24581 :       }
     590             :     } else {
     591             :       if constexpr (std::is_base_of_v<
     592             :                         core::transform_flags::expect_variance_arg_t<
     593             :                             std::tuple_size_v<Tuple>>,
     594             :                         Op>) {
     595             :         throw except::VariancesError("Argument " +
     596             :                                      std::to_string(std::tuple_size_v<Tuple>) +
     597             :                                      " must have variances.");
     598             :       } else {
     599     1414162 :         do_transform_in_place(op, std::tuple_cat(processed, std::tuple(vals)),
     600             :                               args...);
     601             :       }
     602             :     }
     603     1438721 :   }
     604             : 
     605             :   /// Functor for in-place transformation, applying `op` to all elements.
     606             :   ///
     607             :   /// This is responsible for converting the client-provided functor `Op` which
     608             :   /// operates on elements to a functor for the data container, which is
     609             :   /// required by `visit`.
     610             :   template <class Op> struct TransformInPlace {
     611             :     Op op;
     612             :     template <class T, class... Ts>
     613      636539 :     void operator()(T &&out, Ts &&...handles) const {
     614             :       using namespace detail;
     615             :       // If there is an overlap between lhs and rhs we copy the rhs before
     616             :       // applying the operation.
     617      614809 :       if ((overlaps(out, handles) || ...)) {
     618             :         if constexpr (sizeof...(Ts) == 1) {
     619         179 :           auto copy = (handles.clone(), ...);
     620         179 :           return operator()(std::forward<T>(out), Ts(copy)...);
     621         179 :         } else {
     622           0 :           throw std::runtime_error(
     623             :               "Overlap handling only implemented for 2 inputs.");
     624             :         }
     625             :       }
     626     1250958 :       const auto dims = merge(out.dims(), handles.dims()...);
     627      636344 :       auto out_view = as_view{out, dims};
     628      636344 :       do_transform_in_place(op, std::tuple<>{}, out_view,
     629             :                             as_view{handles, dims}...);
     630      636344 :     }
     631             :   };
     632             :   // gcc cannot deal with deduction guide for nested class => helper function.
     633      636370 :   template <class Op> static auto makeTransformInPlace(Op op) {
     634      636370 :     return TransformInPlace<Op>{detail::wrap_eigen{op}};
     635             :   }
     636             : 
     637             :   template <class... Ts, class Op, class Var, class... Other>
     638      636370 :   static void transform_data(const std::tuple<Ts...> &, Op op,
     639             :                              const std::string_view &name, Var &&var,
     640             :                              Other &&...other) {
     641             :     using namespace detail;
     642             :     try {
     643      636370 :       visit<Ts...>::apply(makeTransformInPlace(op), var, other...);
     644          20 :     } catch (const std::bad_variant_access &) {
     645          10 :       throw except::TypeError("'" + std::string(name) +
     646             :                                   "' does not support dtypes ",
     647             :                               var, other...);
     648             :     }
     649      636328 :   }
     650             :   template <class... Ts, class Op, class Var, class... Other>
     651      485638 :   static void transform(Op op, const std::string_view &name, Var &&var,
     652             :                         const Other &...other) {
     653             :     using namespace detail;
     654      463908 :     (scipp::expect::includes(var.dims(), other.dims()), ...);
     655             : 
     656      485628 :     if (!is_bins(var) && ((is_bins(other) || ...))) {
     657           0 :       throw except::BinnedDataError(
     658             :           "Cannot apply inplace operation where target is "
     659             :           "not binned but arguments are binned");
     660             :     }
     661             : 
     662             :     if constexpr (!std::is_base_of_v<
     663             :                       core::transform_flags::force_variance_broadcast_t, Op>) {
     664      284594 :       if (const auto dims = merge(var.dims(), other.dims()...);
     665      131431 :           (bad_variance_broadcast(dims, other) || ...))
     666           4 :         throw_variances_broadcast_error(var, other...);
     667      153157 :       if (is_bins(var) || (is_bins(other) || ...))
     668       19674 :         if (((other.has_variances() && !is_bins(other)) || ...))
     669           1 :           throw_variances_broadcast_error(var, other...);
     670             :     }
     671             : 
     672      485623 :     auto unit = variableFactory().elem_unit(var);
     673      485623 :     op(unit, variableFactory().elem_unit(other)...);
     674             :     // Stop early in bad cases of changing units (if `var` is a slice):
     675      485621 :     variableFactory().expect_can_set_elem_unit(var, unit);
     676             :     // Wrapped implementation to convert multiple tuples into a parameter pack.
     677      485624 :     transform_data(type_tuples<Ts...>(op), op, name, std::forward<Var>(var),
     678             :                    other...);
     679             :     if constexpr (dry_run)
     680         133 :       return;
     681      485440 :     variableFactory().set_elem_unit(var, unit);
     682      485440 :   }
     683             : };
     684             : 
     685             : /// Transform the data elements of a variable in-place.
     686             : ///
     687             : /// Note that this is deliberately not named `for_each`: Unlike std::for_each,
     688             : /// this function does not promise in-order execution. This overload is
     689             : /// equivalent to std::transform with a single input range and an output range
     690             : /// identical to the input range, but avoids potentially costly element copies.
     691             : template <class... Ts, class Var, class Op>
     692       21730 : void transform_in_place(Var &&var, Op op, const std::string_view &name) {
     693       21730 :   in_place<false>::transform<Ts...>(op, name, std::forward<Var>(var));
     694       21730 : }
     695             : 
     696             : /// Transform the data elements of a variable in-place.
     697             : ///
     698             : /// This overload is equivalent to std::transform with two input ranges and an
     699             : /// output range identical to the second input range, but avoids potentially
     700             : /// costly element copies.
     701             : template <class... TypePairs, class Var, class Op>
     702      419054 : void transform_in_place(Var &&var, const Variable &other, Op op,
     703             :                         const std::string_view &name) {
     704      419054 :   in_place<false>::transform<TypePairs...>(op, name, std::forward<Var>(var),
     705             :                                            other);
     706      419003 : }
     707             : 
     708             : /// Transform the data elements of a variable in-place.
     709             : template <class... TypePairs, class Var, class Op>
     710        9898 : void transform_in_place(Var &&var, const Variable &var1, const Variable &var2,
     711             :                         Op op, const std::string_view &name) {
     712        9898 :   in_place<false>::transform<TypePairs...>(op, name, std::forward<Var>(var),
     713             :                                            var1, var2);
     714        9894 : }
     715             : 
     716             : /// Transform the data elements of a variable in-place.
     717             : template <class... TypePairs, class Var, class Op>
     718       34815 : void transform_in_place(Var &&var, const Variable &var1, const Variable &var2,
     719             :                         const Variable &var3, Op op,
     720             :                         const std::string_view &name) {
     721       34815 :   in_place<false>::transform<TypePairs...>(op, name, std::forward<Var>(var),
     722             :                                            var1, var2, var3);
     723       34813 : }
     724             : 
     725             : namespace dry_run {
     726             : template <class... Ts, class Var, class Op>
     727             : void transform_in_place(Var &&var, Op op, const std::string_view &name) {
     728             :   in_place<true>::transform<Ts...>(op, name, std::forward<Var>(var));
     729             : }
     730             : template <class... TypePairs, class Var, class Op>
     731         141 : void transform_in_place(Var &&var, const Variable &other, Op op,
     732             :                         const std::string_view &name) {
     733         141 :   in_place<true>::transform<TypePairs...>(op, name, std::forward<Var>(var),
     734             :                                           other);
     735         133 : }
     736             : } // namespace dry_run
     737             : 
     738             : namespace detail {
     739             : template <class... Ts, class Op, class... Vars>
     740     1184907 : Variable transform(std::tuple<Ts...> &&, Op op, const std::string_view &name,
     741             :                    const Vars &...vars) {
     742             :   using namespace detail;
     743             :   try {
     744     1184907 :     return visit<Ts...>::apply(Transform{wrap_eigen{op}}, vars...);
     745          32 :   } catch (const std::bad_variant_access &) {
     746           6 :     throw except::TypeError(
     747           6 :         "'" + std::string(name) + "' does not support dtypes ", vars...);
     748             :   }
     749             : }
     750             : } // namespace detail
     751             : 
     752             : /// Transform the data elements of a variable and return a new Variable.
     753             : ///
     754             : /// This overload is equivalent to std::transform with a single input range, but
     755             : /// avoids the need to manually create a new variable for the output and the
     756             : /// need for, e.g., std::back_inserter.
     757             : template <class... Ts, class Op>
     758      429309 : [[nodiscard]] Variable transform(const Variable &var, Op op,
     759             :                                  const std::string_view &name) {
     760      429309 :   return detail::transform(type_tuples<Ts...>(op), op, name, var);
     761             : }
     762             : 
     763             : /// Transform the data elements of two variables and return a new Variable.
     764             : ///
     765             : /// This overload is equivalent to std::transform with two input ranges, but
     766             : /// avoids the need to manually create a new variable for the output and the
     767             : /// need for, e.g., std::back_inserter.
     768             : template <class... Ts, class Op>
     769      744114 : [[nodiscard]] Variable transform(const Variable &var1, const Variable &var2,
     770             :                                  Op op, const std::string_view &name) {
     771      744114 :   return detail::transform(type_tuples<Ts...>(op), op, name, var1, var2);
     772             : }
     773             : 
     774             : /// Transform the data elements of three variables and return a new Variable.
     775             : template <class... Ts, class Op>
     776       11318 : [[nodiscard]] Variable transform(const Variable &var1, const Variable &var2,
     777             :                                  const Variable &var3, Op op,
     778             :                                  const std::string_view &name) {
     779       11318 :   return detail::transform(type_tuples<Ts...>(op), op, name, var1, var2, var3);
     780             : }
     781             : 
     782             : /// Transform the data elements of four variables and return a new Variable.
     783             : template <class... Ts, class Op>
     784         166 : [[nodiscard]] Variable transform(const Variable &var1, const Variable &var2,
     785             :                                  const Variable &var3, const Variable &var4,
     786             :                                  Op op, const std::string_view &name) {
     787         165 :   return detail::transform(type_tuples<Ts...>(op), op, name, var1, var2, var3,
     788         166 :                            var4);
     789             : }
     790             : 
     791             : } // namespace scipp::variable

Generated by: LCOV version 1.14