Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 :
6 : #include "scipp/core/multi_index.h"
7 : #include "scipp-core_export.h"
8 : #include "scipp/core/except.h"
9 :
10 : namespace scipp::core {
11 :
12 : template class SCIPP_CORE_EXPORT MultiIndex<1>;
13 : template class SCIPP_CORE_EXPORT MultiIndex<2>;
14 : template class SCIPP_CORE_EXPORT MultiIndex<3>;
15 : template class SCIPP_CORE_EXPORT MultiIndex<4>;
16 : template class SCIPP_CORE_EXPORT MultiIndex<5>;
17 :
18 : namespace {
19 37508 : void validate_bin_indices_impl(const ElementArrayViewParams ¶m0,
20 : const ElementArrayViewParams ¶m1) {
21 37508 : const auto &iterDims = param0.dims();
22 37508 : auto index = MultiIndex(iterDims, param0.strides(), param1.strides());
23 37508 : const auto indices0 = param0.bucketParams().indices;
24 37508 : const auto indices1 = param1.bucketParams().indices;
25 29388344 : constexpr auto size = [](const auto range) {
26 29388344 : return range.second - range.first;
27 : };
28 14731663 : for (scipp::index i = 0; i < iterDims.volume(); ++i) {
29 14694172 : const auto [i0, i1] = index.get();
30 14694172 : if (size(indices0[i0]) != size(indices1[i1]))
31 17 : throw except::BinnedDataError(
32 : "Bin size mismatch in operation with binned data. Refer to "
33 : "https://scipp.github.io/user-guide/binned-data/"
34 : "computation.html#Overview-and-Quick-Reference for equivalent "
35 34 : "operations for binned data (event data).");
36 14694155 : index.increment();
37 : }
38 37491 : }
39 :
40 72299 : template <class Param> void validate_bin_indices(const Param &) {}
41 :
42 : /// Check that corresponding bins have matching sizes.
43 : template <class Param0, class Param1, class... Params>
44 69454 : void validate_bin_indices(const Param0 ¶m0, const Param1 ¶m1,
45 : const Params &...params) {
46 69454 : if (param0.bucketParams() && param1.bucketParams())
47 37508 : validate_bin_indices_impl(param0, param1);
48 69437 : if (param0.bucketParams())
49 51427 : validate_bin_indices(param0, params...);
50 : else
51 18010 : validate_bin_indices(param1, params...);
52 69434 : }
53 :
54 0 : inline auto get_slice_dim() { return Dim::Invalid; }
55 :
56 : template <class T, class... Ts>
57 90309 : auto get_slice_dim(const T ¶m, const Ts &...params) {
58 90309 : return param ? param.dim : get_slice_dim(params...);
59 : }
60 :
61 : template <class T>
62 4113380 : [[nodiscard]] auto make_span(T &&array, const scipp::index begin) {
63 4113380 : return scipp::span{array.data() + begin,
64 4113380 : static_cast<size_t>(NDIM_OP_MAX - begin)};
65 : }
66 :
67 : template <class StridesArg>
68 3281759 : [[nodiscard]] scipp::index value_or_default(const StridesArg &strides,
69 : const scipp::index i) {
70 3281759 : return i < strides.size() ? strides[i] : 0;
71 : }
72 :
73 : template <size_t... I, class... StridesArgs>
74 161920 : bool can_be_flattened(
75 : const scipp::index dim, const scipp::index size, std::index_sequence<I...>,
76 : std::array<scipp::index, sizeof...(I)> &strides_for_contiguous,
77 : const StridesArgs &...strides) {
78 161920 : const bool res =
79 161920 : ((value_or_default(strides, dim) == strides_for_contiguous[I]) && ...);
80 161920 : ((strides_for_contiguous[I] = size * value_or_default(strides, dim)), ...);
81 161920 : return res;
82 : }
83 :
84 : // non_flattenable_dim is in the storage order of Dimensions & Strides.
85 : // It is not possible to flatten dimensions outside of the bin-slice dim
86 : // because they are sliced by that dim and their layout changes depending on
87 : // the current bin.
88 : // But the inner dimensions always have the same layout.
89 : template <class... StridesArgs>
90 : [[nodiscard]] scipp::index
91 2056690 : flatten_dims(const scipp::span<std::array<scipp::index, sizeof...(StridesArgs)>>
92 : &out_strides,
93 : const scipp::span<scipp::index> &out_shape, const Dimensions &dims,
94 : const scipp::index non_flattenable_dim,
95 : const StridesArgs &...strides) {
96 2056690 : constexpr scipp::index N = sizeof...(StridesArgs);
97 2056690 : std::array strides_array{std::ref(strides)...};
98 2056690 : std::array<scipp::index, N> strides_for_contiguous{};
99 2056690 : scipp::index dim_write = 0;
100 3295604 : for (scipp::index dim_read = dims.ndim() - 1; dim_read >= 0; --dim_read) {
101 1238915 : if (dim_write >= static_cast<scipp::index>(out_shape.size()))
102 1 : throw std::runtime_error("Operations with more than " +
103 : std::to_string(NDIM_OP_MAX) +
104 : " dimensions are not supported. "
105 : "For binned data, the combined bin+event "
106 : "dimensions count");
107 1238914 : const auto size = dims.size(dim_read);
108 396182 : if (dim_read > non_flattenable_dim &&
109 1635096 : dim_write > 0 && // need to write at least one inner dim
110 161920 : can_be_flattened(dim_read, size, std::make_index_sequence<N>{},
111 : strides_for_contiguous, strides...)) {
112 24607 : out_shape[dim_write - 1] *= size;
113 : } else {
114 1214307 : out_shape[dim_write] = size;
115 3945179 : for (scipp::index data = 0; data < N; ++data) {
116 2730872 : out_strides[dim_write][data] =
117 2730872 : value_or_default(strides_array[data].get(), dim_read);
118 : }
119 1214307 : ++dim_write;
120 : }
121 : }
122 2056689 : return dim_write;
123 : }
124 : } // namespace
125 :
126 : template <scipp::index N>
127 : template <class... StridesArgs>
128 1912092 : MultiIndex<N>::MultiIndex(const Dimensions &iter_dims,
129 : const StridesArgs &...strides)
130 1912092 : : m_ndim{flatten_dims(make_span(m_stride, 0), make_span(m_shape, 0),
131 : iter_dims, 0, strides...)},
132 3824184 : m_inner_ndim{m_ndim} {}
133 :
134 : template <scipp::index N>
135 : template <class... Params>
136 72316 : MultiIndex<N>::MultiIndex(binned_tag, const Dimensions &inner_dims,
137 : const Dimensions &bin_dims, const Params &...params)
138 72316 : : m_bin{BinIterator(params.bucketParams(), bin_dims.volume())...} {
139 72316 : validate_bin_indices(params...);
140 :
141 72299 : const Dim slice_dim = get_slice_dim(params.bucketParams()...);
142 :
143 : // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
144 144598 : m_inner_ndim = flatten_dims(
145 72299 : make_span(m_stride, 0), make_span(m_shape, 0), inner_dims,
146 : inner_dims.index(slice_dim),
147 144598 : params.bucketParams() ? params.bucketParams().strides : Strides{}...);
148 : // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
149 144598 : m_ndim = m_inner_ndim + flatten_dims(make_span(m_stride, m_inner_ndim),
150 72300 : make_span(m_shape, m_inner_ndim),
151 : bin_dims, 0, params.strides()...);
152 :
153 : // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
154 72298 : m_nested_dim_index = m_inner_ndim - inner_dims.index(slice_dim) - 1;
155 :
156 214029 : for (scipp::index data = 0; data < N; ++data) {
157 141731 : load_bin_params(data);
158 : }
159 72298 : if (m_shape[m_nested_dim_index] == 0 || bin_dims.volume() == 0)
160 10336 : seek_bin();
161 72298 : }
162 :
163 : template SCIPP_CORE_EXPORT MultiIndex<1>::MultiIndex(const Dimensions &,
164 : const Strides &);
165 : template SCIPP_CORE_EXPORT
166 : MultiIndex<2>::MultiIndex(const Dimensions &, const Strides &, const Strides &);
167 : template SCIPP_CORE_EXPORT MultiIndex<3>::MultiIndex(const Dimensions &,
168 : const Strides &,
169 : const Strides &,
170 : const Strides &);
171 : template SCIPP_CORE_EXPORT
172 : MultiIndex<4>::MultiIndex(const Dimensions &, const Strides &, const Strides &,
173 : const Strides &, const Strides &);
174 : template SCIPP_CORE_EXPORT
175 : MultiIndex<5>::MultiIndex(const Dimensions &, const Strides &, const Strides &,
176 : const Strides &, const Strides &, const Strides &);
177 :
178 : template SCIPP_CORE_EXPORT
179 : MultiIndex<1>::MultiIndex(binned_tag, const Dimensions &, const Dimensions &,
180 : const ElementArrayViewParams &);
181 : template SCIPP_CORE_EXPORT
182 : MultiIndex<2>::MultiIndex(binned_tag, const Dimensions &, const Dimensions &,
183 : const ElementArrayViewParams &,
184 : const ElementArrayViewParams &);
185 : template SCIPP_CORE_EXPORT
186 : MultiIndex<3>::MultiIndex(binned_tag, const Dimensions &, const Dimensions &,
187 : const ElementArrayViewParams &,
188 : const ElementArrayViewParams &,
189 : const ElementArrayViewParams &);
190 : template SCIPP_CORE_EXPORT MultiIndex<4>::MultiIndex(
191 : binned_tag, const Dimensions &, const Dimensions &,
192 : const ElementArrayViewParams &, const ElementArrayViewParams &,
193 : const ElementArrayViewParams &, const ElementArrayViewParams &);
194 : template SCIPP_CORE_EXPORT MultiIndex<5>::MultiIndex(
195 : binned_tag, const Dimensions &, const Dimensions &,
196 : const ElementArrayViewParams &, const ElementArrayViewParams &,
197 : const ElementArrayViewParams &, const ElementArrayViewParams &,
198 : const ElementArrayViewParams &);
199 :
200 : } // namespace scipp::core
|