Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Hezbrock
5 : #include "scipp/dataset/bins.h"
6 : #include "scipp/core/except.h"
7 : #include "scipp/dataset/bin.h"
8 : #include "scipp/dataset/bins_view.h"
9 : #include "scipp/variable/arithmetic.h"
10 : #include "scipp/variable/cumulative.h"
11 : #include "scipp/variable/shape.h"
12 : #include "scipp/variable/util.h"
13 : #include "scipp/variable/variable.h"
14 : #include "scipp/variable/variable_factory.h"
15 :
16 : #include "bind_data_array.h"
17 : #include "dim.h"
18 : #include "pybind11.h"
19 :
20 : using namespace scipp;
21 :
22 : namespace py = pybind11;
23 :
24 : namespace {
25 :
26 : template <class T>
27 6024 : auto call_make_bins(const std::optional<Variable> &begin_arg,
28 : const std::optional<Variable> &end_arg, const Dim dim,
29 : T &&data, const bool validate = true) {
30 6024 : Variable indices;
31 6024 : if (begin_arg.has_value()) {
32 6016 : const auto &begin = *begin_arg;
33 6016 : if (end_arg.has_value()) {
34 6014 : const auto &end = *end_arg;
35 6014 : indices = zip(begin, end);
36 : } else {
37 2 : indices = zip(begin, begin);
38 2 : const auto indices_ = indices.values<scipp::index_pair>();
39 2 : const auto nindex = scipp::size(indices_);
40 6 : for (scipp::index i = 0; i < nindex; ++i) {
41 4 : if (i < nindex - 1)
42 2 : indices_[i].second = indices_[i + 1].first;
43 : else
44 2 : indices_[i].second = data.dims()[dim];
45 : }
46 2 : }
47 8 : } else if (!end_arg.has_value()) {
48 7 : const auto one = scipp::index{1} * units::none;
49 7 : const auto ones = broadcast(one, {dim, data.dims()[dim]});
50 7 : const auto begin = cumsum(ones, dim, CumSumMode::Exclusive);
51 7 : indices = zip(begin, begin + one);
52 7 : } else {
53 1 : throw std::runtime_error("`end` given but not `begin`");
54 : }
55 6023 : if (validate)
56 216 : return make_bins(std::move(indices), dim, std::forward<T>(data));
57 : else
58 5807 : return make_bins_no_validate(std::move(indices), dim,
59 11614 : std::forward<T>(data));
60 6024 : }
61 :
62 9 : template <class T> void bind_bins(pybind11::module &m) {
63 27 : m.def(
64 : "bins",
65 217 : [](const std::optional<Variable> &begin,
66 : const std::optional<Variable> &end, const std::string &dim,
67 : const T &data) {
68 217 : return call_make_bins(begin, end, Dim{dim}, T(data));
69 : },
70 27 : py::arg("begin") = py::none(), py::arg("end") = py::none(),
71 0 : py::arg("dim"), py::arg("data")); // do not release GIL since using
72 : // implicit conversions in functor
73 9 : m.def(
74 : "_bins_no_validate",
75 5807 : [](const Variable &begin, const Variable &end, const std::string &dim,
76 : const T &data) {
77 5807 : return call_make_bins(begin, end, Dim{dim}, T(data), false);
78 : },
79 0 : py::arg("begin"), py::arg("end"), py::arg("dim"),
80 9 : py::arg("data")); // do not release GIL since using
81 : // implicit conversions in functor
82 9 : }
83 :
84 14532 : template <class T> py::dict bins_constituents(const Variable &var) {
85 14532 : auto &&[indices, dim, buffer] = var.constituents<T>();
86 14532 : auto &&[begin, end] = unzip(indices);
87 14532 : py::dict out;
88 14532 : out["begin"] = std::forward<decltype(begin)>(begin);
89 14532 : out["end"] = std::forward<decltype(end)>(end);
90 14532 : out["dim"] = std::string(dim.name());
91 14532 : out["data"] = std::forward<decltype(buffer)>(buffer);
92 29064 : return out;
93 14532 : }
94 :
95 : template <class T, bool HasAlignment = false>
96 6 : void bind_bins_map_view(py::module &m, const std::string &name) {
97 6 : py::class_<T> c(m, name.c_str());
98 6 : bind_common_mutable_view_operators(c);
99 6 : bind_pop(c);
100 : if constexpr (HasAlignment) {
101 : bind_set_aligned(c);
102 : }
103 6 : }
104 :
105 3 : template <class T> void bind_bins_view(py::module &m) {
106 : bind_helper_view<str_items_view,
107 3 : decltype(dataset::bins_view<T>(Variable{}).coords())>(
108 : m, "_BinsCoords");
109 : bind_helper_view<items_view,
110 3 : decltype(dataset::bins_view<T>(Variable{}).masks())>(
111 : m, "_BinsMasks");
112 : bind_helper_view<str_keys_view,
113 3 : decltype(dataset::bins_view<T>(Variable{}).coords())>(
114 : m, "_BinsCoords");
115 : bind_helper_view<keys_view,
116 3 : decltype(dataset::bins_view<T>(Variable{}).masks())>(
117 : m, "_BinsMasks");
118 : bind_helper_view<values_view,
119 3 : decltype(dataset::bins_view<T>(Variable{}).coords())>(
120 : m, "_BinsCoords");
121 : bind_helper_view<values_view,
122 3 : decltype(dataset::bins_view<T>(Variable{}).masks())>(
123 : m, "_BinsMasks");
124 :
125 3 : py::class_<decltype(dataset::bins_view<T>(Variable{}))> c(
126 : m, "_BinsViewDataArray");
127 3 : bind_bins_map_view<decltype(dataset::bins_view<T>(Variable{}).meta())>(
128 : m, "_BinsMeta");
129 : bind_mutable_view_no_dim<
130 3 : decltype(dataset::bins_view<T>(Variable{}).coords())>(
131 : m, "_BinsCoords", "Dict of event coords.");
132 3 : bind_mutable_view<decltype(dataset::bins_view<T>(Variable{}).masks())>(
133 : m, "_BinsMasks", "Dict of event masks.");
134 3 : bind_bins_map_view<decltype(dataset::bins_view<T>(Variable{}).attrs())>(
135 : m, "_BinsAttrs");
136 3 : bind_data_array_properties(c);
137 3 : m.def("_bins_view",
138 15143 : [](const Variable &var) { return dataset::bins_view<T>(var); });
139 3 : }
140 :
141 : template <class T, class Data>
142 3742 : auto bins_like(const Variable &bins, const Data &data) {
143 3742 : auto &&[idx, dim, buf] = bins.constituents<T>();
144 7484 : auto out = make_bins_no_validate(idx, dim, empty_like(data, buf.dims()));
145 3742 : out.setSlice(Slice{}, data);
146 7482 : return out;
147 3743 : }
148 :
149 3 : template <class Data> void bind_bins_like(py::module &m) {
150 3 : m.def("bins_like", [](const Variable &bins, const Data &data) {
151 3742 : if (bins.dtype() == dtype<bucket<Variable>>)
152 4 : return bins_like<Variable>(bins, data);
153 3738 : if (bins.dtype() == dtype<bucket<DataArray>>)
154 3738 : return bins_like<DataArray>(bins, data);
155 0 : throw except::TypeError(
156 : "In `bins_like`: Prototype must contain binned data but got dtype=" +
157 : to_string(bins.dtype()));
158 : });
159 3 : }
160 :
161 : } // namespace
162 :
163 3 : void init_buckets(py::module &m) {
164 3 : bind_bins<Variable>(m);
165 3 : bind_bins<DataArray>(m);
166 3 : bind_bins<Dataset>(m);
167 :
168 3 : bind_bins_like<Variable>(m);
169 :
170 3 : m.def("is_bins", variable::is_bins);
171 3 : m.def("is_bins",
172 60347 : [](const DataArray &array) { return dataset::is_bins(array); });
173 3 : m.def("is_bins",
174 0 : [](const Dataset &dataset) { return dataset::is_bins(dataset); });
175 :
176 3 : m.def("bins_constituents", [](const Variable &var) {
177 14532 : const auto dt = var.dtype();
178 14532 : if (dt == dtype<bucket<Variable>>)
179 849 : return bins_constituents<Variable>(var);
180 13683 : if (dt == dtype<bucket<DataArray>>)
181 13682 : return bins_constituents<DataArray>(var);
182 1 : if (dt == dtype<bucket<Dataset>>)
183 1 : return bins_constituents<Dataset>(var);
184 0 : throw except::TypeError("'constituents' does not support dtype " +
185 0 : to_string(dt));
186 : });
187 :
188 3 : m.def(
189 : "lookup_previous",
190 40 : [](const DataArray &function, const Variable &x, const std::string &dim,
191 : const std::optional<Variable> &fill_value) {
192 40 : return dataset::lookup_previous(function, x, Dim{dim}, fill_value);
193 : },
194 0 : py::call_guard<py::gil_scoped_release>());
195 :
196 3 : auto buckets = m.def_submodule("buckets");
197 3 : buckets.def(
198 : "concatenate",
199 0 : [](const Variable &a, const Variable &b) {
200 0 : return dataset::buckets::concatenate(a, b);
201 : },
202 0 : py::call_guard<py::gil_scoped_release>());
203 3 : buckets.def(
204 : "concatenate",
205 0 : [](const DataArray &a, const DataArray &b) {
206 0 : return dataset::buckets::concatenate(a, b);
207 : },
208 0 : py::call_guard<py::gil_scoped_release>());
209 3 : buckets.def(
210 : "append",
211 0 : [](Variable &a, const Variable &b) {
212 0 : return dataset::buckets::append(a, b);
213 : },
214 0 : py::call_guard<py::gil_scoped_release>());
215 3 : buckets.def(
216 : "append",
217 0 : [](DataArray &a, const DataArray &b) {
218 0 : return dataset::buckets::append(a, b);
219 : },
220 0 : py::call_guard<py::gil_scoped_release>());
221 3 : buckets.def(
222 : "map",
223 127 : [](const DataArray &function, const Variable &x, const std::string &dim,
224 : const std::optional<Variable> &fill_value) {
225 127 : return dataset::buckets::map(function, x, Dim{dim}, fill_value);
226 : },
227 0 : py::call_guard<py::gil_scoped_release>());
228 3 : buckets.def(
229 : "scale",
230 111 : [](DataArray &array, const DataArray &histogram, const std::string &dim) {
231 111 : return dataset::buckets::scale(array, histogram, Dim{dim});
232 : },
233 0 : py::call_guard<py::gil_scoped_release>());
234 :
235 3 : m.def(
236 : "bin",
237 4228 : [](const DataArray &array, const std::vector<Variable> &edges,
238 : const std::vector<Variable> &groups,
239 : const std::vector<std::string> &erase) {
240 8456 : return dataset::bin(array, edges, groups, to_dim_type(erase));
241 : },
242 0 : py::arg("array"), py::arg("edges"),
243 6 : py::arg("groups") = std::vector<Variable>{},
244 6 : py::arg("erase") = std::vector<std::string>{},
245 0 : py::call_guard<py::gil_scoped_release>());
246 :
247 3 : bind_bins_view<DataArray>(m);
248 3 : }
|