Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #include <algorithm>
6 :
7 : #include "scipp/core/except.h"
8 : #include "scipp/variable/creation.h"
9 : #include "scipp/variable/shape.h"
10 :
11 : #include "scipp/dataset/except.h"
12 : #include "scipp/dataset/shape.h"
13 :
14 : #include "dataset_operations_common.h"
15 :
16 : using namespace scipp::variable;
17 :
18 : namespace scipp::dataset {
19 :
20 : /// Map `op` over `items`, return vector of results
21 6368 : template <class T, class Op> auto map(const T &items, Op op) {
22 6368 : std::vector<std::decay_t<decltype(op(items.front()))>> out;
23 6368 : out.reserve(items.size());
24 19345 : for (const auto &i : items)
25 12977 : out.emplace_back(op(i));
26 6368 : return out;
27 0 : }
28 :
29 : /// Concatenate a and b, assuming that a and b contain bin edges.
30 : ///
31 : /// Checks that the last edges in `a` match the first edges in `b`. The
32 : /// Concatenates the input edges, removing duplicate bin edges.
33 11 : Variable join_edges(const scipp::span<const Variable> vars, const Dim dim) {
34 11 : std::vector<Variable> tmp;
35 11 : tmp.reserve(vars.size());
36 31 : for (const auto &var : vars) {
37 23 : if (tmp.empty()) {
38 11 : tmp.emplace_back(var);
39 : } else {
40 45 : core::expect::equals(tmp.back().slice({dim, tmp.back().dims()[dim] - 1}),
41 27 : var.slice({dim, 0}));
42 9 : tmp.emplace_back(var.slice({dim, 1, var.dims()[dim]}));
43 : }
44 : }
45 16 : return concat(tmp, dim);
46 11 : }
47 :
48 : namespace {
49 : template <class T, class Key>
50 1183 : bool equal_is_edges(const T &maps, const Key &key, const Dim dim) {
51 1183 : return std::adjacent_find(maps.begin(), maps.end(),
52 2442 : [&key, dim](auto &a, auto &b) {
53 1221 : return is_edges(a.sizes(), a[key].dims(), dim) !=
54 1221 : is_edges(b.sizes(), b[key].dims(), dim);
55 2366 : }) == maps.end();
56 : }
57 : template <class T, class Key>
58 1176 : bool all_is_edges(const T &maps, const Key &key, const Dim dim) {
59 2366 : return std::all_of(maps.begin(), maps.end(), [&key, dim](auto &var) {
60 1190 : return is_edges(var.sizes(), var[key].dims(), dim);
61 1176 : });
62 : }
63 :
64 : template <class T, class Key>
65 58 : auto broadcast_along_dim(const T &maps, const Key &key, const Dim dim) {
66 127 : return map(maps, [&key, dim](const auto &map) {
67 127 : const auto &var = map[key];
68 254 : return broadcast(var, merge(Dimensions(dim, map.sizes().contains(dim)
69 127 : ? map.sizes().at(dim)
70 : : 1),
71 381 : var.dims()));
72 58 : });
73 : }
74 :
75 3487 : template <class Maps> auto concat_maps(const Maps &maps, const Dim dim) {
76 3487 : if (maps.empty())
77 1 : throw std::invalid_argument("Cannot concat empty list.");
78 : using T = typename Maps::value_type;
79 3486 : core::Dict<typename T::key_type, typename T::mapped_type> out;
80 3486 : const auto &a = maps.front();
81 5089 : for (const auto &[key, a_] : a) {
82 4921 : auto vars = map(maps, [&key = key](auto &&map) { return map[key]; });
83 1603 : if (a.dim_of(key) == dim) {
84 1183 : if (!equal_is_edges(maps, key, dim)) {
85 7 : throw except::BinEdgeError(
86 : "Either all or none of the inputs must have bin edge coordinates.");
87 1176 : } else if (!all_is_edges(maps, key, dim)) {
88 1165 : out.insert_or_assign(key, concat(vars, dim));
89 : } else {
90 11 : out.insert_or_assign(key, join_edges(vars, dim));
91 : }
92 : } else {
93 : // 1D coord is kept only if all inputs have matching 1D coords.
94 2169 : if (std::any_of(vars.begin(), vars.end(), [dim, &vars](auto &var) {
95 894 : return var.dims().contains(dim) || !equals_nan(var, vars.front());
96 : })) {
97 : // Mismatching 1D coords must be broadcast to ensure new coord shape
98 : // matches new data shape.
99 58 : out.insert_or_assign(key,
100 : concat(broadcast_along_dim(maps, key, dim), dim));
101 : } else {
102 : if constexpr (std::is_same_v<T, Masks>)
103 130 : out.insert_or_assign(key, copy(a_));
104 : else
105 232 : out.insert_or_assign(key, a_);
106 : }
107 : }
108 : }
109 3476 : return out;
110 10 : }
111 :
112 : } // namespace
113 :
114 1164 : DataArray concat(const scipp::span<const DataArray> das, const Dim dim) {
115 2326 : auto out = DataArray(concat(map(das, get_data), dim), {},
116 3493 : concat_maps(map(das, get_masks), dim));
117 1163 : const auto &coords = map(das, get_coords);
118 2392 : for (auto &&[d, coord] : concat_maps(coords, dim)) {
119 2377 : coord.set_aligned(d == dim ||
120 1148 : std::any_of(coords.begin(), coords.end(),
121 1148 : [&d = d](auto &_) { return _.contains(d); }));
122 1229 : out.coords().set(d, std::move(coord));
123 1153 : }
124 1313 : for (auto &&[d, attr] : concat_maps(map(das, get_attrs), dim)) {
125 160 : out.attrs().set(d, std::move(attr));
126 1153 : }
127 2306 : return out;
128 1173 : }
129 :
130 48 : Dataset concat(const scipp::span<const Dataset> dss, const Dim dim) {
131 48 : if (dss.empty())
132 1 : throw std::invalid_argument("Cannot concat empty list.");
133 47 : Dataset result;
134 98 : for (const auto &first : dss.front())
135 53 : if (std::all_of(dss.begin(), dss.end(),
136 110 : [&first](auto &ds) { return ds.contains(first.name()); })) {
137 154 : auto das = map(dss, [&first](auto &&ds) { return ds[first.name()]; });
138 50 : result.setDataInit(first.name(), concat(das, dim));
139 103 : }
140 45 : if (result.is_valid())
141 38 : return result;
142 14 : return Dataset({}, Coords(concat(map(dss, get_sizes), dim),
143 21 : concat_maps(map(dss, get_coords), dim)));
144 47 : }
145 :
146 78 : DataArray resize(const DataArray &a, const Dim dim, const scipp::index size,
147 : const FillValue fill) {
148 : return apply_to_data_and_drop_dim(
149 156 : a, [](auto &&..._) { return resize(_...); }, dim, size, fill);
150 : }
151 :
152 3 : Dataset resize(const Dataset &d, const Dim dim, const scipp::index size,
153 : const FillValue fill) {
154 : return apply_to_items(
155 6 : d, [](auto &&..._) { return resize(_...); }, dim, size, fill);
156 : }
157 :
158 0 : DataArray resize(const DataArray &a, const Dim dim, const DataArray &shape) {
159 : return apply_to_data_and_drop_dim(
160 0 : a, [](auto &&v, const Dim, auto &&s) { return resize(v, s); }, dim,
161 0 : shape.data());
162 : }
163 :
164 0 : Dataset resize(const Dataset &d, const Dim dim, const Dataset &shape) {
165 0 : Dataset result;
166 0 : for (const auto &data : d)
167 0 : result.setData(data.name(), resize(data, dim, shape[data.name()]));
168 0 : return result;
169 0 : }
170 :
171 : namespace {
172 :
173 : /// Either broadcast variable to from_dims before a reshape or not:
174 : ///
175 : /// 1. If all from_dims are contained in the variable's dims, no broadcast
176 : /// 2. If at least one (but not all) of the from_dims is contained in the
177 : /// variable's dims, broadcast
178 : /// 3. If none of the variables's dimensions are contained, no broadcast
179 7398 : Variable maybe_broadcast(const Variable &var,
180 : const scipp::span<const Dim> &from_labels,
181 : const Dimensions &data_dims) {
182 7398 : const auto &var_dims = var.dims();
183 7398 : Dimensions broadcast_dims;
184 19001 : for (const auto &dim : var_dims.labels())
185 23206 : if (std::find(from_labels.begin(), from_labels.end(), dim) ==
186 11603 : from_labels.end())
187 55 : broadcast_dims.addInner(dim, var_dims[dim]);
188 : else
189 43763 : for (const auto &lab : from_labels)
190 32215 : if (!broadcast_dims.contains(lab)) {
191 : // Need to check if the variable contains that dim, and use the
192 : // variable shape in case we have a bin edge.
193 19595 : if (var_dims.contains(lab))
194 11548 : broadcast_dims.addInner(lab, var_dims[lab]);
195 : else
196 8047 : broadcast_dims.addInner(lab, data_dims[lab]);
197 : }
198 14796 : return broadcast(var, broadcast_dims);
199 7398 : }
200 :
201 : /// Special handling for folding coord along a dim that contains bin edges.
202 4 : Variable fold_bin_edge(const Variable &var, const Dim from_dim,
203 : const Dimensions &to_dims) {
204 4 : auto out = var.slice({from_dim, 0, var.dims()[from_dim] - 1})
205 8 : .fold(from_dim, to_dims) // fold non-overlapping part
206 4 : .as_const(); // mark readonly since we add overlap
207 : // Increase dims without changing strides to obtain first == last
208 4 : out.unchecked_dims().resize(to_dims.inner(), to_dims[to_dims.inner()] + 1);
209 4 : return out;
210 0 : }
211 :
212 : /// Special handling for flattening coord along a dim that contains bin edges.
213 48 : Variable flatten_bin_edge(const Variable &var,
214 : const scipp::span<const Dim> &from_labels,
215 : const Dim to_dim, const Dim bin_edge_dim) {
216 48 : const auto data_shape = var.dims()[bin_edge_dim] - 1;
217 :
218 : // Make sure that the bin edges can be joined together
219 48 : const auto front = var.slice({bin_edge_dim, 0});
220 48 : const auto back = var.slice({bin_edge_dim, data_shape});
221 48 : const auto front_flat = flatten(front, front.dims().labels(), to_dim);
222 48 : const auto back_flat = flatten(back, back.dims().labels(), to_dim);
223 48 : if (front_flat.slice({to_dim, 1, front.dims().volume()}) !=
224 96 : back_flat.slice({to_dim, 0, back.dims().volume() - 1}))
225 34 : return {};
226 :
227 : // Make the bulk slice of the coord, leaving out the last bin edge
228 : const auto bulk =
229 14 : flatten(var.slice({bin_edge_dim, 0, data_shape}), from_labels, to_dim);
230 14 : auto out_dims = bulk.dims();
231 : // To make the container of the right size, we increase to_dim by 1
232 14 : out_dims.resize(to_dim, out_dims[to_dim] + 1);
233 14 : auto out = empty(out_dims, var.unit(), var.dtype(), var.has_variances());
234 14 : copy(bulk, out.slice({to_dim, 0, out_dims[to_dim] - 1}));
235 14 : copy(back_flat.slice({to_dim, back.dims().volume() - 1}),
236 28 : out.slice({to_dim, out_dims[to_dim] - 1}));
237 14 : return out;
238 48 : }
239 :
240 : /// Check if one of the from_labels is a bin edge
241 9685 : Dim bin_edge_in_from_labels(const Variable &var, const Dimensions &array_dims,
242 : const scipp::span<const Dim> &from_labels) {
243 33974 : for (const auto &dim : from_labels)
244 24337 : if (is_edges(array_dims, var.dims(), dim))
245 48 : return dim;
246 9637 : return Dim::Invalid;
247 : }
248 :
249 : } // end anonymous namespace
250 :
251 : /// Fold a single dimension into multiple dimensions
252 : /// ['x': 6] -> ['y': 2, 'z': 3]
253 23 : DataArray fold(const DataArray &a, const Dim from_dim,
254 : const Dimensions &to_dims) {
255 69 : return dataset::transform(a, [&](const auto &var) {
256 69 : if (is_edges(a.dims(), var.dims(), from_dim))
257 4 : return fold_bin_edge(var, from_dim, to_dims);
258 65 : else if (var.dims().contains(from_dim))
259 49 : return fold(var, from_dim, to_dims);
260 : else
261 16 : return var;
262 23 : });
263 : }
264 :
265 : namespace {
266 2290 : void expect_dimension_subset(const core::Dimensions &full_set,
267 : const scipp::span<const Dim> &subset) {
268 6974 : for (const auto &dim : subset) {
269 4687 : if (!full_set.contains(dim)) {
270 9 : throw except::DimensionError{"Expected dimension " + to_string(dim) +
271 12 : "to be in " + to_string(full_set)};
272 : }
273 : }
274 2287 : }
275 : } // namespace
276 :
277 : /// Flatten multiple dimensions into a single dimension:
278 : /// ['y', 'z'] -> ['x']
279 2292 : DataArray flatten(const DataArray &a,
280 : const std::optional<scipp::span<const Dim>> &from_labels,
281 : const Dim to_dim) {
282 2292 : const auto &labels = from_labels.value_or(a.dims().labels());
283 2292 : if (from_labels.has_value() && labels.empty())
284 8 : return DataArray(flatten(a.data(), labels, to_dim), a.coords(), a.masks(),
285 10 : a.attrs());
286 2290 : expect_dimension_subset(a.dims(), labels);
287 9685 : return dataset::transform(a, [&](const auto &in) {
288 9685 : auto var = (&in == &a.data()) ? in : maybe_broadcast(in, labels, a.dims());
289 9685 : const auto bin_edge_dim = bin_edge_in_from_labels(in, a.dims(), labels);
290 9685 : if (bin_edge_dim != Dim::Invalid) {
291 48 : return flatten_bin_edge(var, labels, to_dim, bin_edge_dim);
292 9637 : } else if (a.dims().empty() || var.dims().contains(labels.front())) {
293 : // maybe_broadcast ensures that all variables contain
294 : // all dims in labels, so only need to check labels.front().
295 9568 : return flatten(var, labels, to_dim);
296 : } else {
297 : // This only happens for metadata.
298 69 : return var;
299 : }
300 11972 : });
301 : }
302 :
303 1336 : DataArray transpose(const DataArray &a, const scipp::span<const Dim> dims) {
304 5344 : return {transpose(a.data(), dims), a.coords(), a.masks(), a.attrs(),
305 5344 : a.name()};
306 : }
307 :
308 2 : Dataset transpose(const Dataset &d, const scipp::span<const Dim> dims) {
309 6 : return apply_to_items(d, [](auto &&..._) { return transpose(_...); }, dims);
310 : }
311 :
312 : namespace {
313 : template <class T>
314 37 : T squeeze_impl(const T &x, const std::optional<scipp::span<const Dim>> dims) {
315 37 : auto squeezed = x;
316 77 : for (const auto &dim : dims_for_squeezing(x.dims(), dims)) {
317 40 : squeezed = squeezed.slice({dim, 0});
318 : }
319 : // Copy explicitly to make sure the output does not have its read-only flag
320 : // set.
321 70 : return T(squeezed);
322 37 : }
323 : } // namespace
324 32 : DataArray squeeze(const DataArray &a,
325 : const std::optional<scipp::span<const Dim>> dims) {
326 32 : return squeeze_impl(a, dims);
327 : }
328 :
329 5 : Dataset squeeze(const Dataset &d,
330 : const std::optional<scipp::span<const Dim>> dims) {
331 5 : return squeeze_impl(d, dims);
332 : }
333 :
334 : } // namespace scipp::dataset
|