Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #pragma once
6 :
7 : #include <functional>
8 : #include <numeric>
9 : #include <optional>
10 :
11 : #include "scipp/common/index_composition.h"
12 : #include "scipp/core/dimensions.h"
13 : #include "scipp/core/element_array_view.h"
14 :
15 : namespace scipp::core {
16 : namespace detail {
17 0 : inline auto get_nested_dims() { return Dimensions(); }
18 :
19 : template <class T, class... Ts>
20 59053 : auto get_nested_dims(const T ¶m, const Ts &...params) {
21 59053 : const auto &bin_param = param.bucketParams();
22 59053 : return bin_param ? bin_param.dims : get_nested_dims(params...);
23 : }
24 : } // namespace detail
25 :
26 : template <scipp::index N> class MultiIndex {
27 : public:
28 : /// Determine from arguments if binned.
29 : template <class... Params>
30 1821216 : explicit MultiIndex(const ElementArrayViewParams ¶m,
31 : const Params &...params)
32 : : MultiIndex{
33 3597127 : (!param.bucketParams() && (!params.bucketParams() && ...))
34 3662025 : ? MultiIndex(param.dims(), param.strides(), params.strides()...)
35 : : MultiIndex{binned_tag{},
36 : detail::get_nested_dims(param, params...),
37 : param.dims(), param,
38 5504962 : ElementArrayViewParams{params}...}} {}
39 :
40 : /// Construct without bins.
41 : template <class... StridesArgs>
42 : explicit MultiIndex(const Dimensions &iter_dims,
43 : const StridesArgs &...strides);
44 :
45 : private:
46 : /// Use to disambiguate between constructors.
47 : struct binned_tag {};
48 :
49 : /// Construct with bins.
50 : template <class... Params>
51 : explicit MultiIndex(binned_tag, const Dimensions &inner_dims,
52 : const Dimensions &bin_dims, const Params &...params);
53 :
54 : public:
55 11245135 : void increment_outer() noexcept {
56 : // Go through all nested dims (with bins) / all dims (without bins)
57 : // where we have reached the end.
58 22490270 : increment_in_dims(
59 12776411 : [this](const scipp::index data) -> scipp::index & {
60 12776411 : return this->m_data_index[data];
61 : },
62 11245135 : 0, m_inner_ndim - 1);
63 : // Nested dims incremented, move on to bins.
64 : // Note that we do not check whether there are any bins, instead whether
65 : // the outer Variable is scalar because the loop above is enough to set up
66 : // the coord in that case.
67 11245135 : if (has_bins() && dim_at_end(m_inner_ndim - 1))
68 4437402 : seek_bin();
69 11245135 : }
70 :
71 14209817 : void increment() noexcept {
72 42629451 : for (scipp::index data = 0; data < N; ++data)
73 28419634 : m_data_index[data] += m_stride[0][data];
74 14209817 : ++m_coord[0];
75 14209817 : if (dim_at_end(0))
76 200230 : increment_outer();
77 14209817 : }
78 :
79 11044905 : void increment_by(const scipp::index inner_distance) noexcept {
80 35258016 : for (scipp::index data = 0; data < N; ++data) {
81 24213111 : m_data_index[data] += inner_distance * m_stride[0][data];
82 : }
83 11044905 : m_coord[0] += inner_distance;
84 11044905 : if (dim_at_end(0))
85 11044905 : increment_outer();
86 11044905 : }
87 :
88 1821074 : [[nodiscard]] auto inner_strides() const noexcept {
89 1821074 : return scipp::span<const scipp::index>(m_stride[0].data(), N);
90 : }
91 :
92 9398371 : [[nodiscard]] scipp::index inner_distance_to_end() const noexcept {
93 9398371 : return m_shape[0] - m_coord[0];
94 : }
95 :
96 : [[nodiscard]] scipp::index
97 1646535 : inner_distance_to(const MultiIndex &other) const noexcept {
98 1646535 : return other.m_coord[0] - m_coord[0];
99 : }
100 :
101 : /// Set the absolute index. In the special case of iteration with bins,
102 : /// this sets the *index of the bin* and NOT the full index within the
103 : /// iterated data.
104 3548789 : void set_index(const scipp::index index) noexcept {
105 3548789 : if (has_bins()) {
106 64857 : set_bins_index(index);
107 : } else {
108 3483932 : extract_indices(index, shape_it(), shape_it(m_inner_ndim), coord_it());
109 12287678 : for (scipp::index data = 0; data < N; ++data) {
110 8803746 : m_data_index[data] = flat_index(data, 0, m_inner_ndim);
111 : }
112 : }
113 3548789 : }
114 :
115 0 : void set_to_end() noexcept {
116 0 : if (has_bins()) {
117 0 : set_to_end_bin();
118 : } else {
119 0 : if (m_inner_ndim == 0) {
120 0 : m_coord[0] = 1;
121 : } else {
122 0 : zero_out_coords(m_inner_ndim - 1);
123 0 : m_coord[m_inner_ndim - 1] = m_shape[m_inner_ndim - 1];
124 : }
125 0 : for (scipp::index data = 0; data < N; ++data) {
126 0 : m_data_index[data] = flat_index(data, 0, m_inner_ndim);
127 : }
128 : }
129 0 : }
130 :
131 25254740 : [[nodiscard]] constexpr auto get() const noexcept { return m_data_index; }
132 :
133 12865979 : bool operator==(const MultiIndex &other) const noexcept {
134 : // Assuming the number dimensions match to make the check cheaper.
135 12865979 : return m_coord == other.m_coord;
136 : }
137 :
138 12865979 : bool operator!=(const MultiIndex &other) const noexcept {
139 12865979 : return !(*this == other); // NOLINT
140 : }
141 :
142 : [[nodiscard]] bool
143 11044906 : in_same_chunk(const MultiIndex &other,
144 : const scipp::index first_dim) const noexcept {
145 : // Take scalars of bins into account when calculating ndim.
146 12407098 : for (scipp::index dim = first_dim;
147 12407098 : dim < m_inner_ndim + std::max(bin_ndim(), scipp::index{1}); ++dim) {
148 10760563 : if (m_coord[dim] != other.m_coord[dim]) {
149 9398371 : return false;
150 : }
151 : }
152 1646535 : return true;
153 : }
154 :
155 0 : [[nodiscard]] auto begin() const noexcept {
156 0 : auto it(*this);
157 0 : it.set_index(0);
158 0 : return it;
159 : }
160 :
161 0 : [[nodiscard]] auto end() const noexcept {
162 0 : auto it(*this);
163 0 : it.set_to_end();
164 0 : return it;
165 : }
166 :
167 73573580 : [[nodiscard]] bool has_bins() const noexcept {
168 73573580 : return m_nested_dim_index != -1;
169 : }
170 :
171 : /// Return true if the first subindex has a 0 stride
172 636196 : [[nodiscard]] bool has_stride_zero() const noexcept {
173 1060198 : for (scipp::index dim = 0; dim < m_ndim; ++dim)
174 517361 : if (m_stride[dim][0] == 0)
175 93359 : return true;
176 542837 : return false;
177 : }
178 :
179 : private:
180 115096967 : [[nodiscard]] auto dim_at_end(const scipp::index dim) const noexcept {
181 115096967 : return m_coord[dim] == std::max(m_shape[dim], scipp::index{1});
182 : }
183 :
184 58779656 : [[nodiscard]] bool at_end() const noexcept { return dim_at_end(last_dim()); }
185 :
186 58779656 : [[nodiscard]] scipp::index last_dim() const noexcept {
187 58779656 : if (has_bins()) {
188 58779656 : return bin_ndim() == 0 ? m_ndim : m_ndim - 1;
189 : } else {
190 0 : return std::max(m_ndim - 1, scipp::index{0});
191 : }
192 : }
193 :
194 : template <class F>
195 11647964 : void increment_in_dims(const F &data_index, const scipp::index begin_dim,
196 : const scipp::index end_dim) {
197 17853535 : for (scipp::index dim = begin_dim; dim < end_dim && dim_at_end(dim);
198 : ++dim) {
199 19898662 : for (scipp::index data = 0; data < N; ++data) {
200 13693091 : data_index(data) +=
201 : // take a step in dimension dim+1
202 13693091 : m_stride[dim + 1][data]
203 : // rewind dimension dim (coord(d) == m_shape[d])
204 13693091 : - m_coord[dim] * m_stride[dim][data];
205 : }
206 6205571 : ++m_coord[dim + 1];
207 6205571 : m_coord[dim] = 0;
208 : }
209 11647964 : }
210 :
211 71254818 : [[nodiscard]] constexpr auto bin_ndim() const noexcept {
212 71254818 : return m_ndim - m_inner_ndim;
213 : }
214 :
215 16383192 : [[nodiscard]] bool current_bin_is_empty() const noexcept {
216 16383192 : return m_shape[m_nested_dim_index] == 0;
217 : }
218 :
219 : struct BinIterator {
220 : BinIterator() = default;
221 140050 : explicit BinIterator(const BucketParams &bucket_params,
222 : const scipp::index outer_volume)
223 140050 : : m_is_binned{static_cast<bool>(bucket_params)},
224 : // indices can be != nullptr but outer_volume == 0 when Variable
225 : // was sliced.
226 140050 : m_indices{outer_volume == 0 ? nullptr : bucket_params.indices} {}
227 :
228 : const bool m_is_binned{false};
229 : scipp::index m_bin_index{0};
230 : const std::pair<scipp::index, scipp::index> *m_indices{nullptr};
231 : };
232 :
233 402829 : void increment_outer_bins() noexcept {
234 805658 : increment_in_dims(
235 916680 : [this](const scipp::index data) -> scipp::index & {
236 916680 : return this->m_bin[data].m_bin_index;
237 : },
238 402829 : m_inner_ndim, m_ndim - 1);
239 402829 : }
240 :
241 16318335 : void increment_bins() noexcept {
242 16318335 : const auto dim = m_inner_ndim;
243 49053576 : for (scipp::index data = 0; data < N; ++data) {
244 32735241 : m_bin[data].m_bin_index += m_stride[dim][data];
245 : }
246 16318335 : zero_out_coords(m_inner_ndim);
247 16318335 : ++m_coord[dim];
248 16318335 : if (dim_at_end(dim))
249 402829 : increment_outer_bins();
250 16318335 : if (!at_end()) {
251 48906766 : for (scipp::index data = 0; data < N; ++data) {
252 32634973 : load_bin_params(data);
253 : }
254 : }
255 16318335 : }
256 :
257 16318335 : void seek_bin() noexcept {
258 : do {
259 16318335 : increment_bins();
260 16318335 : } while (current_bin_is_empty() && !at_end());
261 4448517 : }
262 :
263 32922156 : void load_bin_params(const scipp::index data) noexcept {
264 32922156 : if (!m_bin[data].m_is_binned) {
265 2344748 : m_data_index[data] = flat_index(data, 0, m_ndim);
266 30577408 : } else if (!at_end()) {
267 : // All bins are guaranteed to have the same size.
268 30516695 : if (m_bin[data].m_indices != nullptr) {
269 30454065 : const auto [begin, end] =
270 30454065 : m_bin[data].m_indices[m_bin[data].m_bin_index];
271 30454065 : m_shape[m_nested_dim_index] = end - begin;
272 30454065 : m_data_index[data] = m_stride[m_nested_dim_index][data] * begin;
273 : } else {
274 : // m_indices can be nullptr if there are bins, but they are empty.
275 62630 : m_shape[m_nested_dim_index] = 0;
276 62630 : m_data_index[data] = 0;
277 : }
278 : }
279 : // else: at end of bins
280 32922156 : }
281 :
282 64857 : void set_bins_index(const scipp::index index) noexcept {
283 64857 : if (bin_ndim() == 0 && index != 0) {
284 : // Scalar outer dims and setting to / past end.
285 3207 : set_to_end_bin();
286 : } else {
287 61650 : zero_out_coords(m_inner_ndim);
288 61650 : extract_indices(index, shape_it(m_inner_ndim), shape_end(),
289 : coord_it(m_inner_ndim));
290 : }
291 :
292 212029 : for (scipp::index data = 0; data < N; ++data) {
293 147172 : m_bin[data].m_bin_index = flat_index(data, m_inner_ndim, m_ndim);
294 147172 : load_bin_params(data);
295 : }
296 64857 : if (current_bin_is_empty() && !at_end())
297 1937 : seek_bin();
298 64857 : }
299 :
300 3207 : void set_to_end_bin() noexcept {
301 3207 : zero_out_coords(m_ndim);
302 3207 : if (bin_ndim() == 0) {
303 3207 : m_coord[m_inner_ndim] = 1;
304 : } else {
305 0 : m_coord[m_ndim - 1] = std::max(m_shape[m_ndim - 1], scipp::index{1});
306 : }
307 3207 : }
308 :
309 11295666 : scipp::index flat_index(const scipp::index i_data, scipp::index begin_index,
310 : const scipp::index end_index) {
311 11295666 : scipp::index res = 0;
312 23685524 : for (; begin_index < end_index; ++begin_index) {
313 12389858 : res += m_coord[begin_index] * m_stride[begin_index][i_data];
314 : }
315 11295666 : return res;
316 : }
317 :
318 16383192 : void zero_out_coords(const scipp::index ndim) noexcept {
319 16383192 : const auto end = coord_it(ndim);
320 32821147 : for (auto it = coord_it(); it != end; ++it) {
321 16437955 : *it = 0;
322 : }
323 16383192 : }
324 :
325 36311966 : [[nodiscard]] auto coord_it(const scipp::index dim = 0) noexcept {
326 36311966 : return m_coord.begin() + dim;
327 : }
328 :
329 7029514 : [[nodiscard]] auto shape_it(const scipp::index dim = 0) noexcept {
330 7029514 : return std::next(m_shape.begin(), dim);
331 : }
332 :
333 61650 : [[nodiscard]] auto shape_end() noexcept { return m_shape.begin() + m_ndim; }
334 :
335 : /// Current flat index into the operands.
336 : std::array<scipp::index, N> m_data_index = {};
337 : // This does *not* 0-init the inner arrays!
338 : /// Stride for each operand in each dimension.
339 : std::array<std::array<scipp::index, N>, NDIM_OP_MAX> m_stride = {};
340 : /// Current index in iteration dimensions for both bin and inner dims.
341 : std::array<scipp::index, NDIM_OP_MAX + 1> m_coord = {};
342 : /// Shape of the iteration dimensions for both bin and inner dims.
343 : std::array<scipp::index, NDIM_OP_MAX + 1> m_shape = {};
344 : /// Total number of dimensions.
345 : scipp::index m_ndim{0};
346 : /// Number of dense dimensions, i.e. same as m_ndim when not binned,
347 : /// else number of dims in bins.
348 : scipp::index m_inner_ndim{0};
349 : /// Index of dim referred to by bin indices to distinguish, e.g., 2D bins
350 : /// slicing along first or second dim.
351 : /// -1 if not binned.
352 : scipp::index m_nested_dim_index{-1};
353 : /// Parameters of the currently loaded bins.
354 : std::array<BinIterator, N> m_bin = {};
355 : };
356 :
357 : template <class... StridesArgs>
358 : MultiIndex(const Dimensions &, const StridesArgs &...)
359 : -> MultiIndex<sizeof...(StridesArgs)>;
360 : template <class... Params>
361 : MultiIndex(const ElementArrayViewParams &, const Params &...)
362 : -> MultiIndex<sizeof...(Params) + 1>;
363 :
364 : } // namespace scipp::core
|