LCOV - code coverage report
Current view: top level - core/include/scipp/core - multi_index.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 147 168 87.5 %
Date: 2024-04-28 01:25:40 Functions: 316 1126 28.1 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #pragma once
       6             : 
       7             : #include <functional>
       8             : #include <numeric>
       9             : #include <optional>
      10             : 
      11             : #include "scipp/common/index_composition.h"
      12             : #include "scipp/core/dimensions.h"
      13             : #include "scipp/core/element_array_view.h"
      14             : 
      15             : namespace scipp::core {
      16             : namespace detail {
      17           0 : inline auto get_nested_dims() { return Dimensions(); }
      18             : 
      19             : template <class T, class... Ts>
      20       59053 : auto get_nested_dims(const T &param, const Ts &...params) {
      21       59053 :   const auto &bin_param = param.bucketParams();
      22       59053 :   return bin_param ? bin_param.dims : get_nested_dims(params...);
      23             : }
      24             : } // namespace detail
      25             : 
      26             : template <scipp::index N> class MultiIndex {
      27             : public:
      28             :   /// Determine from arguments if binned.
      29             :   template <class... Params>
      30     1821216 :   explicit MultiIndex(const ElementArrayViewParams &param,
      31             :                       const Params &...params)
      32             :       : MultiIndex{
      33     3597127 :             (!param.bucketParams() && (!params.bucketParams() && ...))
      34     3662025 :                 ? MultiIndex(param.dims(), param.strides(), params.strides()...)
      35             :                 : MultiIndex{binned_tag{},
      36             :                              detail::get_nested_dims(param, params...),
      37             :                              param.dims(), param,
      38     5504962 :                              ElementArrayViewParams{params}...}} {}
      39             : 
      40             :   /// Construct without bins.
      41             :   template <class... StridesArgs>
      42             :   explicit MultiIndex(const Dimensions &iter_dims,
      43             :                       const StridesArgs &...strides);
      44             : 
      45             : private:
      46             :   /// Use to disambiguate between constructors.
      47             :   struct binned_tag {};
      48             : 
      49             :   /// Construct with bins.
      50             :   template <class... Params>
      51             :   explicit MultiIndex(binned_tag, const Dimensions &inner_dims,
      52             :                       const Dimensions &bin_dims, const Params &...params);
      53             : 
      54             : public:
      55    11245135 :   void increment_outer() noexcept {
      56             :     // Go through all nested dims (with bins) / all dims (without bins)
      57             :     // where we have reached the end.
      58    22490270 :     increment_in_dims(
      59    12776411 :         [this](const scipp::index data) -> scipp::index & {
      60    12776411 :           return this->m_data_index[data];
      61             :         },
      62    11245135 :         0, m_inner_ndim - 1);
      63             :     // Nested dims incremented, move on to bins.
      64             :     // Note that we do not check whether there are any bins, instead whether
      65             :     // the outer Variable is scalar because the loop above is enough to set up
      66             :     // the coord in that case.
      67    11245135 :     if (has_bins() && dim_at_end(m_inner_ndim - 1))
      68     4437402 :       seek_bin();
      69    11245135 :   }
      70             : 
      71    14209817 :   void increment() noexcept {
      72    42629451 :     for (scipp::index data = 0; data < N; ++data)
      73    28419634 :       m_data_index[data] += m_stride[0][data];
      74    14209817 :     ++m_coord[0];
      75    14209817 :     if (dim_at_end(0))
      76      200230 :       increment_outer();
      77    14209817 :   }
      78             : 
      79    11044905 :   void increment_by(const scipp::index inner_distance) noexcept {
      80    35258016 :     for (scipp::index data = 0; data < N; ++data) {
      81    24213111 :       m_data_index[data] += inner_distance * m_stride[0][data];
      82             :     }
      83    11044905 :     m_coord[0] += inner_distance;
      84    11044905 :     if (dim_at_end(0))
      85    11044905 :       increment_outer();
      86    11044905 :   }
      87             : 
      88     1821074 :   [[nodiscard]] auto inner_strides() const noexcept {
      89     1821074 :     return scipp::span<const scipp::index>(m_stride[0].data(), N);
      90             :   }
      91             : 
      92     9398371 :   [[nodiscard]] scipp::index inner_distance_to_end() const noexcept {
      93     9398371 :     return m_shape[0] - m_coord[0];
      94             :   }
      95             : 
      96             :   [[nodiscard]] scipp::index
      97     1646535 :   inner_distance_to(const MultiIndex &other) const noexcept {
      98     1646535 :     return other.m_coord[0] - m_coord[0];
      99             :   }
     100             : 
     101             :   /// Set the absolute index. In the special case of iteration with bins,
     102             :   /// this sets the *index of the bin* and NOT the full index within the
     103             :   /// iterated data.
     104     3548789 :   void set_index(const scipp::index index) noexcept {
     105     3548789 :     if (has_bins()) {
     106       64857 :       set_bins_index(index);
     107             :     } else {
     108     3483932 :       extract_indices(index, shape_it(), shape_it(m_inner_ndim), coord_it());
     109    12287678 :       for (scipp::index data = 0; data < N; ++data) {
     110     8803746 :         m_data_index[data] = flat_index(data, 0, m_inner_ndim);
     111             :       }
     112             :     }
     113     3548789 :   }
     114             : 
     115           0 :   void set_to_end() noexcept {
     116           0 :     if (has_bins()) {
     117           0 :       set_to_end_bin();
     118             :     } else {
     119           0 :       if (m_inner_ndim == 0) {
     120           0 :         m_coord[0] = 1;
     121             :       } else {
     122           0 :         zero_out_coords(m_inner_ndim - 1);
     123           0 :         m_coord[m_inner_ndim - 1] = m_shape[m_inner_ndim - 1];
     124             :       }
     125           0 :       for (scipp::index data = 0; data < N; ++data) {
     126           0 :         m_data_index[data] = flat_index(data, 0, m_inner_ndim);
     127             :       }
     128             :     }
     129           0 :   }
     130             : 
     131    25254740 :   [[nodiscard]] constexpr auto get() const noexcept { return m_data_index; }
     132             : 
     133    12865979 :   bool operator==(const MultiIndex &other) const noexcept {
     134             :     // Assuming the number dimensions match to make the check cheaper.
     135    12865979 :     return m_coord == other.m_coord;
     136             :   }
     137             : 
     138    12865979 :   bool operator!=(const MultiIndex &other) const noexcept {
     139    12865979 :     return !(*this == other); // NOLINT
     140             :   }
     141             : 
     142             :   [[nodiscard]] bool
     143    11044906 :   in_same_chunk(const MultiIndex &other,
     144             :                 const scipp::index first_dim) const noexcept {
     145             :     // Take scalars of bins into account when calculating ndim.
     146    12407098 :     for (scipp::index dim = first_dim;
     147    12407098 :          dim < m_inner_ndim + std::max(bin_ndim(), scipp::index{1}); ++dim) {
     148    10760563 :       if (m_coord[dim] != other.m_coord[dim]) {
     149     9398371 :         return false;
     150             :       }
     151             :     }
     152     1646535 :     return true;
     153             :   }
     154             : 
     155           0 :   [[nodiscard]] auto begin() const noexcept {
     156           0 :     auto it(*this);
     157           0 :     it.set_index(0);
     158           0 :     return it;
     159             :   }
     160             : 
     161           0 :   [[nodiscard]] auto end() const noexcept {
     162           0 :     auto it(*this);
     163           0 :     it.set_to_end();
     164           0 :     return it;
     165             :   }
     166             : 
     167    73573580 :   [[nodiscard]] bool has_bins() const noexcept {
     168    73573580 :     return m_nested_dim_index != -1;
     169             :   }
     170             : 
     171             :   /// Return true if the first subindex has a 0 stride
     172      636196 :   [[nodiscard]] bool has_stride_zero() const noexcept {
     173     1060198 :     for (scipp::index dim = 0; dim < m_ndim; ++dim)
     174      517361 :       if (m_stride[dim][0] == 0)
     175       93359 :         return true;
     176      542837 :     return false;
     177             :   }
     178             : 
     179             : private:
     180   115096967 :   [[nodiscard]] auto dim_at_end(const scipp::index dim) const noexcept {
     181   115096967 :     return m_coord[dim] == std::max(m_shape[dim], scipp::index{1});
     182             :   }
     183             : 
     184    58779656 :   [[nodiscard]] bool at_end() const noexcept { return dim_at_end(last_dim()); }
     185             : 
     186    58779656 :   [[nodiscard]] scipp::index last_dim() const noexcept {
     187    58779656 :     if (has_bins()) {
     188    58779656 :       return bin_ndim() == 0 ? m_ndim : m_ndim - 1;
     189             :     } else {
     190           0 :       return std::max(m_ndim - 1, scipp::index{0});
     191             :     }
     192             :   }
     193             : 
     194             :   template <class F>
     195    11647964 :   void increment_in_dims(const F &data_index, const scipp::index begin_dim,
     196             :                          const scipp::index end_dim) {
     197    17853535 :     for (scipp::index dim = begin_dim; dim < end_dim && dim_at_end(dim);
     198             :          ++dim) {
     199    19898662 :       for (scipp::index data = 0; data < N; ++data) {
     200    13693091 :         data_index(data) +=
     201             :             // take a step in dimension dim+1
     202    13693091 :             m_stride[dim + 1][data]
     203             :             // rewind dimension dim (coord(d) == m_shape[d])
     204    13693091 :             - m_coord[dim] * m_stride[dim][data];
     205             :       }
     206     6205571 :       ++m_coord[dim + 1];
     207     6205571 :       m_coord[dim] = 0;
     208             :     }
     209    11647964 :   }
     210             : 
     211    71254818 :   [[nodiscard]] constexpr auto bin_ndim() const noexcept {
     212    71254818 :     return m_ndim - m_inner_ndim;
     213             :   }
     214             : 
     215    16383192 :   [[nodiscard]] bool current_bin_is_empty() const noexcept {
     216    16383192 :     return m_shape[m_nested_dim_index] == 0;
     217             :   }
     218             : 
     219             :   struct BinIterator {
     220             :     BinIterator() = default;
     221      140050 :     explicit BinIterator(const BucketParams &bucket_params,
     222             :                          const scipp::index outer_volume)
     223      140050 :         : m_is_binned{static_cast<bool>(bucket_params)},
     224             :           // indices can be != nullptr but outer_volume == 0 when Variable
     225             :           // was sliced.
     226      140050 :           m_indices{outer_volume == 0 ? nullptr : bucket_params.indices} {}
     227             : 
     228             :     const bool m_is_binned{false};
     229             :     scipp::index m_bin_index{0};
     230             :     const std::pair<scipp::index, scipp::index> *m_indices{nullptr};
     231             :   };
     232             : 
     233      402829 :   void increment_outer_bins() noexcept {
     234      805658 :     increment_in_dims(
     235      916680 :         [this](const scipp::index data) -> scipp::index & {
     236      916680 :           return this->m_bin[data].m_bin_index;
     237             :         },
     238      402829 :         m_inner_ndim, m_ndim - 1);
     239      402829 :   }
     240             : 
     241    16318335 :   void increment_bins() noexcept {
     242    16318335 :     const auto dim = m_inner_ndim;
     243    49053576 :     for (scipp::index data = 0; data < N; ++data) {
     244    32735241 :       m_bin[data].m_bin_index += m_stride[dim][data];
     245             :     }
     246    16318335 :     zero_out_coords(m_inner_ndim);
     247    16318335 :     ++m_coord[dim];
     248    16318335 :     if (dim_at_end(dim))
     249      402829 :       increment_outer_bins();
     250    16318335 :     if (!at_end()) {
     251    48906766 :       for (scipp::index data = 0; data < N; ++data) {
     252    32634973 :         load_bin_params(data);
     253             :       }
     254             :     }
     255    16318335 :   }
     256             : 
     257    16318335 :   void seek_bin() noexcept {
     258             :     do {
     259    16318335 :       increment_bins();
     260    16318335 :     } while (current_bin_is_empty() && !at_end());
     261     4448517 :   }
     262             : 
     263    32922156 :   void load_bin_params(const scipp::index data) noexcept {
     264    32922156 :     if (!m_bin[data].m_is_binned) {
     265     2344748 :       m_data_index[data] = flat_index(data, 0, m_ndim);
     266    30577408 :     } else if (!at_end()) {
     267             :       // All bins are guaranteed to have the same size.
     268    30516695 :       if (m_bin[data].m_indices != nullptr) {
     269    30454065 :         const auto [begin, end] =
     270    30454065 :             m_bin[data].m_indices[m_bin[data].m_bin_index];
     271    30454065 :         m_shape[m_nested_dim_index] = end - begin;
     272    30454065 :         m_data_index[data] = m_stride[m_nested_dim_index][data] * begin;
     273             :       } else {
     274             :         // m_indices can be nullptr if there are bins, but they are empty.
     275       62630 :         m_shape[m_nested_dim_index] = 0;
     276       62630 :         m_data_index[data] = 0;
     277             :       }
     278             :     }
     279             :     // else: at end of bins
     280    32922156 :   }
     281             : 
     282       64857 :   void set_bins_index(const scipp::index index) noexcept {
     283       64857 :     if (bin_ndim() == 0 && index != 0) {
     284             :       // Scalar outer dims and setting to / past end.
     285        3207 :       set_to_end_bin();
     286             :     } else {
     287       61650 :       zero_out_coords(m_inner_ndim);
     288       61650 :       extract_indices(index, shape_it(m_inner_ndim), shape_end(),
     289             :                       coord_it(m_inner_ndim));
     290             :     }
     291             : 
     292      212029 :     for (scipp::index data = 0; data < N; ++data) {
     293      147172 :       m_bin[data].m_bin_index = flat_index(data, m_inner_ndim, m_ndim);
     294      147172 :       load_bin_params(data);
     295             :     }
     296       64857 :     if (current_bin_is_empty() && !at_end())
     297        1937 :       seek_bin();
     298       64857 :   }
     299             : 
     300        3207 :   void set_to_end_bin() noexcept {
     301        3207 :     zero_out_coords(m_ndim);
     302        3207 :     if (bin_ndim() == 0) {
     303        3207 :       m_coord[m_inner_ndim] = 1;
     304             :     } else {
     305           0 :       m_coord[m_ndim - 1] = std::max(m_shape[m_ndim - 1], scipp::index{1});
     306             :     }
     307        3207 :   }
     308             : 
     309    11295666 :   scipp::index flat_index(const scipp::index i_data, scipp::index begin_index,
     310             :                           const scipp::index end_index) {
     311    11295666 :     scipp::index res = 0;
     312    23685524 :     for (; begin_index < end_index; ++begin_index) {
     313    12389858 :       res += m_coord[begin_index] * m_stride[begin_index][i_data];
     314             :     }
     315    11295666 :     return res;
     316             :   }
     317             : 
     318    16383192 :   void zero_out_coords(const scipp::index ndim) noexcept {
     319    16383192 :     const auto end = coord_it(ndim);
     320    32821147 :     for (auto it = coord_it(); it != end; ++it) {
     321    16437955 :       *it = 0;
     322             :     }
     323    16383192 :   }
     324             : 
     325    36311966 :   [[nodiscard]] auto coord_it(const scipp::index dim = 0) noexcept {
     326    36311966 :     return m_coord.begin() + dim;
     327             :   }
     328             : 
     329     7029514 :   [[nodiscard]] auto shape_it(const scipp::index dim = 0) noexcept {
     330     7029514 :     return std::next(m_shape.begin(), dim);
     331             :   }
     332             : 
     333       61650 :   [[nodiscard]] auto shape_end() noexcept { return m_shape.begin() + m_ndim; }
     334             : 
     335             :   /// Current flat index into the operands.
     336             :   std::array<scipp::index, N> m_data_index = {};
     337             :   // This does *not* 0-init the inner arrays!
     338             :   /// Stride for each operand in each dimension.
     339             :   std::array<std::array<scipp::index, N>, NDIM_OP_MAX> m_stride = {};
     340             :   /// Current index in iteration dimensions for both bin and inner dims.
     341             :   std::array<scipp::index, NDIM_OP_MAX + 1> m_coord = {};
     342             :   /// Shape of the iteration dimensions for both bin and inner dims.
     343             :   std::array<scipp::index, NDIM_OP_MAX + 1> m_shape = {};
     344             :   /// Total number of dimensions.
     345             :   scipp::index m_ndim{0};
     346             :   /// Number of dense dimensions, i.e. same as m_ndim when not binned,
     347             :   /// else number of dims in bins.
     348             :   scipp::index m_inner_ndim{0};
     349             :   /// Index of dim referred to by bin indices to distinguish, e.g., 2D bins
     350             :   /// slicing along first or second dim.
     351             :   /// -1 if not binned.
     352             :   scipp::index m_nested_dim_index{-1};
     353             :   /// Parameters of the currently loaded bins.
     354             :   std::array<BinIterator, N> m_bin = {};
     355             : };
     356             : 
     357             : template <class... StridesArgs>
     358             : MultiIndex(const Dimensions &, const StridesArgs &...)
     359             :     -> MultiIndex<sizeof...(StridesArgs)>;
     360             : template <class... Params>
     361             : MultiIndex(const ElementArrayViewParams &, const Params &...)
     362             :     -> MultiIndex<sizeof...(Params) + 1>;
     363             : 
     364             : } // namespace scipp::core

Generated by: LCOV version 1.14