Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #pragma once
6 :
7 : #include <functional>
8 : #include <numeric>
9 : #include <optional>
10 :
11 : #include "scipp/common/index_composition.h"
12 : #include "scipp/core/dimensions.h"
13 : #include "scipp/core/element_array_view.h"
14 :
15 : namespace scipp::core {
16 : namespace detail {
17 0 : inline auto get_nested_dims() { return Dimensions(); }
18 :
19 : template <class T, class... Ts>
20 62039 : auto get_nested_dims(const T ¶m, const Ts &...params) {
21 62039 : const auto &bin_param = param.bucketParams();
22 62039 : return bin_param ? bin_param.dims : get_nested_dims(params...);
23 : }
24 : } // namespace detail
25 :
26 : template <scipp::index N> class MultiIndex {
27 : public:
28 : /// Determine from arguments if binned.
29 : template <class... Params>
30 1895451 : explicit MultiIndex(const ElementArrayViewParams ¶m,
31 : const Params &...params)
32 : : MultiIndex{
33 3743749 : (!param.bucketParams() && (!params.bucketParams() && ...))
34 3812420 : ? MultiIndex(param.dims(), param.strides(), params.strides()...)
35 : : MultiIndex{binned_tag{},
36 : detail::get_nested_dims(param, params...),
37 : param.dims(), param,
38 5729920 : ElementArrayViewParams{params}...}} {}
39 :
40 : /// Construct without bins.
41 : template <class... StridesArgs>
42 : explicit MultiIndex(const Dimensions &iter_dims,
43 : const StridesArgs &...strides);
44 :
45 : private:
46 : /// Use to disambiguate between constructors.
47 : struct binned_tag {};
48 :
49 : /// Construct with bins.
50 : template <class... Params>
51 : explicit MultiIndex(binned_tag, const Dimensions &inner_dims,
52 : const Dimensions &bin_dims, const Params &...params);
53 :
54 : public:
55 11994340 : void increment_outer() noexcept {
56 : // Go through all nested dims (with bins) / all dims (without bins)
57 : // where we have reached the end.
58 23988680 : increment_in_dims(
59 14032801 : [this](const scipp::index data) -> scipp::index & {
60 14032801 : return this->m_data_index[data];
61 : },
62 11994340 : 0, m_inner_ndim - 1);
63 : // Nested dims incremented, move on to bins.
64 : // Note that we do not check whether there are any bins, instead whether
65 : // the outer Variable is scalar because the loop above is enough to set up
66 : // the coord in that case.
67 11994340 : if (has_bins() && dim_at_end(m_inner_ndim - 1))
68 4714030 : seek_bin();
69 11994340 : }
70 :
71 14593105 : void increment() noexcept {
72 43779315 : for (scipp::index data = 0; data < N; ++data)
73 29186210 : m_data_index[data] += m_stride[0][data];
74 14593105 : ++m_coord[0];
75 14593105 : if (dim_at_end(0))
76 249851 : increment_outer();
77 14593105 : }
78 :
79 11744489 : void increment_by(const scipp::index inner_distance) noexcept {
80 37640604 : for (scipp::index data = 0; data < N; ++data) {
81 25896115 : m_data_index[data] += inner_distance * m_stride[0][data];
82 : }
83 11744489 : m_coord[0] += inner_distance;
84 11744489 : if (dim_at_end(0))
85 11744489 : increment_outer();
86 11744489 : }
87 :
88 1895309 : [[nodiscard]] auto inner_strides() const noexcept {
89 1895309 : return scipp::span<const scipp::index>(m_stride[0].data(), N);
90 : }
91 :
92 10038156 : [[nodiscard]] scipp::index inner_distance_to_end() const noexcept {
93 10038156 : return m_shape[0] - m_coord[0];
94 : }
95 :
96 : [[nodiscard]] scipp::index
97 1706334 : inner_distance_to(const MultiIndex &other) const noexcept {
98 1706334 : return other.m_coord[0] - m_coord[0];
99 : }
100 :
101 : /// Set the absolute index. In the special case of iteration with bins,
102 : /// this sets the *index of the bin* and NOT the full index within the
103 : /// iterated data.
104 3693699 : void set_index(const scipp::index index) noexcept {
105 3693699 : if (has_bins()) {
106 68630 : set_bins_index(index);
107 : } else {
108 3625069 : extract_indices(index, shape_it(), shape_it(m_inner_ndim), coord_it());
109 12762727 : for (scipp::index data = 0; data < N; ++data) {
110 9137658 : m_data_index[data] = flat_index(data, 0, m_inner_ndim);
111 : }
112 : }
113 3693699 : }
114 :
115 0 : void set_to_end() noexcept {
116 0 : if (has_bins()) {
117 0 : set_to_end_bin();
118 : } else {
119 0 : if (m_inner_ndim == 0) {
120 0 : m_coord[0] = 1;
121 : } else {
122 0 : zero_out_coords(m_inner_ndim - 1);
123 0 : m_coord[m_inner_ndim - 1] = m_shape[m_inner_ndim - 1];
124 : }
125 0 : for (scipp::index data = 0; data < N; ++data) {
126 0 : m_data_index[data] = flat_index(data, 0, m_inner_ndim);
127 : }
128 : }
129 0 : }
130 :
131 26337612 : [[nodiscard]] constexpr auto get() const noexcept { return m_data_index; }
132 :
133 13639798 : bool operator==(const MultiIndex &other) const noexcept {
134 : // Assuming the number dimensions match to make the check cheaper.
135 13639798 : return m_coord == other.m_coord;
136 : }
137 :
138 13639798 : bool operator!=(const MultiIndex &other) const noexcept {
139 13639798 : return !(*this == other); // NOLINT
140 : }
141 :
142 : [[nodiscard]] bool
143 11744490 : in_same_chunk(const MultiIndex &other,
144 : const scipp::index first_dim) const noexcept {
145 : // Take scalars of bins into account when calculating ndim.
146 13204742 : for (scipp::index dim = first_dim;
147 13204742 : dim < m_inner_ndim + std::max(bin_ndim(), scipp::index{1}); ++dim) {
148 11498408 : if (m_coord[dim] != other.m_coord[dim]) {
149 10038156 : return false;
150 : }
151 : }
152 1706334 : return true;
153 : }
154 :
155 0 : [[nodiscard]] auto begin() const noexcept {
156 0 : auto it(*this);
157 0 : it.set_index(0);
158 0 : return it;
159 : }
160 :
161 0 : [[nodiscard]] auto end() const noexcept {
162 0 : auto it(*this);
163 0 : it.set_to_end();
164 0 : return it;
165 : }
166 :
167 76766018 : [[nodiscard]] bool has_bins() const noexcept {
168 76766018 : return m_nested_dim_index != -1;
169 : }
170 :
171 : /// Return true if the first subindex has a 0 stride
172 648015 : [[nodiscard]] bool has_stride_zero() const noexcept {
173 1088090 : for (scipp::index dim = 0; dim < m_ndim; ++dim)
174 536994 : if (m_stride[dim][0] == 0)
175 96919 : return true;
176 551096 : return false;
177 : }
178 :
179 : private:
180 120412057 : [[nodiscard]] auto dim_at_end(const scipp::index dim) const noexcept {
181 120412057 : return m_coord[dim] == std::max(m_shape[dim], scipp::index{1});
182 : }
183 :
184 61077979 : [[nodiscard]] bool at_end() const noexcept { return dim_at_end(last_dim()); }
185 :
186 61077979 : [[nodiscard]] scipp::index last_dim() const noexcept {
187 61077979 : if (has_bins()) {
188 61077979 : return bin_ndim() == 0 ? m_ndim : m_ndim - 1;
189 : } else {
190 0 : return std::max(m_ndim - 1, scipp::index{0});
191 : }
192 : }
193 :
194 : template <class F>
195 12490057 : void increment_in_dims(const F &data_index, const scipp::index begin_dim,
196 : const scipp::index end_dim) {
197 19276785 : for (scipp::index dim = begin_dim; dim < end_dim && dim_at_end(dim);
198 : ++dim) {
199 21950080 : for (scipp::index data = 0; data < N; ++data) {
200 15163352 : data_index(data) +=
201 : // take a step in dimension dim+1
202 15163352 : m_stride[dim + 1][data]
203 : // rewind dimension dim (coord(d) == m_shape[d])
204 15163352 : - m_coord[dim] * m_stride[dim][data];
205 : }
206 6786728 : ++m_coord[dim + 1];
207 6786728 : m_coord[dim] = 0;
208 : }
209 12490057 : }
210 :
211 74354660 : [[nodiscard]] constexpr auto bin_ndim() const noexcept {
212 74354660 : return m_ndim - m_inner_ndim;
213 : }
214 :
215 17052956 : [[nodiscard]] bool current_bin_is_empty() const noexcept {
216 17052956 : return m_shape[m_nested_dim_index] == 0;
217 : }
218 :
219 : struct BinIterator {
220 : BinIterator() = default;
221 144566 : explicit BinIterator(const BucketParams &bucket_params,
222 : const scipp::index outer_volume)
223 144566 : : m_is_binned{static_cast<bool>(bucket_params)},
224 : // indices can be != nullptr but outer_volume == 0 when Variable
225 : // was sliced.
226 144566 : m_indices{outer_volume == 0 ? nullptr : bucket_params.indices} {}
227 :
228 : const bool m_is_binned{false};
229 : scipp::index m_bin_index{0};
230 : const std::pair<scipp::index, scipp::index> *m_indices{nullptr};
231 : };
232 :
233 495717 : void increment_outer_bins() noexcept {
234 991434 : increment_in_dims(
235 1130551 : [this](const scipp::index data) -> scipp::index & {
236 1130551 : return this->m_bin[data].m_bin_index;
237 : },
238 495717 : m_inner_ndim, m_ndim - 1);
239 495717 : }
240 :
241 16984326 : void increment_bins() noexcept {
242 16984326 : const auto dim = m_inner_ndim;
243 51034298 : for (scipp::index data = 0; data < N; ++data) {
244 34049972 : m_bin[data].m_bin_index += m_stride[dim][data];
245 : }
246 16984326 : zero_out_coords(m_inner_ndim);
247 16984326 : ++m_coord[dim];
248 16984326 : if (dim_at_end(dim))
249 495717 : increment_outer_bins();
250 16984326 : if (!at_end()) {
251 50875172 : for (scipp::index data = 0; data < N; ++data) {
252 33941490 : load_bin_params(data);
253 : }
254 : }
255 16984326 : }
256 :
257 16984326 : void seek_bin() noexcept {
258 : do {
259 16984326 : increment_bins();
260 16984326 : } while (current_bin_is_empty() && !at_end());
261 4729842 : }
262 :
263 34240755 : void load_bin_params(const scipp::index data) noexcept {
264 34240755 : if (!m_bin[data].m_is_binned) {
265 2423795 : m_data_index[data] = flat_index(data, 0, m_ndim);
266 31816960 : } else if (!at_end()) {
267 : // All bins are guaranteed to have the same size.
268 31752763 : if (m_bin[data].m_indices != nullptr) {
269 31689271 : const auto [begin, end] =
270 31689271 : m_bin[data].m_indices[m_bin[data].m_bin_index];
271 31689271 : m_shape[m_nested_dim_index] = end - begin;
272 31689271 : m_data_index[data] = m_stride[m_nested_dim_index][data] * begin;
273 : } else {
274 : // m_indices can be nullptr if there are bins, but they are empty.
275 63492 : m_shape[m_nested_dim_index] = 0;
276 63492 : m_data_index[data] = 0;
277 : }
278 : }
279 : // else: at end of bins
280 34240755 : }
281 :
282 68630 : void set_bins_index(const scipp::index index) noexcept {
283 68630 : if (bin_ndim() == 0 && index != 0) {
284 : // Scalar outer dims and setting to / past end.
285 3309 : set_to_end_bin();
286 : } else {
287 65321 : zero_out_coords(m_inner_ndim);
288 65321 : extract_indices(index, shape_it(m_inner_ndim), shape_end(),
289 : coord_it(m_inner_ndim));
290 : }
291 :
292 223368 : for (scipp::index data = 0; data < N; ++data) {
293 154738 : m_bin[data].m_bin_index = flat_index(data, m_inner_ndim, m_ndim);
294 154738 : load_bin_params(data);
295 : }
296 68630 : if (current_bin_is_empty() && !at_end())
297 4114 : seek_bin();
298 68630 : }
299 :
300 3309 : void set_to_end_bin() noexcept {
301 3309 : zero_out_coords(m_ndim);
302 3309 : if (bin_ndim() == 0) {
303 3309 : m_coord[m_inner_ndim] = 1;
304 : } else {
305 0 : m_coord[m_ndim - 1] = std::max(m_shape[m_ndim - 1], scipp::index{1});
306 : }
307 3309 : }
308 :
309 11716191 : scipp::index flat_index(const scipp::index i_data, scipp::index begin_index,
310 : const scipp::index end_index) {
311 11716191 : scipp::index res = 0;
312 24695088 : for (; begin_index < end_index; ++begin_index) {
313 12978897 : res += m_coord[begin_index] * m_stride[begin_index][i_data];
314 : }
315 11716191 : return res;
316 : }
317 :
318 17052956 : void zero_out_coords(const scipp::index ndim) noexcept {
319 17052956 : const auto end = coord_it(ndim);
320 34160675 : for (auto it = coord_it(); it != end; ++it) {
321 17107719 : *it = 0;
322 : }
323 17052956 : }
324 :
325 37796302 : [[nodiscard]] auto coord_it(const scipp::index dim = 0) noexcept {
326 37796302 : return m_coord.begin() + dim;
327 : }
328 :
329 7315459 : [[nodiscard]] auto shape_it(const scipp::index dim = 0) noexcept {
330 7315459 : return std::next(m_shape.begin(), dim);
331 : }
332 :
333 65321 : [[nodiscard]] auto shape_end() noexcept { return m_shape.begin() + m_ndim; }
334 :
335 : /// Current flat index into the operands.
336 : std::array<scipp::index, N> m_data_index = {};
337 : // This does *not* 0-init the inner arrays!
338 : /// Stride for each operand in each dimension.
339 : std::array<std::array<scipp::index, N>, NDIM_OP_MAX> m_stride = {};
340 : /// Current index in iteration dimensions for both bin and inner dims.
341 : std::array<scipp::index, NDIM_OP_MAX + 1> m_coord = {};
342 : /// Shape of the iteration dimensions for both bin and inner dims.
343 : std::array<scipp::index, NDIM_OP_MAX + 1> m_shape = {};
344 : /// Total number of dimensions.
345 : scipp::index m_ndim{0};
346 : /// Number of dense dimensions, i.e. same as m_ndim when not binned,
347 : /// else number of dims in bins.
348 : scipp::index m_inner_ndim{0};
349 : /// Index of dim referred to by bin indices to distinguish, e.g., 2D bins
350 : /// slicing along first or second dim.
351 : /// -1 if not binned.
352 : scipp::index m_nested_dim_index{-1};
353 : /// Parameters of the currently loaded bins.
354 : std::array<BinIterator, N> m_bin = {};
355 : };
356 :
357 : template <class... StridesArgs>
358 : MultiIndex(const Dimensions &, const StridesArgs &...)
359 : -> MultiIndex<sizeof...(StridesArgs)>;
360 : template <class... Params>
361 : MultiIndex(const ElementArrayViewParams &, const Params &...)
362 : -> MultiIndex<sizeof...(Params) + 1>;
363 :
364 : } // namespace scipp::core
|