Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #include <algorithm>
6 : #include <limits>
7 :
8 : #include "scipp/core/bucket.h"
9 : #include "scipp/core/element/event_operations.h"
10 : #include "scipp/core/element/histogram.h"
11 : #include "scipp/core/except.h"
12 :
13 : #include "scipp/variable/arithmetic.h"
14 : #include "scipp/variable/bins.h"
15 : #include "scipp/variable/creation.h"
16 : #include "scipp/variable/cumulative.h"
17 : #include "scipp/variable/reduction.h"
18 : #include "scipp/variable/subspan_view.h"
19 : #include "scipp/variable/transform.h"
20 : #include "scipp/variable/transform_subspan.h"
21 : #include "scipp/variable/util.h"
22 : #include "scipp/variable/variable.h"
23 : #include "scipp/variable/variable_factory.h"
24 :
25 : #include "scipp/dataset/bins.h"
26 : #include "scipp/dataset/bins_view.h"
27 : #include "scipp/dataset/dataset.h"
28 : #include "scipp/dataset/histogram.h"
29 :
30 : #include "../variable/operations_common.h"
31 : #include "bin_common.h"
32 : #include "bin_detail.h"
33 : #include "dataset_operations_common.h"
34 :
35 : namespace scipp::dataset {
36 : namespace {
37 10356 : constexpr auto copy_or_match = [](const auto &a, auto &&b, const Dim dim,
38 : const Variable &srcIndices,
39 : const Variable &dstIndices) {
40 10356 : if (a.dims().contains(dim))
41 10258 : copy_slices(a, b, dim, srcIndices, dstIndices);
42 : else
43 98 : core::expect::equals(a, b);
44 10356 : };
45 :
46 7774 : constexpr auto expect_matching_keys = [](const auto &a, const auto &b) {
47 7774 : bool ok = true;
48 20761 : constexpr auto key = [](const auto &x_) {
49 : if constexpr (std::is_base_of_v<DataArray, std::decay_t<decltype(x_)>>)
50 160 : return x_.name();
51 : else
52 20601 : return x_.first;
53 : };
54 18152 : for (const auto &x : a)
55 10378 : ok &= b.contains(key(x));
56 18157 : for (const auto &x : b)
57 10383 : ok &= a.contains(key(x));
58 7774 : if (!ok)
59 21 : throw std::runtime_error("Mismatching keys in\n" + to_string(a) + " and\n" +
60 : to_string(b));
61 7753 : };
62 :
63 172 : auto make_fill(const DataArray &function,
64 : const std::optional<Variable> &fill_value) {
65 172 : Variable fill = fill_value.value_or(zero_like(function.data()));
66 172 : if (fill_value) {
67 138 : if (fill.dtype() != function.dtype())
68 0 : throw except::TypeError(
69 0 : "The fill_value (dtype=" + to_string(fill.dtype()) +
70 0 : ") must have the same dtype as the function values (dtype=" +
71 0 : to_string(function.dtype()) + ").");
72 34 : } else if (fill.dtype() == dtype<double>) {
73 10 : fill.value<double>() = std::numeric_limits<double>::quiet_NaN();
74 24 : } else if (fill.dtype() == dtype<float>) {
75 6 : fill.value<float>() = std::numeric_limits<float>::quiet_NaN();
76 : }
77 172 : return fill;
78 0 : }
79 :
80 : } // namespace
81 :
82 3761 : void copy_slices(const DataArray &src, DataArray dst, const Dim dim,
83 : const Variable &srcIndices, const Variable &dstIndices) {
84 3765 : copy_slices(src.data(), dst.data(), dim, srcIndices, dstIndices);
85 3781 : expect_matching_keys(src.meta(), dst.meta());
86 3745 : expect_matching_keys(src.masks(), dst.masks());
87 13807 : for (const auto &[name, coord] : src.meta())
88 13807 : copy_or_match(coord, dst.meta()[name], dim, srcIndices, dstIndices);
89 3829 : for (const auto &[name, mask] : src.masks())
90 88 : copy_or_match(mask, dst.masks()[name], dim, srcIndices, dstIndices);
91 3741 : }
92 :
93 61 : void copy_slices(const Dataset &src, Dataset dst, const Dim dim,
94 : const Variable &srcIndices, const Variable &dstIndices) {
95 154 : for (const auto &[name, var] : src.coords())
96 94 : copy_or_match(var, dst.coords()[name], dim, srcIndices, dstIndices);
97 60 : expect_matching_keys(src.coords(), dst.coords());
98 59 : expect_matching_keys(src, dst);
99 132 : for (const auto &item : src) {
100 77 : const auto &dst_ = dst[item.name()];
101 77 : expect_matching_keys(item.attrs(), dst_.attrs());
102 76 : expect_matching_keys(item.masks(), dst_.masks());
103 75 : copy_or_match(item.data(), dst_.data(), dim, srcIndices, dstIndices);
104 94 : for (const auto &[name, var] : item.masks())
105 19 : copy_or_match(var, dst_.masks()[name], dim, srcIndices, dstIndices);
106 90 : for (const auto &[name, var] : item.attrs())
107 15 : copy_or_match(var, dst_.attrs()[name], dim, srcIndices, dstIndices);
108 79 : }
109 55 : }
110 :
111 : namespace {
112 13981 : constexpr auto copy_or_resize = [](const auto &var, const Dim dim,
113 : const scipp::index size) {
114 13981 : auto dims = var.dims();
115 13981 : if (dims.contains(dim))
116 13882 : dims.resize(dim, size);
117 : // Using variableFactory instead of variable::resize for creating
118 : // _uninitialized_ variable.
119 13981 : return var.dims().contains(dim)
120 13882 : ? variable::variableFactory().create(var.dtype(), dims, var.unit(),
121 13882 : var.has_variances())
122 41844 : : copy(var);
123 13981 : };
124 : } // namespace
125 :
126 : // TODO These functions are an unfortunate near-duplicate of `resize`. However,
127 : // the latter drops coords along the resized dimension. Is there a way to unify
128 : // this? Can the need to drop coords in resize be avoided?
129 3708 : DataArray resize_default_init(const DataArray &parent, const Dim dim,
130 : const scipp::index size) {
131 7416 : DataArray buffer(copy_or_resize(parent.data(), dim, size));
132 13704 : for (const auto &[name, var] : parent.coords())
133 9996 : buffer.coords().set(name, copy_or_resize(var, dim, size));
134 3799 : for (const auto &[name, var] : parent.masks())
135 91 : buffer.masks().set(name, copy_or_resize(var, dim, size));
136 3722 : for (const auto &[name, var] : parent.attrs())
137 14 : buffer.attrs().set(name, copy_or_resize(var, dim, size));
138 3708 : return buffer;
139 0 : }
140 :
141 49 : Dataset resize_default_init(const Dataset &parent, const Dim dim,
142 : const scipp::index size) {
143 49 : auto new_sizes = parent.sizes();
144 49 : if (new_sizes.contains(dim))
145 49 : new_sizes.resize(dim, size);
146 :
147 98 : Dataset buffer({}, Coords(new_sizes, {}));
148 128 : for (const auto &[name, var] : parent.coords())
149 79 : buffer.setCoord(name, copy_or_resize(var, dim, size));
150 114 : for (const auto &item : parent) {
151 65 : buffer.setData(item.name(), copy_or_resize(item.data(), dim, size));
152 80 : for (const auto &[name, var] : item.masks())
153 15 : buffer[item.name()].masks().set(name, copy_or_resize(var, dim, size));
154 78 : for (const auto &[name, var] : item.attrs())
155 13 : buffer[item.name()].attrs().set(name, copy_or_resize(var, dim, size));
156 65 : }
157 98 : return buffer;
158 49 : }
159 :
160 : /// Construct a bin-variable over a data array.
161 : ///
162 : /// Each bin is represented by a Variable slice. `indices` defines the array of
163 : /// bins as slices of `buffer` along `dim`.
164 8739 : Variable make_bins(Variable indices, const Dim dim, DataArray buffer) {
165 8739 : expect_valid_bin_indices(indices, dim, buffer.dims());
166 8739 : return make_bins_no_validate(std::move(indices), dim, std::move(buffer));
167 : }
168 :
169 : /// Construct a bin-variable over a data array without index validation.
170 : ///
171 : /// Must be used only when it is guaranteed that indices are valid or overlap of
172 : /// bins is acceptable.
173 35403 : Variable make_bins_no_validate(Variable indices, const Dim dim,
174 : DataArray buffer) {
175 35403 : return variable::make_bins_impl(std::move(indices), dim, std::move(buffer));
176 : }
177 :
178 : /// Construct a bin-variable over a dataset.
179 : ///
180 : /// Each bin is represented by a Variable slice. `indices` defines the array of
181 : /// bins as slices of `buffer` along `dim`.
182 32 : Variable make_bins(Variable indices, const Dim dim, Dataset buffer) {
183 32 : expect_valid_bin_indices(indices, dim, buffer.sizes());
184 32 : return make_bins_no_validate(std::move(indices), dim, std::move(buffer));
185 : }
186 :
187 : /// Construct a bin-variable over a dataset without index validation.
188 : ///
189 : /// Must be used only when it is guaranteed that indices are valid or overlap of
190 : /// bins is acceptable.
191 103 : Variable make_bins_no_validate(Variable indices, const Dim dim,
192 : Dataset buffer) {
193 103 : return variable::make_bins_impl(std::move(indices), dim, std::move(buffer));
194 : }
195 :
196 77387 : bool is_bins(const DataArray &array) { return is_bins(array.data()); }
197 :
198 0 : bool is_bins(const Dataset &dataset) {
199 0 : return std::any_of(dataset.begin(), dataset.end(),
200 0 : [](const auto &item) { return is_bins(item); });
201 : }
202 :
203 40 : Variable lookup_previous(const DataArray &function, const Variable &x, Dim dim,
204 : const std::optional<Variable> &fill_value) {
205 40 : const auto fill = make_fill(function, fill_value);
206 40 : const auto &coord = function.meta()[dim];
207 40 : const auto data = masked_data(function, dim, fill);
208 40 : const auto weights = subspan_view(data, dim);
209 40 : if (!allsorted(coord, dim))
210 0 : throw except::DataArrayError(
211 0 : "Coordinate of lookup function must be sorted.");
212 : // Note that we could do a linspace optimization similar to buckets::map here.
213 : // Add this if we have real world application that would benefit.
214 80 : return variable::transform(x, subspan_view(coord, dim), weights, fill,
215 : core::element::event::lookup_previous,
216 120 : "lookup_previous");
217 40 : }
218 :
219 6143 : Variable pretend_bins_for_threading(const DataArray &da, Dim bin_dim) {
220 6143 : const auto dim = da.dims().inner();
221 6143 : const auto size = std::max(scipp::index(1), da.dims()[dim]);
222 12286 : const auto nthread = size > 10000000 ? 24
223 12282 : : size > 1000000 ? 4
224 6139 : : size > 100000 ? 2
225 : : 1;
226 :
227 6143 : const auto stride = std::max(scipp::index(1), size / nthread);
228 6143 : auto begin = bin_detail::make_range(0, size, stride, bin_dim);
229 6143 : auto end = begin + stride * units::none;
230 6143 : end.values<scipp::index>().as_span().back() = da.dims()[dim];
231 6143 : const auto indices = zip(begin, end);
232 12286 : return make_bins_no_validate(indices, dim, da);
233 6143 : }
234 :
235 : } // namespace scipp::dataset
236 :
237 : namespace scipp::dataset::buckets {
238 : namespace {
239 :
240 28 : template <class T> auto combine(const Variable &var0, const Variable &var1) {
241 28 : const auto &[indices0, dim0, buffer0] = var0.constituents<T>();
242 28 : const auto &[indices1, dim1, buffer1] = var1.constituents<T>();
243 : static_cast<void>(buffer1);
244 : static_cast<void>(dim1);
245 28 : const Dim dim = dim0;
246 28 : const auto [begin0, end0] = unzip(indices0);
247 28 : const auto [begin1, end1] = unzip(indices1);
248 28 : const auto sizes0 = end0 - begin0;
249 28 : const auto sizes1 = end1 - begin1;
250 28 : const auto sizes = sizes0 + sizes1;
251 28 : const auto end = cumsum(sizes);
252 28 : const auto begin = end - sizes;
253 28 : const auto total_size =
254 28 : end.dims().volume() > 0
255 28 : ? end.template values<scipp::index>().as_span().back()
256 : : 0;
257 28 : auto buffer = resize_default_init(buffer0, dim, total_size);
258 28 : copy_slices(buffer0, buffer, dim, indices0, zip(begin, end - sizes1));
259 46 : copy_slices(buffer1, buffer, dim, indices1, zip(begin + sizes0, end));
260 44 : return make_bins_no_validate(zip(begin, end), dim, std::move(buffer));
261 82 : }
262 :
263 : template <class T>
264 21 : auto concatenate_impl(const Variable &var0, const Variable &var1) {
265 21 : return combine<T>(var0, var1);
266 : }
267 :
268 : } // namespace
269 :
270 21 : Variable concatenate(const Variable &var0, const Variable &var1) {
271 21 : if (var0.dtype() == dtype<bucket<Variable>>)
272 0 : return concatenate_impl<Variable>(var0, var1);
273 21 : else if (var0.dtype() == dtype<bucket<DataArray>>)
274 9 : return concatenate_impl<DataArray>(var0, var1);
275 : else
276 12 : return concatenate_impl<Dataset>(var0, var1);
277 : }
278 :
279 7 : DataArray concatenate(const DataArray &a, const DataArray &b) {
280 14 : return DataArray{buckets::concatenate(a.data(), b.data()),
281 14 : union_(a.coords(), b.coords(), "concatenate"),
282 14 : union_or(a.masks(), b.masks()),
283 14 : intersection(a.attrs(), b.attrs())};
284 : }
285 :
286 : /// Reduce a dimension by concatenating all elements along the dimension.
287 : ///
288 : /// This is the analogue to summing non-bucket data.
289 8 : Variable concatenate(const Variable &var, const Dim dim) {
290 8 : if (var.dtype() == dtype<bucket<Variable>>)
291 1 : return concat_bins<Variable>(var, dim);
292 : else
293 7 : return concat_bins<DataArray>(var, dim);
294 : }
295 :
296 : /// Reduce a dimension by concatenating all elements along the dimension.
297 : ///
298 : /// This is the analogue to summing non-bucket data.
299 5 : DataArray concatenate(const DataArray &array, const Dim dim) {
300 5 : return groupby_concat_bins(array, {}, {}, {dim});
301 : }
302 :
303 7 : void append(Variable &var0, const Variable &var1) {
304 7 : if (var0.dtype() == dtype<bucket<Variable>>)
305 0 : var0.setDataHandle(combine<Variable>(var0, var1).data_handle());
306 7 : else if (var0.dtype() == dtype<bucket<DataArray>>)
307 9 : var0.setDataHandle(combine<DataArray>(var0, var1).data_handle());
308 : else
309 0 : var0.setDataHandle(combine<Dataset>(var0, var1).data_handle());
310 6 : }
311 :
312 0 : void append(Variable &&var0, const Variable &var1) { append(var0, var1); }
313 :
314 4 : void append(DataArray &a, const DataArray &b) {
315 4 : expect::coords_are_superset(a, b, "bins.append");
316 4 : union_or_in_place(a.masks(), b.masks());
317 4 : auto data = a.data();
318 4 : append(data, b.data());
319 4 : a.setData(data);
320 4 : }
321 :
322 128 : Variable histogram(const Variable &data, const Variable &binEdges) {
323 : using namespace scipp::core;
324 128 : auto hist_dim = binEdges.dims().inner();
325 128 : auto &&[indices, dim, buffer] = data.constituents<DataArray>();
326 : // `hist_dim` may be the same as a dim of data if there is existing binning.
327 : // We rename to a dummy to avoid duplicate dimensions, perform histogramming,
328 : // and then sum over the dummy dimensions, i.e., sum contributions from all
329 : // inputs bins to the same output histogram. This also allows for threading of
330 : // 1-D histogramming provided that the input has multiple bins along
331 : // `hist_dim`.
332 128 : const Dim dummy = Dim::InternalHistogram;
333 128 : const auto nbin = binEdges.dims()[hist_dim] - 1;
334 128 : if (indices.dims().contains(hist_dim)) {
335 : // With large existing dim matching the new dim, we would create a large
336 : // intermediate histogrammed result, which leads to performance and memory
337 : // issues. This is a suboptimal (since it concatenates first) but simple way
338 : // to avoid the problem.
339 47 : if (indices.dims().volume() * nbin > 100000000) { // about 1 GByte
340 0 : const auto tmp = concatenate(data, hist_dim);
341 0 : if (tmp.ndim() == 0) // Operate on buffer so we get multi-threading
342 0 : return histogram(tmp.bin_buffer<DataArray>(), binEdges).data();
343 : else
344 0 : return histogram(tmp, binEdges);
345 0 : }
346 47 : indices = indices.rename_dims({{hist_dim, dummy}});
347 : }
348 :
349 128 : const auto masked = masked_data(buffer, dim);
350 128 : const auto coord = buffer.meta()[hist_dim];
351 128 : const auto dt = common_type(binEdges, coord);
352 128 : const auto promoted_coord = astype(coord, dt, CopyPolicy::TryAvoid);
353 128 : const auto promoted_edges = astype(binEdges, dt, CopyPolicy::TryAvoid);
354 : auto hist = variable::transform_subspan(
355 : buffer.dtype(), hist_dim, nbin,
356 256 : subspan_view(promoted_coord, dim, indices),
357 129 : subspan_view(masked, dim, indices), promoted_edges, element::histogram,
358 257 : "histogram");
359 127 : if (hist.dims().contains(dummy))
360 57 : return sum(hist, dummy);
361 : else
362 70 : return hist;
363 132 : }
364 :
365 132 : Variable map(const DataArray &function, const Variable &x, Dim dim,
366 : const std::optional<Variable> &fill_value) {
367 132 : const auto fill = make_fill(function, fill_value);
368 132 : if (dim == Dim::Invalid)
369 0 : dim = edge_dimension(function);
370 132 : const auto &edges = function.meta()[dim];
371 132 : if (!is_edges(function.dims(), edges.dims(), dim))
372 1 : throw except::BinEdgeError(
373 2 : "Function used as lookup table in map operation must be a histogram");
374 131 : const auto data = masked_data(function, dim, fill);
375 131 : const auto weights = subspan_view(data, dim);
376 131 : if (all(islinspace(edges, dim)).value<bool>()) {
377 231 : return variable::transform(x, subspan_view(edges, dim), weights, fill,
378 345 : core::element::event::map_linspace, "map");
379 : } else {
380 16 : if (!allsorted(edges, dim))
381 0 : throw except::BinEdgeError("Bin edges of histogram must be sorted.");
382 32 : return variable::transform(x, subspan_view(edges, dim), weights, fill,
383 48 : core::element::event::map_sorted_edges, "map");
384 : }
385 136 : }
386 :
387 122 : void scale(DataArray &array, const DataArray &histogram, Dim dim) {
388 122 : if (dim == Dim::Invalid)
389 11 : dim = edge_dimension(histogram);
390 : // Coords along dim are ignored since "binning" is dynamic for buckets.
391 119 : expect::coords_are_superset(array, histogram.slice({dim, 0}), "bins.scale");
392 : // scale applies masks along dim but others are kept
393 119 : union_or_in_place(array.masks(), histogram.slice({dim, 0}).masks());
394 238 : auto data = bins_view<DataArray>(array.data()).data();
395 119 : const auto &coord = bins_view<DataArray>(array.data()).meta()[dim];
396 119 : const auto &edges = histogram.meta()[dim];
397 119 : const auto masked = masked_data(histogram, dim);
398 119 : const auto weights = subspan_view(masked, dim);
399 119 : if (all(islinspace(edges, dim)).value<bool>()) {
400 119 : transform_in_place(data, coord, subspan_view(edges, dim), weights,
401 : core::element::event::map_and_mul_linspace,
402 : "bins.scale");
403 : } else {
404 2 : if (!allsorted(edges, dim))
405 0 : throw except::BinEdgeError("Bin edges of histogram must be sorted.");
406 2 : transform_in_place(data, coord, subspan_view(edges, dim), weights,
407 : core::element::event::map_and_mul_sorted_edges,
408 : "bins.scale");
409 : }
410 127 : }
411 : } // namespace scipp::dataset::buckets
|