LCOV - code coverage report
Current view: top level - common/include/scipp/common - index_composition.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 27 27 100.0 %
Date: 2024-04-28 01:25:40 Functions: 4 4 100.0 %

          Line data    Source code
       1             : // SPDX-License-Identifier: GPL-3.0-or-later
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Jan-Lukas Wynen
       5             : #pragma once
       6             : 
       7             : #include <array>
       8             : #include <cassert>
       9             : #include <cstddef>
      10             : #include <utility>
      11             : 
      12             : #include "scipp/common/index.h"
      13             : 
      14             : namespace scipp {
      15             : /// Compute a flat index from strides and a multi-dimensional index.
      16             : ///
      17             : /// @return sum_{i=0}^{ndim} ( strides[i] * indices[i] )
      18             : /// @note This function uses *strides* and not a shape, meaning strides[d]
      19             : ///       is not the extent of the array in dimension d but rather the step
      20             : ///       length in the flat index to advance one element in d.
      21             : ///       Therefore, some conversion of parameters is required when inverting
      22             : ///       the result with `extract_indices`.
      23             : template <class ForwardIt1, class ForwardIt2>
      24     2416730 : constexpr auto flat_index_from_strides(ForwardIt1 strides_it,
      25             :                                        const ForwardIt1 strides_end,
      26             :                                        ForwardIt2 indices_it) noexcept {
      27     2416730 :   std::decay_t<decltype(*strides_it)> result = 0;
      28     5500858 :   for (; strides_it != strides_end; ++strides_it, ++indices_it) {
      29     3084128 :     result += *strides_it * *indices_it;
      30             :   }
      31     2416730 :   return result;
      32             : }
      33             : 
      34             : /// Compute the bounds of a piece of memory.
      35             : ///
      36             : /// Given a pointer `p` to some memory, a shape, and strides, this function
      37             : /// returns a begin and an end index such that
      38             : /// - `p + begin` is the smallest reachable address and
      39             : /// - `p + end` is one past the largest reachable address
      40             : ///
      41             : /// \return A pair of indices `{begin, end}`.
      42             : template <class ForwardIt1, class ForwardIt2>
      43       22984 : constexpr auto memory_bounds(ForwardIt1 shape_it, const ForwardIt1 shape_end,
      44             :                              ForwardIt2 strides_it) noexcept {
      45       22984 :   if (shape_it == shape_end) {
      46             :     // Scalars are one element wide in memory, this would not be handled
      47             :     // correctly by the code below.
      48         482 :     return std::pair{scipp::index{0}, scipp::index{1}};
      49             :   }
      50       22502 :   scipp::index begin = 0;
      51       22502 :   scipp::index end = 0;
      52       49923 :   for (; shape_it != shape_end; ++shape_it, ++strides_it) {
      53       26975 :     if (*strides_it < 0)
      54         669 :       begin += *shape_it * *strides_it;
      55             :     else
      56       26752 :       end += *shape_it * *strides_it;
      57             :   }
      58       22502 :   return std::pair{begin, end};
      59             : }
      60             : 
      61             : /// Extract individual indices from a flat index.
      62             : ///
      63             : /// Let
      64             : ///     I = i_0 + l_0 * (i_1 + l_1 * (i_2 + ... (i_{n-2} + l_{n-2} * i_{n-1})))
      65             : /// be a flat index computed from indices {i_d} and shape {l_d} in
      66             : /// 'column-major' order. Here, this means that i_0 is the fasted moving index
      67             : /// and i_{n-1} is slowest.
      68             : ///
      69             : /// If I == prod_{d=0}^ndim (l_d), i.e. one element past the end,
      70             : /// the resulting indices are i_d = 0 for d < ndim-1, i_{ndim-1} = l_{ndim-1}
      71             : /// unless l_{ndim-1} = 0, see below.
      72             : /// This allows setting 'end-iterators' in a well defined manner.
      73             : /// However, the result is undefined for greater values of I.
      74             : ///
      75             : /// Values of array elements in `indices` with d > ndim-1 are unspecified
      76             : /// except when ndim == 0, i_0 = I.
      77             : ///
      78             : /// Any number of l_d maybe 0 which yields i_d = 0.
      79             : /// Except for the one-past-the-end case described above, i_{ndim-1} = 1 if
      80             : /// l_{ndim-1} to allow this case to be distinguishable from an index
      81             : /// to the end.
      82             : ///
      83             : /// @param flat_index I
      84             : /// @param shape_it Begin iterator for {l_d}.
      85             : /// @param shape_end End iterator for {l_d}.
      86             : /// @param indices_it Begin iterator for {i_d}.
      87             : ///                   `*indices_it` must always be writeable, even when
      88             : ///                   `shape_it == shape_end`.
      89             : /// @note This function uses a *shape*, i.e. individual dimension sizes
      90             : ///       to encode the size of the array.
      91             : ///       Therefore, some conversion of parameters is required when inverting
      92             : ///       the result with `flat_index_from_strides`.
      93             : template <class It1, class It2>
      94     5962312 : constexpr void extract_indices(scipp::index flat_index, It1 shape_it,
      95             :                                It1 shape_end, It2 indices_it) noexcept {
      96     5962312 :   if (shape_it == shape_end) {
      97     2592917 :     *indices_it = flat_index;
      98     2592917 :     return;
      99             :   }
     100     3369395 :   shape_end--; // The last element is set after the loop.
     101     4713458 :   for (; shape_it != shape_end; ++shape_it, ++indices_it) {
     102     1344063 :     if (*shape_it != 0) {
     103     1342936 :       const scipp::index aux = flat_index / *shape_it;
     104     1342936 :       *indices_it = flat_index - aux * *shape_it;
     105     1342936 :       flat_index = aux;
     106             :     } else {
     107        1127 :       *indices_it = 0;
     108             :     }
     109             :   }
     110     3369395 :   *indices_it = flat_index;
     111             : }
     112             : /* Implementation notes for extract_indices
     113             :  *
     114             :  * With ndim == 2, we have
     115             :  *     I = i_0 + l_0 * i_1
     116             :  * All numbers are positive integers. Thus I can be decomposed using
     117             :  * integer division as follows (note that i_0 < l_0):
     118             :  *     x = I / l_0
     119             :  *     i_0 = I - x * l_0
     120             :  *     i_1 = x
     121             :  *
     122             :  * With ndim == 3, we have
     123             :  *     I = i_0 + l_0 * (i_1 + l_1 * i_2)
     124             :  * which can be decomposed as above:
     125             :  *     x = I / l_0
     126             :  *     i_0 = I - x * l_0
     127             :  * Noting that
     128             :  *     x = i_1 _ l_1 * i_2
     129             :  * we can compute i_1 and i_2 by applying the algorithm recursively.
     130             :  *
     131             :  * The function implements this algorithm for arbitrary dimensions by rolling
     132             :  * the recursion into a loop.
     133             :  */
     134             : 
     135             : } // namespace scipp

Generated by: LCOV version 1.14