Line data Source code
1 : // SPDX-License-Identifier: GPL-3.0-or-later 2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 3 : /// @file 4 : /// @author Jan-Lukas Wynen 5 : #pragma once 6 : 7 : #include <array> 8 : #include <cassert> 9 : #include <cstddef> 10 : #include <utility> 11 : 12 : #include "scipp/common/index.h" 13 : 14 : namespace scipp { 15 : /// Compute a flat index from strides and a multi-dimensional index. 16 : /// 17 : /// @return sum_{i=0}^{ndim} ( strides[i] * indices[i] ) 18 : /// @note This function uses *strides* and not a shape, meaning strides[d] 19 : /// is not the extent of the array in dimension d but rather the step 20 : /// length in the flat index to advance one element in d. 21 : /// Therefore, some conversion of parameters is required when inverting 22 : /// the result with `extract_indices`. 23 : template <class ForwardIt1, class ForwardIt2> 24 2416730 : constexpr auto flat_index_from_strides(ForwardIt1 strides_it, 25 : const ForwardIt1 strides_end, 26 : ForwardIt2 indices_it) noexcept { 27 2416730 : std::decay_t<decltype(*strides_it)> result = 0; 28 5500858 : for (; strides_it != strides_end; ++strides_it, ++indices_it) { 29 3084128 : result += *strides_it * *indices_it; 30 : } 31 2416730 : return result; 32 : } 33 : 34 : /// Compute the bounds of a piece of memory. 35 : /// 36 : /// Given a pointer `p` to some memory, a shape, and strides, this function 37 : /// returns a begin and an end index such that 38 : /// - `p + begin` is the smallest reachable address and 39 : /// - `p + end` is one past the largest reachable address 40 : /// 41 : /// \return A pair of indices `{begin, end}`. 42 : template <class ForwardIt1, class ForwardIt2> 43 22984 : constexpr auto memory_bounds(ForwardIt1 shape_it, const ForwardIt1 shape_end, 44 : ForwardIt2 strides_it) noexcept { 45 22984 : if (shape_it == shape_end) { 46 : // Scalars are one element wide in memory, this would not be handled 47 : // correctly by the code below. 48 482 : return std::pair{scipp::index{0}, scipp::index{1}}; 49 : } 50 22502 : scipp::index begin = 0; 51 22502 : scipp::index end = 0; 52 49923 : for (; shape_it != shape_end; ++shape_it, ++strides_it) { 53 26975 : if (*strides_it < 0) 54 669 : begin += *shape_it * *strides_it; 55 : else 56 26752 : end += *shape_it * *strides_it; 57 : } 58 22502 : return std::pair{begin, end}; 59 : } 60 : 61 : /// Extract individual indices from a flat index. 62 : /// 63 : /// Let 64 : /// I = i_0 + l_0 * (i_1 + l_1 * (i_2 + ... (i_{n-2} + l_{n-2} * i_{n-1}))) 65 : /// be a flat index computed from indices {i_d} and shape {l_d} in 66 : /// 'column-major' order. Here, this means that i_0 is the fasted moving index 67 : /// and i_{n-1} is slowest. 68 : /// 69 : /// If I == prod_{d=0}^ndim (l_d), i.e. one element past the end, 70 : /// the resulting indices are i_d = 0 for d < ndim-1, i_{ndim-1} = l_{ndim-1} 71 : /// unless l_{ndim-1} = 0, see below. 72 : /// This allows setting 'end-iterators' in a well defined manner. 73 : /// However, the result is undefined for greater values of I. 74 : /// 75 : /// Values of array elements in `indices` with d > ndim-1 are unspecified 76 : /// except when ndim == 0, i_0 = I. 77 : /// 78 : /// Any number of l_d maybe 0 which yields i_d = 0. 79 : /// Except for the one-past-the-end case described above, i_{ndim-1} = 1 if 80 : /// l_{ndim-1} to allow this case to be distinguishable from an index 81 : /// to the end. 82 : /// 83 : /// @param flat_index I 84 : /// @param shape_it Begin iterator for {l_d}. 85 : /// @param shape_end End iterator for {l_d}. 86 : /// @param indices_it Begin iterator for {i_d}. 87 : /// `*indices_it` must always be writeable, even when 88 : /// `shape_it == shape_end`. 89 : /// @note This function uses a *shape*, i.e. individual dimension sizes 90 : /// to encode the size of the array. 91 : /// Therefore, some conversion of parameters is required when inverting 92 : /// the result with `flat_index_from_strides`. 93 : template <class It1, class It2> 94 5962312 : constexpr void extract_indices(scipp::index flat_index, It1 shape_it, 95 : It1 shape_end, It2 indices_it) noexcept { 96 5962312 : if (shape_it == shape_end) { 97 2592917 : *indices_it = flat_index; 98 2592917 : return; 99 : } 100 3369395 : shape_end--; // The last element is set after the loop. 101 4713458 : for (; shape_it != shape_end; ++shape_it, ++indices_it) { 102 1344063 : if (*shape_it != 0) { 103 1342936 : const scipp::index aux = flat_index / *shape_it; 104 1342936 : *indices_it = flat_index - aux * *shape_it; 105 1342936 : flat_index = aux; 106 : } else { 107 1127 : *indices_it = 0; 108 : } 109 : } 110 3369395 : *indices_it = flat_index; 111 : } 112 : /* Implementation notes for extract_indices 113 : * 114 : * With ndim == 2, we have 115 : * I = i_0 + l_0 * i_1 116 : * All numbers are positive integers. Thus I can be decomposed using 117 : * integer division as follows (note that i_0 < l_0): 118 : * x = I / l_0 119 : * i_0 = I - x * l_0 120 : * i_1 = x 121 : * 122 : * With ndim == 3, we have 123 : * I = i_0 + l_0 * (i_1 + l_1 * i_2) 124 : * which can be decomposed as above: 125 : * x = I / l_0 126 : * i_0 = I - x * l_0 127 : * Noting that 128 : * x = i_1 _ l_1 * i_2 129 : * we can compute i_1 and i_2 by applying the algorithm recursively. 130 : * 131 : * The function implements this algorithm for arbitrary dimensions by rolling 132 : * the recursion into a loop. 133 : */ 134 : 135 : } // namespace scipp