Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #include <algorithm>
6 :
7 : #include "scipp/core/dimensions.h"
8 :
9 : #include "scipp/variable/arithmetic.h"
10 : #include "scipp/variable/bins.h"
11 : #include "scipp/variable/creation.h"
12 : #include "scipp/variable/except.h"
13 : #include "scipp/variable/shape.h"
14 : #include "scipp/variable/util.h"
15 : #include "scipp/variable/variable_concept.h"
16 : #include "scipp/variable/variable_factory.h"
17 :
18 : using namespace scipp::core;
19 :
20 : namespace scipp::variable {
21 :
22 38381 : Variable broadcast(const Variable &var, const Dimensions &dims) {
23 38381 : return var.broadcast(dims);
24 : }
25 :
26 : namespace {
27 34 : auto get_bin_sizes(const scipp::span<const Variable> vars) {
28 34 : std::vector<Variable> sizes;
29 34 : sizes.reserve(vars.size());
30 110 : for (const auto &var : vars)
31 76 : sizes.emplace_back(bin_sizes(var));
32 34 : return sizes;
33 0 : }
34 : } // namespace
35 :
36 2605 : Variable concat(const scipp::span<const Variable> vars, const Dim dim) {
37 2605 : if (vars.empty())
38 1 : throw std::invalid_argument("Cannot concat empty list.");
39 : const auto it =
40 2604 : std::find_if(vars.begin(), vars.end(),
41 2828 : [dim](const auto &var) { return var.dims().contains(dim); });
42 2604 : Dimensions dims;
43 : // Expand dims for inputs that do not contain dim already. Favor order given
44 : // by first input, if not found add as outer dim.
45 2604 : if (it == vars.end()) {
46 129 : dims = vars.front().dims();
47 129 : dims.add(dim, 1);
48 : } else {
49 2475 : dims = it->dims();
50 2475 : dims.resize(dim, 1);
51 : }
52 2604 : std::vector<Variable> tmp;
53 2604 : scipp::index size = 0;
54 7925 : for (const auto &var : vars) {
55 5322 : if (var.dims().contains(dim))
56 4949 : tmp.emplace_back(var);
57 : else
58 373 : tmp.emplace_back(broadcast(var, dims));
59 5321 : size += tmp.back().dims()[dim];
60 : }
61 2603 : dims.resize(dim, size);
62 2603 : Variable out;
63 2603 : if (is_bins(vars.front())) {
64 34 : out = empty_like(vars.front(), {}, concat(get_bin_sizes(vars), dim));
65 : } else {
66 2569 : out = empty_like(vars.front(), dims);
67 : }
68 2603 : scipp::index offset = 0;
69 7899 : for (const auto &var : tmp) {
70 5320 : const auto extent = var.dims()[dim];
71 5344 : out.data().copy(var, out.slice({dim, offset, offset + extent}));
72 5296 : offset += extent;
73 : }
74 5158 : return out;
75 2653 : }
76 :
77 78 : Variable resize(const Variable &var, const Dim dim, const scipp::index size,
78 : const FillValue fill) {
79 78 : auto dims = var.dims();
80 78 : dims.resize(dim, size);
81 234 : return special_like(broadcast(Variable(var, Dimensions{}), dims), fill);
82 78 : }
83 :
84 : /// Return new variable resized to given shape.
85 : ///
86 : /// For bucket variables the values of `shape` are interpreted as bucket sizes
87 : /// to RESERVE and the buffer is also resized accordingly. The emphasis is on
88 : /// "reserve", i.e., buffer size and begin indices are set up accordingly, but
89 : /// end=begin is set, i.e., the buckets are empty, but may be grown up to the
90 : /// requested size. For normal (non-bucket) variable the values of `shape` are
91 : /// ignored, i.e., only `shape.dims()` is used to determine the shape of the
92 : /// output.
93 0 : Variable resize(const Variable &var, const Variable &shape) {
94 0 : return {shape.dims(), var.data().makeDefaultFromParent(shape)};
95 : }
96 :
97 255 : Variable fold(const Variable &view, const Dim from_dim,
98 : const Dimensions &to_dims) {
99 255 : return view.fold(from_dim, to_dims);
100 : }
101 :
102 14506 : Variable flatten(const Variable &view,
103 : const scipp::span<const Dim> &from_labels, const Dim to_dim) {
104 14506 : if (from_labels.empty()) {
105 31 : auto out(view);
106 31 : out.unchecked_dims().addInner(to_dim, 1);
107 31 : out.unchecked_strides().push_back(1);
108 31 : return out;
109 31 : }
110 14475 : const auto &labels = view.dims().labels();
111 14475 : auto it = std::search(labels.begin(), labels.end(), from_labels.begin(),
112 : from_labels.end());
113 14475 : if (it == labels.end())
114 3 : throw except::DimensionError("Can only flatten a contiguous set of "
115 6 : "dimensions in the correct order");
116 14472 : scipp::index size = 1;
117 14472 : auto to = std::distance(labels.begin(), it);
118 14472 : auto out(view);
119 41390 : for (const auto &from : from_labels) {
120 31131 : size *= out.dims().size(to);
121 31131 : if (from == from_labels.back()) {
122 10259 : out.unchecked_dims().resize(from, size);
123 10259 : out.unchecked_dims().replace_key(from, to_dim);
124 : } else {
125 20872 : if (out.strides()[to] != out.dims().size(to + 1) * out.strides()[to + 1])
126 8426 : return flatten(copy(view), from_labels, to_dim);
127 16659 : out.unchecked_dims().erase(from);
128 16659 : out.unchecked_strides().erase(to);
129 : }
130 : }
131 10259 : return out;
132 14472 : }
133 :
134 7449 : Variable transpose(const Variable &var, const scipp::span<const Dim> dims) {
135 7449 : return var.transpose(dims);
136 : }
137 :
138 : std::vector<scipp::Dim>
139 9979 : dims_for_squeezing(const core::Sizes &data_dims,
140 : const std::optional<scipp::span<const Dim>> selected_dims) {
141 9979 : if (selected_dims.has_value()) {
142 11257 : for (const auto &dim : *selected_dims) {
143 2755 : if (const auto size = data_dims[dim]; size != 1)
144 15 : throw except::DimensionError("Cannot squeeze '" + to_string(dim) +
145 20 : "' of length " + std::to_string(size) +
146 10 : ", must be of length 1.");
147 : }
148 8502 : return std::vector<Dim>{selected_dims->begin(), selected_dims->end()};
149 : } else {
150 1472 : std::vector<Dim> length_1_dims;
151 1472 : length_1_dims.reserve(data_dims.size());
152 2982 : for (const auto &dim : data_dims) {
153 1510 : if (data_dims[dim] == 1) {
154 170 : length_1_dims.push_back(dim);
155 : }
156 : }
157 1472 : return length_1_dims;
158 1472 : }
159 : }
160 :
161 9942 : Variable squeeze(const Variable &var,
162 : const std::optional<scipp::span<const Dim>> dims) {
163 9942 : auto squeezed = var;
164 12820 : for (const auto &dim : dims_for_squeezing(var.dims(), dims)) {
165 2878 : squeezed = squeezed.slice({dim, 0});
166 9939 : }
167 9939 : return squeezed;
168 3 : }
169 :
170 : } // namespace scipp::variable
|