Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #include <numeric>
6 :
7 : #include "scipp/variable/variable_factory.h"
8 :
9 : #include "scipp/dataset/bins.h"
10 : #include "scipp/dataset/except.h"
11 : #include "scipp/dataset/extract.h"
12 : #include "scipp/dataset/util.h"
13 :
14 : namespace scipp {
15 :
16 : namespace {
17 :
18 : /// Transform data of data array or dataset, coord, masks, and and attrs are
19 : /// shallow-copied.
20 : ///
21 : /// Beware of the mask-copy behavior, which is not suitable for data returned to
22 : /// the user.
23 : template <class T, class Func, class... Ts>
24 1054 : T transform_data(const T &obj, Func func, const Ts &...other) {
25 1054 : T out(obj);
26 : if constexpr (std::is_same_v<T, Variable>) {
27 744 : return func(obj, other...);
28 : } else if constexpr (std::is_same_v<T, DataArray>) {
29 254 : out.setData(func(obj.data(), other.data()...));
30 : } else {
31 128 : for (const auto &item : obj)
32 72 : out.setData(item.name(), func(item.data(), other[item.name()].data()...),
33 : dataset::AttrPolicy::Keep);
34 : }
35 310 : return out;
36 744 : }
37 :
38 : template <class Buffer>
39 539 : Variable copy_ranges_from_buffer(const Variable &indices, const Dim dim,
40 : const Buffer &buffer) {
41 539 : return copy(make_bins_no_validate(indices, dim, buffer));
42 : }
43 :
44 12 : Variable copy_ranges_from_bins_buffer(const Variable &indices,
45 : const Variable &data) {
46 12 : if (data.dtype() == dtype<bucket<Variable>>) {
47 6 : const auto &[i, dim, buf] = data.constituents<Variable>();
48 6 : return copy_ranges_from_buffer(indices, dim, buf);
49 12 : } else if (data.dtype() == dtype<bucket<DataArray>>) {
50 6 : const auto &[i, dim, buf] = data.constituents<DataArray>();
51 6 : return copy_ranges_from_buffer(indices, dim, buf);
52 6 : } else {
53 0 : const auto &[i, dim, buf] = data.constituents<Dataset>();
54 0 : return copy_ranges_from_buffer(indices, dim, buf);
55 0 : }
56 : }
57 :
58 535 : Variable dense_or_bin_indices(const Variable &var) {
59 535 : return is_bins(var) ? var.bin_indices() : var;
60 : }
61 :
62 535 : Variable dense_or_copy_bin_elements(const Variable &dense_or_indices,
63 : const Variable &data) {
64 535 : return is_bins(data) ? copy_ranges_from_bins_buffer(dense_or_indices, data)
65 535 : : dense_or_indices;
66 : }
67 : } // namespace
68 :
69 : template <class T>
70 527 : T extract_ranges(const Variable &indices, const T &data, const Dim dim) {
71 527 : T no_edges;
72 : if constexpr (std::is_same_v<T, Variable>)
73 372 : no_edges = data;
74 : else
75 155 : no_edges = strip_edges_along(data, dim);
76 : // 1. Operate on dense data, or equivalent array of indices (if binned) to
77 : // obtain output data of correct shape with proper meta data.
78 527 : auto dense = transform_data(no_edges, dense_or_bin_indices);
79 527 : auto out =
80 527 : copy_ranges_from_buffer(indices, dim, dense).template bin_buffer<T>();
81 : // 2. If we have binned data then the data of the DataArray or Dataset
82 : // obtained in step 1. give the indices into the underlying buffer to be
83 : // copied. This then replaces the data to obtain the final result. Does
84 : // nothing if dense data.
85 1054 : return transform_data(out, dense_or_copy_bin_elements, no_edges);
86 527 : }
87 :
88 : namespace {
89 462 : template <class T> T extract_impl(const T &obj, const Variable &condition) {
90 462 : if (condition.dtype() != dtype<bool>)
91 1 : throw except::TypeError(
92 : "Cannot extract elements based on condition with non-boolean dtype. If "
93 : "you intended to select a range based on a label you must specify the "
94 : "dimension.");
95 461 : if (condition.dims().ndim() != 1)
96 6 : throw except::DimensionError("Condition must by 1-D, but got " +
97 : to_string(condition.dims()) + '.');
98 455 : if (!obj.dims().includes(condition.dims()))
99 7 : throw except::DimensionError(
100 : "Condition dimensions " + to_string(condition.dims()) +
101 : " must be be included in the dimensions of the sliced object " +
102 : to_string(obj.dims()) + '.');
103 :
104 448 : auto values = condition.values<bool>().as_span();
105 448 : std::vector<scipp::index_pair> indices;
106 3173 : for (scipp::index i = 0; i < scipp::size(values); ++i) {
107 2725 : if (i > 0 && values[i - 1] == values[i])
108 1916 : continue; // not an edge
109 809 : if (values[i]) // rising edge
110 434 : indices.emplace_back(i, scipp::size(values));
111 375 : else if (i != 0) // falling edge
112 165 : indices.back().second = i;
113 : }
114 448 : return extract_ranges(makeVariable<scipp::index_pair>(Dims{condition.dim()},
115 896 : Shape{indices.size()},
116 896 : Values(indices)),
117 1344 : obj, condition.dim());
118 448 : }
119 : } // namespace
120 :
121 349 : Variable extract(const Variable &var, const Variable &condition) {
122 349 : return extract_impl(var, condition);
123 : }
124 :
125 107 : DataArray extract(const DataArray &da, const Variable &condition) {
126 107 : return extract_impl(da, condition);
127 : }
128 :
129 6 : Dataset extract(const Dataset &ds, const Variable &condition) {
130 6 : return extract_impl(ds, condition);
131 : }
132 :
133 : template SCIPP_DATASET_EXPORT Variable extract_ranges(const Variable &,
134 : const Variable &,
135 : const Dim);
136 : template SCIPP_DATASET_EXPORT DataArray extract_ranges(const Variable &,
137 : const DataArray &,
138 : const Dim);
139 : template SCIPP_DATASET_EXPORT Dataset extract_ranges(const Variable &,
140 : const Dataset &,
141 : const Dim);
142 :
143 : } // namespace scipp
|