LCOV - code coverage report
Current view: top level - core/include/scipp/core - multi_index.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 147 168 87.5 %
Date: 2024-12-01 01:56:34 Functions: 320 1129 28.3 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #pragma once
       6             : 
       7             : #include <functional>
       8             : #include <numeric>
       9             : #include <optional>
      10             : 
      11             : #include "scipp/common/index_composition.h"
      12             : #include "scipp/core/dimensions.h"
      13             : #include "scipp/core/element_array_view.h"
      14             : 
      15             : namespace scipp::core {
      16             : namespace detail {
      17           0 : inline auto get_nested_dims() { return Dimensions(); }
      18             : 
      19             : template <class T, class... Ts>
      20       62039 : auto get_nested_dims(const T &param, const Ts &...params) {
      21       62039 :   const auto &bin_param = param.bucketParams();
      22       62039 :   return bin_param ? bin_param.dims : get_nested_dims(params...);
      23             : }
      24             : } // namespace detail
      25             : 
      26             : template <scipp::index N> class MultiIndex {
      27             : public:
      28             :   /// Determine from arguments if binned.
      29             :   template <class... Params>
      30     1895451 :   explicit MultiIndex(const ElementArrayViewParams &param,
      31             :                       const Params &...params)
      32             :       : MultiIndex{
      33     3743749 :             (!param.bucketParams() && (!params.bucketParams() && ...))
      34     3812420 :                 ? MultiIndex(param.dims(), param.strides(), params.strides()...)
      35             :                 : MultiIndex{binned_tag{},
      36             :                              detail::get_nested_dims(param, params...),
      37             :                              param.dims(), param,
      38     5729920 :                              ElementArrayViewParams{params}...}} {}
      39             : 
      40             :   /// Construct without bins.
      41             :   template <class... StridesArgs>
      42             :   explicit MultiIndex(const Dimensions &iter_dims,
      43             :                       const StridesArgs &...strides);
      44             : 
      45             : private:
      46             :   /// Use to disambiguate between constructors.
      47             :   struct binned_tag {};
      48             : 
      49             :   /// Construct with bins.
      50             :   template <class... Params>
      51             :   explicit MultiIndex(binned_tag, const Dimensions &inner_dims,
      52             :                       const Dimensions &bin_dims, const Params &...params);
      53             : 
      54             : public:
      55    11994340 :   void increment_outer() noexcept {
      56             :     // Go through all nested dims (with bins) / all dims (without bins)
      57             :     // where we have reached the end.
      58    23988680 :     increment_in_dims(
      59    14032801 :         [this](const scipp::index data) -> scipp::index & {
      60    14032801 :           return this->m_data_index[data];
      61             :         },
      62    11994340 :         0, m_inner_ndim - 1);
      63             :     // Nested dims incremented, move on to bins.
      64             :     // Note that we do not check whether there are any bins, instead whether
      65             :     // the outer Variable is scalar because the loop above is enough to set up
      66             :     // the coord in that case.
      67    11994340 :     if (has_bins() && dim_at_end(m_inner_ndim - 1))
      68     4714030 :       seek_bin();
      69    11994340 :   }
      70             : 
      71    14593105 :   void increment() noexcept {
      72    43779315 :     for (scipp::index data = 0; data < N; ++data)
      73    29186210 :       m_data_index[data] += m_stride[0][data];
      74    14593105 :     ++m_coord[0];
      75    14593105 :     if (dim_at_end(0))
      76      249851 :       increment_outer();
      77    14593105 :   }
      78             : 
      79    11744489 :   void increment_by(const scipp::index inner_distance) noexcept {
      80    37640604 :     for (scipp::index data = 0; data < N; ++data) {
      81    25896115 :       m_data_index[data] += inner_distance * m_stride[0][data];
      82             :     }
      83    11744489 :     m_coord[0] += inner_distance;
      84    11744489 :     if (dim_at_end(0))
      85    11744489 :       increment_outer();
      86    11744489 :   }
      87             : 
      88     1895309 :   [[nodiscard]] auto inner_strides() const noexcept {
      89     1895309 :     return scipp::span<const scipp::index>(m_stride[0].data(), N);
      90             :   }
      91             : 
      92    10038156 :   [[nodiscard]] scipp::index inner_distance_to_end() const noexcept {
      93    10038156 :     return m_shape[0] - m_coord[0];
      94             :   }
      95             : 
      96             :   [[nodiscard]] scipp::index
      97     1706334 :   inner_distance_to(const MultiIndex &other) const noexcept {
      98     1706334 :     return other.m_coord[0] - m_coord[0];
      99             :   }
     100             : 
     101             :   /// Set the absolute index. In the special case of iteration with bins,
     102             :   /// this sets the *index of the bin* and NOT the full index within the
     103             :   /// iterated data.
     104     3693699 :   void set_index(const scipp::index index) noexcept {
     105     3693699 :     if (has_bins()) {
     106       68630 :       set_bins_index(index);
     107             :     } else {
     108     3625069 :       extract_indices(index, shape_it(), shape_it(m_inner_ndim), coord_it());
     109    12762727 :       for (scipp::index data = 0; data < N; ++data) {
     110     9137658 :         m_data_index[data] = flat_index(data, 0, m_inner_ndim);
     111             :       }
     112             :     }
     113     3693699 :   }
     114             : 
     115           0 :   void set_to_end() noexcept {
     116           0 :     if (has_bins()) {
     117           0 :       set_to_end_bin();
     118             :     } else {
     119           0 :       if (m_inner_ndim == 0) {
     120           0 :         m_coord[0] = 1;
     121             :       } else {
     122           0 :         zero_out_coords(m_inner_ndim - 1);
     123           0 :         m_coord[m_inner_ndim - 1] = m_shape[m_inner_ndim - 1];
     124             :       }
     125           0 :       for (scipp::index data = 0; data < N; ++data) {
     126           0 :         m_data_index[data] = flat_index(data, 0, m_inner_ndim);
     127             :       }
     128             :     }
     129           0 :   }
     130             : 
     131    26337612 :   [[nodiscard]] constexpr auto get() const noexcept { return m_data_index; }
     132             : 
     133    13639798 :   bool operator==(const MultiIndex &other) const noexcept {
     134             :     // Assuming the number dimensions match to make the check cheaper.
     135    13639798 :     return m_coord == other.m_coord;
     136             :   }
     137             : 
     138    13639798 :   bool operator!=(const MultiIndex &other) const noexcept {
     139    13639798 :     return !(*this == other); // NOLINT
     140             :   }
     141             : 
     142             :   [[nodiscard]] bool
     143    11744490 :   in_same_chunk(const MultiIndex &other,
     144             :                 const scipp::index first_dim) const noexcept {
     145             :     // Take scalars of bins into account when calculating ndim.
     146    13204742 :     for (scipp::index dim = first_dim;
     147    13204742 :          dim < m_inner_ndim + std::max(bin_ndim(), scipp::index{1}); ++dim) {
     148    11498408 :       if (m_coord[dim] != other.m_coord[dim]) {
     149    10038156 :         return false;
     150             :       }
     151             :     }
     152     1706334 :     return true;
     153             :   }
     154             : 
     155           0 :   [[nodiscard]] auto begin() const noexcept {
     156           0 :     auto it(*this);
     157           0 :     it.set_index(0);
     158           0 :     return it;
     159             :   }
     160             : 
     161           0 :   [[nodiscard]] auto end() const noexcept {
     162           0 :     auto it(*this);
     163           0 :     it.set_to_end();
     164           0 :     return it;
     165             :   }
     166             : 
     167    76766018 :   [[nodiscard]] bool has_bins() const noexcept {
     168    76766018 :     return m_nested_dim_index != -1;
     169             :   }
     170             : 
     171             :   /// Return true if the first subindex has a 0 stride
     172      648015 :   [[nodiscard]] bool has_stride_zero() const noexcept {
     173     1088090 :     for (scipp::index dim = 0; dim < m_ndim; ++dim)
     174      536994 :       if (m_stride[dim][0] == 0)
     175       96919 :         return true;
     176      551096 :     return false;
     177             :   }
     178             : 
     179             : private:
     180   120412057 :   [[nodiscard]] auto dim_at_end(const scipp::index dim) const noexcept {
     181   120412057 :     return m_coord[dim] == std::max(m_shape[dim], scipp::index{1});
     182             :   }
     183             : 
     184    61077979 :   [[nodiscard]] bool at_end() const noexcept { return dim_at_end(last_dim()); }
     185             : 
     186    61077979 :   [[nodiscard]] scipp::index last_dim() const noexcept {
     187    61077979 :     if (has_bins()) {
     188    61077979 :       return bin_ndim() == 0 ? m_ndim : m_ndim - 1;
     189             :     } else {
     190           0 :       return std::max(m_ndim - 1, scipp::index{0});
     191             :     }
     192             :   }
     193             : 
     194             :   template <class F>
     195    12490057 :   void increment_in_dims(const F &data_index, const scipp::index begin_dim,
     196             :                          const scipp::index end_dim) {
     197    19276785 :     for (scipp::index dim = begin_dim; dim < end_dim && dim_at_end(dim);
     198             :          ++dim) {
     199    21950080 :       for (scipp::index data = 0; data < N; ++data) {
     200    15163352 :         data_index(data) +=
     201             :             // take a step in dimension dim+1
     202    15163352 :             m_stride[dim + 1][data]
     203             :             // rewind dimension dim (coord(d) == m_shape[d])
     204    15163352 :             - m_coord[dim] * m_stride[dim][data];
     205             :       }
     206     6786728 :       ++m_coord[dim + 1];
     207     6786728 :       m_coord[dim] = 0;
     208             :     }
     209    12490057 :   }
     210             : 
     211    74354660 :   [[nodiscard]] constexpr auto bin_ndim() const noexcept {
     212    74354660 :     return m_ndim - m_inner_ndim;
     213             :   }
     214             : 
     215    17052956 :   [[nodiscard]] bool current_bin_is_empty() const noexcept {
     216    17052956 :     return m_shape[m_nested_dim_index] == 0;
     217             :   }
     218             : 
     219             :   struct BinIterator {
     220             :     BinIterator() = default;
     221      144566 :     explicit BinIterator(const BucketParams &bucket_params,
     222             :                          const scipp::index outer_volume)
     223      144566 :         : m_is_binned{static_cast<bool>(bucket_params)},
     224             :           // indices can be != nullptr but outer_volume == 0 when Variable
     225             :           // was sliced.
     226      144566 :           m_indices{outer_volume == 0 ? nullptr : bucket_params.indices} {}
     227             : 
     228             :     const bool m_is_binned{false};
     229             :     scipp::index m_bin_index{0};
     230             :     const std::pair<scipp::index, scipp::index> *m_indices{nullptr};
     231             :   };
     232             : 
     233      495717 :   void increment_outer_bins() noexcept {
     234      991434 :     increment_in_dims(
     235     1130551 :         [this](const scipp::index data) -> scipp::index & {
     236     1130551 :           return this->m_bin[data].m_bin_index;
     237             :         },
     238      495717 :         m_inner_ndim, m_ndim - 1);
     239      495717 :   }
     240             : 
     241    16984326 :   void increment_bins() noexcept {
     242    16984326 :     const auto dim = m_inner_ndim;
     243    51034298 :     for (scipp::index data = 0; data < N; ++data) {
     244    34049972 :       m_bin[data].m_bin_index += m_stride[dim][data];
     245             :     }
     246    16984326 :     zero_out_coords(m_inner_ndim);
     247    16984326 :     ++m_coord[dim];
     248    16984326 :     if (dim_at_end(dim))
     249      495717 :       increment_outer_bins();
     250    16984326 :     if (!at_end()) {
     251    50875172 :       for (scipp::index data = 0; data < N; ++data) {
     252    33941490 :         load_bin_params(data);
     253             :       }
     254             :     }
     255    16984326 :   }
     256             : 
     257    16984326 :   void seek_bin() noexcept {
     258             :     do {
     259    16984326 :       increment_bins();
     260    16984326 :     } while (current_bin_is_empty() && !at_end());
     261     4729842 :   }
     262             : 
     263    34240755 :   void load_bin_params(const scipp::index data) noexcept {
     264    34240755 :     if (!m_bin[data].m_is_binned) {
     265     2423795 :       m_data_index[data] = flat_index(data, 0, m_ndim);
     266    31816960 :     } else if (!at_end()) {
     267             :       // All bins are guaranteed to have the same size.
     268    31752763 :       if (m_bin[data].m_indices != nullptr) {
     269    31689271 :         const auto [begin, end] =
     270    31689271 :             m_bin[data].m_indices[m_bin[data].m_bin_index];
     271    31689271 :         m_shape[m_nested_dim_index] = end - begin;
     272    31689271 :         m_data_index[data] = m_stride[m_nested_dim_index][data] * begin;
     273             :       } else {
     274             :         // m_indices can be nullptr if there are bins, but they are empty.
     275       63492 :         m_shape[m_nested_dim_index] = 0;
     276       63492 :         m_data_index[data] = 0;
     277             :       }
     278             :     }
     279             :     // else: at end of bins
     280    34240755 :   }
     281             : 
     282       68630 :   void set_bins_index(const scipp::index index) noexcept {
     283       68630 :     if (bin_ndim() == 0 && index != 0) {
     284             :       // Scalar outer dims and setting to / past end.
     285        3309 :       set_to_end_bin();
     286             :     } else {
     287       65321 :       zero_out_coords(m_inner_ndim);
     288       65321 :       extract_indices(index, shape_it(m_inner_ndim), shape_end(),
     289             :                       coord_it(m_inner_ndim));
     290             :     }
     291             : 
     292      223368 :     for (scipp::index data = 0; data < N; ++data) {
     293      154738 :       m_bin[data].m_bin_index = flat_index(data, m_inner_ndim, m_ndim);
     294      154738 :       load_bin_params(data);
     295             :     }
     296       68630 :     if (current_bin_is_empty() && !at_end())
     297        4114 :       seek_bin();
     298       68630 :   }
     299             : 
     300        3309 :   void set_to_end_bin() noexcept {
     301        3309 :     zero_out_coords(m_ndim);
     302        3309 :     if (bin_ndim() == 0) {
     303        3309 :       m_coord[m_inner_ndim] = 1;
     304             :     } else {
     305           0 :       m_coord[m_ndim - 1] = std::max(m_shape[m_ndim - 1], scipp::index{1});
     306             :     }
     307        3309 :   }
     308             : 
     309    11716191 :   scipp::index flat_index(const scipp::index i_data, scipp::index begin_index,
     310             :                           const scipp::index end_index) {
     311    11716191 :     scipp::index res = 0;
     312    24695088 :     for (; begin_index < end_index; ++begin_index) {
     313    12978897 :       res += m_coord[begin_index] * m_stride[begin_index][i_data];
     314             :     }
     315    11716191 :     return res;
     316             :   }
     317             : 
     318    17052956 :   void zero_out_coords(const scipp::index ndim) noexcept {
     319    17052956 :     const auto end = coord_it(ndim);
     320    34160675 :     for (auto it = coord_it(); it != end; ++it) {
     321    17107719 :       *it = 0;
     322             :     }
     323    17052956 :   }
     324             : 
     325    37796302 :   [[nodiscard]] auto coord_it(const scipp::index dim = 0) noexcept {
     326    37796302 :     return m_coord.begin() + dim;
     327             :   }
     328             : 
     329     7315459 :   [[nodiscard]] auto shape_it(const scipp::index dim = 0) noexcept {
     330     7315459 :     return std::next(m_shape.begin(), dim);
     331             :   }
     332             : 
     333       65321 :   [[nodiscard]] auto shape_end() noexcept { return m_shape.begin() + m_ndim; }
     334             : 
     335             :   /// Current flat index into the operands.
     336             :   std::array<scipp::index, N> m_data_index = {};
     337             :   // This does *not* 0-init the inner arrays!
     338             :   /// Stride for each operand in each dimension.
     339             :   std::array<std::array<scipp::index, N>, NDIM_OP_MAX> m_stride = {};
     340             :   /// Current index in iteration dimensions for both bin and inner dims.
     341             :   std::array<scipp::index, NDIM_OP_MAX + 1> m_coord = {};
     342             :   /// Shape of the iteration dimensions for both bin and inner dims.
     343             :   std::array<scipp::index, NDIM_OP_MAX + 1> m_shape = {};
     344             :   /// Total number of dimensions.
     345             :   scipp::index m_ndim{0};
     346             :   /// Number of dense dimensions, i.e. same as m_ndim when not binned,
     347             :   /// else number of dims in bins.
     348             :   scipp::index m_inner_ndim{0};
     349             :   /// Index of dim referred to by bin indices to distinguish, e.g., 2D bins
     350             :   /// slicing along first or second dim.
     351             :   /// -1 if not binned.
     352             :   scipp::index m_nested_dim_index{-1};
     353             :   /// Parameters of the currently loaded bins.
     354             :   std::array<BinIterator, N> m_bin = {};
     355             : };
     356             : 
     357             : template <class... StridesArgs>
     358             : MultiIndex(const Dimensions &, const StridesArgs &...)
     359             :     -> MultiIndex<sizeof...(StridesArgs)>;
     360             : template <class... Params>
     361             : MultiIndex(const ElementArrayViewParams &, const Params &...)
     362             :     -> MultiIndex<sizeof...(Params) + 1>;
     363             : 
     364             : } // namespace scipp::core

Generated by: LCOV version 1.14