Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #pragma once
6 :
7 : #include <algorithm>
8 : #include <variant>
9 :
10 : #include "scipp/core/dtype.h"
11 : #include "scipp/core/eigen.h"
12 : #include "scipp/core/spatial_transforms.h"
13 : #include "scipp/core/tag_util.h"
14 : #include "scipp/dataset/dataset.h"
15 : #include "scipp/dataset/except.h"
16 : #include "scipp/variable/shape.h"
17 : #include "scipp/variable/variable.h"
18 : #include "scipp/variable/variable_concept.h"
19 :
20 : #include "dtype.h"
21 : #include "numpy.h"
22 : #include "py_object.h"
23 : #include "pybind11.h"
24 : #include "unit.h"
25 :
26 : namespace py = pybind11;
27 : using namespace scipp;
28 :
29 22 : template <class T> void remove_variances(T &obj) {
30 : if constexpr (std::is_same_v<T, DataArray>)
31 1 : obj.data().setVariances(Variable());
32 : else
33 21 : obj.setVariances(Variable());
34 2 : }
35 :
36 37 : template <class T> void init_variances(T &obj) {
37 : if constexpr (std::is_same_v<T, DataArray>)
38 23 : obj.data().setVariances(Variable(obj.data()));
39 : else
40 14 : obj.setVariances(Variable(obj));
41 37 : }
42 :
43 : /// Add element size as factor to strides.
44 : template <class T>
45 : std::vector<scipp::index>
46 48094 : numpy_strides(const scipp::span<const scipp::index> &s) {
47 48094 : std::vector<scipp::index> strides(s.size());
48 99315 : for (size_t i = 0; i < strides.size(); ++i) {
49 51221 : strides[i] = sizeof(T) * s[i];
50 : }
51 48094 : return strides;
52 : }
53 :
54 181844 : template <typename T> decltype(auto) get_data_variable(T &&x) {
55 : if constexpr (std::is_same_v<std::decay_t<T>, scipp::Variable>) {
56 179779 : return std::forward<T>(x);
57 : } else {
58 2065 : return std::forward<T>(x).data();
59 : }
60 : }
61 :
62 : /// Return a pybind11 handle to the VariableConcept of x.
63 : /// Refers to the data variable if T is a DataArray.
64 48262 : template <typename T> auto get_data_variable_concept_handle(T &&x) {
65 48838 : return py::cast(get_data_variable(std::forward<T>(x)).data_handle());
66 : }
67 :
68 : template <class... Ts> class as_ElementArrayViewImpl;
69 :
70 : class DataAccessHelper {
71 : template <class... Ts> friend class as_ElementArrayViewImpl;
72 :
73 : template <class Getter, class T, class View>
74 48095 : static py::object as_py_array_t_impl(View &&view) {
75 96190 : const auto get_dtype = [&view]() {
76 : if constexpr (std::is_same_v<T, scipp::core::time_point>) {
77 : // Need a custom implementation because py::dtype::of only works with
78 : // types supported by the buffer protocol.
79 104 : return py::dtype("datetime64[" + to_numpy_time_string(view.unit()) +
80 106 : ']');
81 : } else {
82 : static_cast<void>(view);
83 47991 : return py::dtype::of<T>();
84 : }
85 : };
86 48095 : auto &&var = get_data_variable(view);
87 48095 : const auto &dims = view.dims();
88 48095 : if (var.is_readonly()) {
89 1440 : auto array =
90 720 : py::array{get_dtype(), dims.shape(), numpy_strides<T>(var.strides()),
91 480 : Getter::template get<T>(std::as_const(view)).data(),
92 : get_data_variable_concept_handle(view)};
93 240 : py::detail::array_proxy(array.ptr())->flags &=
94 : ~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
95 : // no automatic move because of type mismatch
96 240 : return py::object{std::move(array)};
97 240 : } else {
98 47854 : return py::array{get_dtype(), dims.shape(),
99 47854 : numpy_strides<T>(var.strides()),
100 47854 : Getter::template get<T>(view).data(),
101 191417 : get_data_variable_concept_handle(view)};
102 : }
103 575 : }
104 :
105 : struct get_values {
106 91554 : template <class T, class View> static constexpr auto get(View &&view) {
107 91554 : return view.template values<T>();
108 : }
109 : };
110 :
111 : struct get_variances {
112 535 : template <class T, class View> static constexpr auto get(View &&view) {
113 535 : return view.template variances<T>();
114 : }
115 : };
116 : };
117 :
118 43380 : inline void expect_scalar(const Dimensions &dims, const std::string_view name) {
119 43380 : if (dims != Dimensions{}) {
120 9 : std::ostringstream oss;
121 : oss << "The '" << name << "' property cannot be used with non-scalar "
122 9 : << "Variables. Got dimensions " << to_string(dims) << ". Did you mean '"
123 9 : << name << "s'?";
124 9 : throw except::DimensionError(oss.str());
125 9 : }
126 43371 : }
127 :
128 : template <class... Ts> class as_ElementArrayViewImpl {
129 : using get_values = DataAccessHelper::get_values;
130 : using get_variances = DataAccessHelper::get_variances;
131 :
132 : template <class View>
133 : using outVariant_t = std::variant<ElementArrayView<Ts>...>;
134 :
135 : template <class Getter, class View>
136 43995 : static outVariant_t<View> get(View &view) {
137 43995 : const DType type = view.dtype();
138 43995 : if (type == dtype<double>)
139 17133 : return {Getter::template get<double>(view)};
140 26862 : if (type == dtype<float>)
141 195 : return {Getter::template get<float>(view)};
142 : if constexpr (std::is_same_v<Getter, get_values>) {
143 26667 : if (type == dtype<int64_t>)
144 25218 : return {Getter::template get<int64_t>(view)};
145 1449 : if (type == dtype<int32_t>)
146 178 : return {Getter::template get<int32_t>(view)};
147 1271 : if (type == dtype<bool>)
148 487 : return {Getter::template get<bool>(view)};
149 784 : if (type == dtype<std::string>)
150 100 : return {Getter::template get<std::string>(view)};
151 684 : if (type == dtype<scipp::core::time_point>)
152 193 : return {Getter::template get<scipp::core::time_point>(view)};
153 491 : if (type == dtype<Variable>)
154 28 : return {Getter::template get<Variable>(view)};
155 463 : if (type == dtype<DataArray>)
156 15 : return {Getter::template get<DataArray>(view)};
157 448 : if (type == dtype<Dataset>)
158 10 : return {Getter::template get<Dataset>(view)};
159 438 : if (type == dtype<Eigen::Vector3d>)
160 21 : return {Getter::template get<Eigen::Vector3d>(view)};
161 417 : if (type == dtype<Eigen::Matrix3d>)
162 16 : return {Getter::template get<Eigen::Matrix3d>(view)};
163 401 : if (type == dtype<Eigen::Affine3d>)
164 0 : return {Getter::template get<Eigen::Affine3d>(view)};
165 401 : if (type == dtype<scipp::core::Quaternion>)
166 0 : return {Getter::template get<scipp::core::Quaternion>(view)};
167 401 : if (type == dtype<scipp::core::Translation>)
168 0 : return {Getter::template get<scipp::core::Translation>(view)};
169 401 : if (type == dtype<scipp::python::PyObject>)
170 40 : return {Getter::template get<scipp::python::PyObject>(view)};
171 361 : if (type == dtype<bucket<Variable>>)
172 92 : return {Getter::template get<bucket<Variable>>(view)};
173 269 : if (type == dtype<bucket<DataArray>>)
174 269 : return {Getter::template get<bucket<DataArray>>(view)};
175 0 : if (type == dtype<bucket<Dataset>>)
176 0 : return {Getter::template get<bucket<Dataset>>(view)};
177 : }
178 0 : throw std::runtime_error("Value-access not implemented for this type.");
179 : }
180 :
181 : template <class View>
182 526 : static void set(const Dimensions &dims, const units::Unit unit,
183 : const View &view, const py::object &obj) {
184 526 : std::visit(
185 1007 : [&dims, &unit, &obj](const auto &view_) {
186 : using T =
187 : typename std::remove_reference_t<decltype(view_)>::value_type;
188 526 : copy_array_into_view(cast_to_array_like<T>(obj, unit), view_, dims);
189 : },
190 : view);
191 481 : }
192 :
193 : template <typename View, typename T>
194 : static auto
195 30 : get_matrix_elements(const View &view,
196 : const std::initializer_list<scipp::index> shape) {
197 30 : auto elems = get_data_variable(view).template elements<T>();
198 30 : elems = fold(
199 : elems, Dim::InternalStructureComponent,
200 60 : Dimensions({Dim::InternalStructureRow, Dim::InternalStructureColumn},
201 : shape));
202 60 : std::vector labels(elems.dims().labels().begin(),
203 60 : elems.dims().labels().end());
204 30 : std::iter_swap(labels.end() - 2, labels.end() - 1);
205 60 : return transpose(elems, labels);
206 30 : }
207 :
208 102 : template <class View> static auto structure_elements(View &&view) {
209 102 : if (view.dtype() == dtype<Eigen::Vector3d>) {
210 48 : return get_data_variable(view).template elements<Eigen::Vector3d>();
211 54 : } else if (view.dtype() == dtype<Eigen::Matrix3d>) {
212 20 : return get_matrix_elements<View, Eigen::Matrix3d>(view, {3, 3});
213 34 : } else if (view.dtype() == dtype<scipp::core::Quaternion>) {
214 12 : return get_data_variable(view)
215 12 : .template elements<scipp::core::Quaternion>();
216 22 : } else if (view.dtype() == dtype<scipp::core::Translation>) {
217 12 : return get_data_variable(view)
218 12 : .template elements<scipp::core::Translation>();
219 10 : } else if (view.dtype() == dtype<Eigen::Affine3d>) {
220 10 : return get_matrix_elements<View, Eigen::Affine3d>(view, {4, 4});
221 : } else {
222 0 : throw std::runtime_error("Unsupported structured dtype");
223 : }
224 : }
225 :
226 : public:
227 : template <class Getter, class View>
228 48504 : static py::object get_py_array_t(py::object &obj) {
229 48504 : auto &view = obj.cast<View &>();
230 48263 : if (!std::is_const_v<View> && get_data_variable(view).is_readonly())
231 : return as_ElementArrayViewImpl<const Ts...>::template get_py_array_t<
232 241 : Getter, const View>(obj);
233 48263 : const DType type = view.dtype();
234 48263 : if (type == dtype<double>)
235 34330 : return DataAccessHelper::as_py_array_t_impl<Getter, double>(view);
236 13933 : if (type == dtype<float>)
237 212 : return DataAccessHelper::as_py_array_t_impl<Getter, float>(view);
238 13721 : if (type == dtype<int64_t>)
239 12529 : return DataAccessHelper::as_py_array_t_impl<Getter, int64_t>(view);
240 1192 : if (type == dtype<int32_t>)
241 378 : return DataAccessHelper::as_py_array_t_impl<Getter, int32_t>(view);
242 814 : if (type == dtype<bool>)
243 452 : return DataAccessHelper::as_py_array_t_impl<Getter, bool>(view);
244 362 : if (type == dtype<scipp::core::time_point>)
245 : return DataAccessHelper::as_py_array_t_impl<Getter,
246 : scipp::core::time_point>(
247 104 : view);
248 258 : if (is_structured(type))
249 : return DataAccessHelper::as_py_array_t_impl<Getter, double>(
250 90 : structure_elements(view));
251 : return std::visit(
252 336 : [&view](const auto &data) {
253 168 : const auto &dims = view.dims();
254 : // We return an individual item in two cases:
255 : // 1. For 0-D data (consistent with numpy behavior, e.g., when slicing
256 : // a 1-D array).
257 : // 2. For 1-D event data, where the individual item is then a
258 : // vector-like object.
259 168 : if (dims.ndim() == 0) {
260 : return make_scalar(data[0], get_data_variable_concept_handle(view),
261 55 : view);
262 : } else {
263 : // Returning view (span or ElementArrayView) by value. This
264 : // references data in variable, so it must be kept alive. There is
265 : // no policy that supports this, so we use `keep_alive_impl`
266 : // manually.
267 113 : auto ret = py::cast(data, py::return_value_policy::move);
268 113 : pybind11::detail::keep_alive_impl(
269 : ret, get_data_variable_concept_handle(view));
270 113 : return ret;
271 113 : }
272 : },
273 168 : get<Getter>(view));
274 : }
275 :
276 47859 : template <class Var> static py::object values(py::object &object) {
277 47859 : return get_py_array_t<get_values, Var>(object);
278 : }
279 :
280 16332 : template <class Var> static py::object variances(py::object &object) {
281 16332 : if (!object.cast<Var &>().has_variances())
282 15937 : return py::none();
283 395 : return get_py_array_t<get_variances, Var>(object);
284 : }
285 :
286 : template <class Var>
287 492 : static void set_values(Var &view, const py::object &obj) {
288 492 : if (is_structured(view.dtype())) {
289 7 : auto elems = structure_elements(view);
290 7 : set_values(elems, obj);
291 7 : } else {
292 485 : set(view.dims(), view.unit(), get<get_values>(view), obj);
293 : }
294 447 : }
295 :
296 : template <class Var>
297 63 : static void set_variances(Var &view, const py::object &obj) {
298 63 : if (obj.is_none())
299 22 : return remove_variances(view);
300 41 : if (!view.has_variances())
301 37 : init_variances(view);
302 41 : set(view.dims(), view.unit(), get<get_variances>(view), obj);
303 : }
304 :
305 : private:
306 36158 : static auto numpy_attr(const char *const name) {
307 36158 : return py::module_::import("numpy").attr(name);
308 : }
309 :
310 : template <class Scalar, class View>
311 37105 : static auto make_scalar(Scalar &&scalar, py::object parent,
312 : const View &view) {
313 : if constexpr (std::is_same_v<std::decay_t<Scalar>,
314 : scipp::python::PyObject>) {
315 : // Returning PyObject. This increments the reference counter of
316 : // the element, so it is ok if the parent `parent` (the variable)
317 : // goes out of scope.
318 34 : return scalar.to_pybind();
319 : } else if constexpr (std::is_same_v<std::decay_t<Scalar>,
320 : core::time_point>) {
321 92 : const auto np_datetime64 = numpy_attr("datetime64");
322 93 : return np_datetime64(scalar.time_since_epoch(),
323 368 : to_numpy_time_string(view.unit()));
324 92 : } else if constexpr (std::is_same_v<std::decay_t<Scalar>, int32_t>) {
325 154 : return numpy_attr("int32")(scalar);
326 : } else if constexpr (std::is_same_v<std::decay_t<Scalar>, int64_t>) {
327 21343 : return numpy_attr("int64")(scalar);
328 : } else if constexpr (std::is_same_v<std::decay_t<Scalar>, float>) {
329 175 : return numpy_attr("float32")(scalar);
330 : } else if constexpr (std::is_same_v<std::decay_t<Scalar>, double>) {
331 14394 : return numpy_attr("float64")(scalar);
332 : } else if constexpr (!std::is_reference_v<Scalar>) {
333 : // Views such as slices of data arrays for binned data are
334 : // returned by value and require separate handling to avoid the
335 : // py::return_value_policy::reference_internal in the default case
336 : // below.
337 298 : return py::cast(scalar, py::return_value_policy::move);
338 : } else {
339 : // Returning reference to element in variable. Return-policy
340 : // reference_internal keeps alive `parent`. Note that an attempt to
341 : // pass `keep_alive` as a call policy to `def_property` failed,
342 : // resulting in exception from pybind11, so we have to handle it by
343 : // hand here.
344 : return py::cast(scalar, py::return_value_policy::reference_internal,
345 615 : std::move(parent));
346 : }
347 : }
348 :
349 : // Helper function object to get a scalar value or variance.
350 : template <class View> struct GetScalarVisitor {
351 : py::object &self; // The object we're getting the value / variance from.
352 : std::remove_reference_t<View> &view; // self as a view.
353 :
354 37050 : template <class Data> auto operator()(const Data &&data) const {
355 37050 : return make_scalar(data[0], self, view);
356 : }
357 : };
358 :
359 : // Helper function object to set a scalar value or variance.
360 : template <class View> struct SetScalarVisitor {
361 : const py::object &rhs; // The object we are assigning.
362 : std::remove_reference_t<View> &view; // View of self.
363 :
364 6250 : template <class Data> auto operator()(Data &&data) const {
365 : using T = typename std::decay_t<decltype(data)>::value_type;
366 : if constexpr (std::is_same_v<T, scipp::python::PyObject>)
367 1 : data[0] = rhs;
368 : else if constexpr (std::is_same_v<T, scipp::core::time_point>) {
369 : // TODO support int
370 51 : if (view.unit() != parse_datetime_dtype(rhs)) {
371 : // TODO implement
372 42 : throw std::invalid_argument(
373 : "Conversion of time units is not implemented.");
374 : }
375 9 : data[0] = make_time_point(rhs.template cast<py::buffer>());
376 : } else
377 6198 : data[0] = rhs.cast<T>();
378 6208 : }
379 : };
380 :
381 : public:
382 : // Return a scalar value from a variable, implicitly requiring that the
383 : // variable is 0-dimensional and thus has only a single item.
384 37111 : template <class Var> static py::object value(py::object &obj) {
385 37111 : auto &view = obj.cast<Var &>();
386 36968 : if (!std::is_const_v<Var> && get_data_variable(view).is_readonly())
387 : return as_ElementArrayViewImpl<const Ts...>::template value<const Var>(
388 143 : obj);
389 36968 : expect_scalar(view.dims(), "value");
390 73927 : if (view.dtype() == dtype<scipp::core::Quaternion> ||
391 73927 : view.dtype() == dtype<scipp::core::Translation> ||
392 73924 : view.dtype() == dtype<Eigen::Affine3d>)
393 9 : return get_py_array_t<get_values, Var>(obj);
394 : return std::visit(GetScalarVisitor<decltype(view)>{obj, view},
395 36956 : get<get_values>(view));
396 : }
397 : // Return a scalar variance from a variable, implicitly requiring that the
398 : // variable is 0-dimensional and thus has only a single item.
399 156 : template <class Var> static py::object variance(py::object &obj) {
400 156 : auto &view = obj.cast<Var &>();
401 154 : if (!std::is_const_v<Var> && get_data_variable(view).is_readonly())
402 : return as_ElementArrayViewImpl<const Ts...>::template variance<const Var>(
403 2 : obj);
404 154 : expect_scalar(view.dims(), "variance");
405 150 : if (!view.has_variances())
406 56 : return py::none();
407 : return std::visit(GetScalarVisitor<decltype(view)>{obj, view},
408 94 : get<get_variances>(view));
409 : }
410 : // Set a scalar value in a variable, implicitly requiring that the
411 : // variable is 0-dimensional and thus has only a single item.
412 6252 : template <class Var> static void set_value(Var &view, const py::object &obj) {
413 6252 : expect_scalar(view.dims(), "value");
414 6251 : if (is_structured(view.dtype())) {
415 5 : auto elems = structure_elements(view);
416 5 : set_values(elems, obj);
417 5 : } else {
418 6246 : std::visit(SetScalarVisitor<decltype(view)>{obj, view},
419 : get<get_values>(view));
420 : }
421 6208 : }
422 : // Set a scalar variance in a variable, implicitly requiring that the
423 : // variable is 0-dimensional and thus has only a single item.
424 : template <class Var>
425 6 : static void set_variance(Var &view, const py::object &obj) {
426 6 : expect_scalar(view.dims(), "variance");
427 5 : if (obj.is_none())
428 0 : return remove_variances(view);
429 5 : if (!view.has_variances())
430 0 : init_variances(view);
431 :
432 5 : std::visit(SetScalarVisitor<decltype(view)>{obj, view},
433 : get<get_variances>(view));
434 : }
435 : };
436 :
437 : using as_ElementArrayView = as_ElementArrayViewImpl<
438 : double, float, int64_t, int32_t, bool, std::string, scipp::core::time_point,
439 : Variable, DataArray, Dataset, bucket<Variable>, bucket<DataArray>,
440 : bucket<Dataset>, Eigen::Vector3d, Eigen::Matrix3d, scipp::python::PyObject,
441 : Eigen::Affine3d, scipp::core::Quaternion, scipp::core::Translation>;
442 :
443 : template <class T, class... Ignored>
444 9 : void bind_common_data_properties(pybind11::class_<T, Ignored...> &c) {
445 0 : c.def_property_readonly(
446 : "dims",
447 47009 : [](const T &self) {
448 47009 : const auto &labels = self.dims().labels();
449 47009 : const auto ndim = static_cast<size_t>(self.ndim());
450 47009 : py::tuple dims(ndim);
451 119863 : for (size_t i = 0; i < ndim; ++i) {
452 72854 : dims[i] = labels[i].name();
453 : }
454 94018 : return dims;
455 0 : },
456 : "Dimension labels of the data (read-only).",
457 9 : py::return_value_policy::move);
458 9 : c.def_property_readonly(
459 53302 : "dim", [](const T &self) { return self.dim().name(); },
460 : "The only dimension label for 1-dimensional data, raising an exception "
461 : "if the data is not 1-dimensional.");
462 0 : c.def_property_readonly(
463 7632 : "ndim", [](const T &self) { return self.ndim(); },
464 : "Number of dimensions of the data (read-only).",
465 9 : py::return_value_policy::move);
466 0 : c.def_property_readonly(
467 : "shape",
468 3502 : [](const T &self) {
469 3502 : const auto &sizes = self.dims().sizes();
470 3502 : const auto ndim = static_cast<size_t>(self.ndim());
471 3502 : py::tuple shape(ndim);
472 10083 : for (size_t i = 0; i < ndim; ++i) {
473 6581 : shape[i] = sizes[i];
474 : }
475 7004 : return shape;
476 0 : },
477 9 : "Shape of the data (read-only).", py::return_value_policy::move);
478 0 : c.def_property_readonly(
479 : "sizes",
480 12268 : [](const T &self) {
481 12268 : const auto &dims = self.dims();
482 : // Use py::dict directly instead of std::map in order to guarantee
483 : // that items are stored in the order of insertion.
484 12268 : py::dict sizes;
485 42510 : for (const auto label : dims.labels()) {
486 30242 : sizes[label.name().c_str()] = dims[label];
487 : }
488 12268 : return sizes;
489 0 : },
490 : "dict mapping dimension labels to dimension sizes (read-only).",
491 9 : py::return_value_policy::move);
492 9 : }
493 :
494 : template <class T, class... Ignored>
495 6 : void bind_data_properties(pybind11::class_<T, Ignored...> &c) {
496 6 : bind_common_data_properties(c);
497 6 : c.def_property_readonly(
498 24541 : "dtype", [](const T &self) { return self.dtype(); },
499 : "Data type contained in the variable.");
500 6 : c.def_property(
501 : "unit",
502 31850 : [](const T &self) {
503 31850 : return self.unit() == units::none ? std::optional<units::Unit>()
504 32337 : : self.unit();
505 : },
506 596 : [](T &self, const ProtoUnit &unit) {
507 310 : self.setUnit(unit_or_default(unit, self.dtype()));
508 : },
509 : "Physical unit of the data.");
510 6 : c.def_property("values", &as_ElementArrayView::values<T>,
511 : &as_ElementArrayView::set_values<T>,
512 : "Array of values of the data.");
513 6 : c.def_property("variances", &as_ElementArrayView::variances<T>,
514 : &as_ElementArrayView::set_variances<T>,
515 : "Array of variances of the data.");
516 6 : c.def_property(
517 : "value", &as_ElementArrayView::value<T>,
518 : &as_ElementArrayView::set_value<T>,
519 : "The only value for 0-dimensional data, raising an exception if the data "
520 : "is not 0-dimensional.");
521 6 : c.def_property(
522 : "variance", &as_ElementArrayView::variance<T>,
523 : &as_ElementArrayView::set_variance<T>,
524 : "The only variance for 0-dimensional data, raising an exception if the "
525 : "data is not 0-dimensional.");
526 : if constexpr (std::is_same_v<T, DataArray> || std::is_same_v<T, Variable>) {
527 0 : c.def_property_readonly(
528 8 : "size", [](const T &self) { return self.dims().volume(); },
529 : "Number of elements in the data (read-only).",
530 6 : py::return_value_policy::move);
531 : }
532 6 : }
|