Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #pragma once
6 :
7 : #include <algorithm>
8 : #include <variant>
9 :
10 : #include <pybind11/typing.h>
11 :
12 : #include "scipp/core/dtype.h"
13 : #include "scipp/core/eigen.h"
14 : #include "scipp/core/spatial_transforms.h"
15 : #include "scipp/core/tag_util.h"
16 : #include "scipp/dataset/dataset.h"
17 : #include "scipp/dataset/except.h"
18 : #include "scipp/variable/shape.h"
19 : #include "scipp/variable/variable.h"
20 : #include "scipp/variable/variable_concept.h"
21 :
22 : #include "dtype.h"
23 : #include "numpy.h"
24 : #include "py_object.h"
25 : #include "pybind11.h"
26 : #include "unit.h"
27 :
28 : namespace py = pybind11;
29 : using namespace scipp;
30 :
31 30 : template <class T> void remove_variances(T &obj) {
32 : if constexpr (std::is_same_v<T, DataArray>)
33 1 : obj.data().setVariances(Variable());
34 : else
35 29 : obj.setVariances(Variable());
36 10 : }
37 :
38 37 : template <class T> void init_variances(T &obj) {
39 : if constexpr (std::is_same_v<T, DataArray>)
40 23 : obj.data().setVariances(Variable(obj.data()));
41 : else
42 14 : obj.setVariances(Variable(obj));
43 37 : }
44 :
45 : /// Add element size as factor to strides.
46 : template <class T>
47 : std::vector<scipp::index>
48 46558 : numpy_strides(const scipp::span<const scipp::index> &s) {
49 46558 : std::vector<scipp::index> strides(s.size());
50 96616 : for (size_t i = 0; i < strides.size(); ++i) {
51 50058 : strides[i] = sizeof(T) * s[i];
52 : }
53 46558 : return strides;
54 : }
55 :
56 184214 : template <typename T> decltype(auto) get_data_variable(T &&x) {
57 : if constexpr (std::is_same_v<std::decay_t<T>, scipp::Variable>) {
58 182038 : return std::forward<T>(x);
59 : } else {
60 2176 : return std::forward<T>(x).data();
61 : }
62 : }
63 :
64 : /// Return a pybind11 handle to the VariableConcept of x.
65 : /// Refers to the data variable if T is a DataArray.
66 46735 : template <typename T> auto get_data_variable_concept_handle(T &&x) {
67 47336 : return py::cast(get_data_variable(std::forward<T>(x)).data_handle());
68 : }
69 :
70 : template <class... Ts> class as_ElementArrayViewImpl;
71 :
72 : class DataAccessHelper {
73 : template <class... Ts> friend class as_ElementArrayViewImpl;
74 :
75 : template <class Getter, class T, class View>
76 46559 : static py::object as_py_array_t_impl(View &&view) {
77 93118 : const auto get_dtype = [&view]() {
78 : if constexpr (std::is_same_v<T, scipp::core::time_point>) {
79 : // Need a custom implementation because py::dtype::of only works with
80 : // types supported by the buffer protocol.
81 104 : return py::dtype("datetime64[" + to_numpy_time_string(view.unit()) +
82 106 : ']');
83 : } else {
84 : static_cast<void>(view);
85 46455 : return py::dtype::of<T>();
86 : }
87 : };
88 46559 : auto &&var = get_data_variable(view);
89 46559 : const auto &dims = view.dims();
90 46559 : if (var.is_readonly()) {
91 1260 : auto array =
92 630 : py::array{get_dtype(), dims.shape(), numpy_strides<T>(var.strides()),
93 420 : Getter::template get<T>(std::as_const(view)).data(),
94 : get_data_variable_concept_handle(view)};
95 210 : py::detail::array_proxy(array.ptr())->flags &=
96 : ~py::detail::npy_api::NPY_ARRAY_WRITEABLE_;
97 : // no automatic move because of type mismatch
98 210 : return py::object{std::move(array)};
99 210 : } else {
100 46348 : return py::array{get_dtype(), dims.shape(),
101 46348 : numpy_strides<T>(var.strides()),
102 46348 : Getter::template get<T>(view).data(),
103 185393 : get_data_variable_concept_handle(view)};
104 : }
105 600 : }
106 :
107 : struct get_values {
108 167913 : template <class T, class View> static constexpr auto get(View &&view) {
109 167913 : return view.template values<T>();
110 : }
111 : };
112 :
113 : struct get_variances {
114 589 : template <class T, class View> static constexpr auto get(View &&view) {
115 589 : return view.template variances<T>();
116 : }
117 : };
118 : };
119 :
120 121300 : inline void expect_scalar(const Dimensions &dims, const std::string_view name) {
121 121300 : if (dims != Dimensions{}) {
122 9 : std::ostringstream oss;
123 : oss << "The '" << name << "' property cannot be used with non-scalar "
124 9 : << "Variables. Got dimensions " << to_string(dims) << ". Did you mean '"
125 9 : << name << "s'?";
126 9 : throw except::DimensionError(oss.str());
127 9 : }
128 121291 : }
129 :
130 : template <class... Ts> class as_ElementArrayViewImpl {
131 : using get_values = DataAccessHelper::get_values;
132 : using get_variances = DataAccessHelper::get_variances;
133 :
134 : template <class View>
135 : using outVariant_t = std::variant<ElementArrayView<Ts>...>;
136 :
137 : template <class Getter, class View>
138 121944 : static outVariant_t<View> get(View &view) {
139 121944 : const DType type = view.dtype();
140 121944 : if (type == dtype<double>)
141 93100 : return {Getter::template get<double>(view)};
142 28844 : if (type == dtype<float>)
143 165 : return {Getter::template get<float>(view)};
144 : if constexpr (std::is_same_v<Getter, get_values>) {
145 28679 : if (type == dtype<int64_t>)
146 27168 : return {Getter::template get<int64_t>(view)};
147 1511 : if (type == dtype<int32_t>)
148 198 : return {Getter::template get<int32_t>(view)};
149 1313 : if (type == dtype<bool>)
150 470 : return {Getter::template get<bool>(view)};
151 843 : if (type == dtype<std::string>)
152 109 : return {Getter::template get<std::string>(view)};
153 734 : if (type == dtype<scipp::core::time_point>)
154 193 : return {Getter::template get<scipp::core::time_point>(view)};
155 541 : if (type == dtype<Variable>)
156 28 : return {Getter::template get<Variable>(view)};
157 513 : if (type == dtype<DataArray>)
158 15 : return {Getter::template get<DataArray>(view)};
159 498 : if (type == dtype<Dataset>)
160 10 : return {Getter::template get<Dataset>(view)};
161 488 : if (type == dtype<Eigen::Vector3d>)
162 21 : return {Getter::template get<Eigen::Vector3d>(view)};
163 467 : if (type == dtype<Eigen::Matrix3d>)
164 16 : return {Getter::template get<Eigen::Matrix3d>(view)};
165 451 : if (type == dtype<Eigen::Affine3d>)
166 0 : return {Getter::template get<Eigen::Affine3d>(view)};
167 451 : if (type == dtype<scipp::core::Quaternion>)
168 0 : return {Getter::template get<scipp::core::Quaternion>(view)};
169 451 : if (type == dtype<scipp::core::Translation>)
170 0 : return {Getter::template get<scipp::core::Translation>(view)};
171 451 : if (type == dtype<scipp::python::PyObject>)
172 40 : return {Getter::template get<scipp::python::PyObject>(view)};
173 411 : if (type == dtype<bucket<Variable>>)
174 92 : return {Getter::template get<bucket<Variable>>(view)};
175 319 : if (type == dtype<bucket<DataArray>>)
176 319 : return {Getter::template get<bucket<DataArray>>(view)};
177 0 : if (type == dtype<bucket<Dataset>>)
178 0 : return {Getter::template get<bucket<Dataset>>(view)};
179 : }
180 0 : throw std::runtime_error("Value-access not implemented for this type.");
181 : }
182 :
183 : template <class View>
184 563 : static void set(const Dimensions &dims, const units::Unit unit,
185 : const View &view, const py::object &obj) {
186 563 : std::visit(
187 1081 : [&dims, &unit, &obj](const auto &view_) {
188 : using T =
189 : typename std::remove_reference_t<decltype(view_)>::value_type;
190 563 : copy_array_into_view(cast_to_array_like<T>(obj, unit), view_, dims);
191 : },
192 : view);
193 518 : }
194 :
195 : template <typename View, typename T>
196 : static auto
197 30 : get_matrix_elements(const View &view,
198 : const std::initializer_list<scipp::index> shape) {
199 30 : auto elems = get_data_variable(view).template elements<T>();
200 30 : elems = fold(
201 : elems, Dim::InternalStructureComponent,
202 60 : Dimensions({Dim::InternalStructureRow, Dim::InternalStructureColumn},
203 : shape));
204 60 : std::vector labels(elems.dims().labels().begin(),
205 60 : elems.dims().labels().end());
206 30 : std::iter_swap(labels.end() - 2, labels.end() - 1);
207 60 : return transpose(elems, labels);
208 30 : }
209 :
210 102 : template <class View> static auto structure_elements(View &&view) {
211 102 : if (view.dtype() == dtype<Eigen::Vector3d>) {
212 48 : return get_data_variable(view).template elements<Eigen::Vector3d>();
213 54 : } else if (view.dtype() == dtype<Eigen::Matrix3d>) {
214 20 : return get_matrix_elements<View, Eigen::Matrix3d>(view, {3, 3});
215 34 : } else if (view.dtype() == dtype<scipp::core::Quaternion>) {
216 12 : return get_data_variable(view)
217 12 : .template elements<scipp::core::Quaternion>();
218 22 : } else if (view.dtype() == dtype<scipp::core::Translation>) {
219 12 : return get_data_variable(view)
220 12 : .template elements<scipp::core::Translation>();
221 10 : } else if (view.dtype() == dtype<Eigen::Affine3d>) {
222 10 : return get_matrix_elements<View, Eigen::Affine3d>(view, {4, 4});
223 : } else {
224 0 : throw std::runtime_error("Unsupported structured dtype");
225 : }
226 : }
227 :
228 : public:
229 : template <class Getter, class View>
230 46947 : static py::object get_py_array_t(py::object &obj) {
231 46947 : auto &view = obj.cast<View &>();
232 46736 : if (!std::is_const_v<View> && get_data_variable(view).is_readonly())
233 : return as_ElementArrayViewImpl<const Ts...>::template get_py_array_t<
234 211 : Getter, const View>(obj);
235 46736 : const DType type = view.dtype();
236 46736 : if (type == dtype<double>)
237 37739 : return DataAccessHelper::as_py_array_t_impl<Getter, double>(view);
238 8997 : if (type == dtype<float>)
239 223 : return DataAccessHelper::as_py_array_t_impl<Getter, float>(view);
240 8774 : if (type == dtype<int64_t>)
241 7793 : return DataAccessHelper::as_py_array_t_impl<Getter, int64_t>(view);
242 981 : if (type == dtype<int32_t>)
243 228 : return DataAccessHelper::as_py_array_t_impl<Getter, int32_t>(view);
244 753 : if (type == dtype<bool>)
245 382 : return DataAccessHelper::as_py_array_t_impl<Getter, bool>(view);
246 371 : if (type == dtype<scipp::core::time_point>)
247 : return DataAccessHelper::as_py_array_t_impl<Getter,
248 : scipp::core::time_point>(
249 104 : view);
250 267 : if (is_structured(type))
251 : return DataAccessHelper::as_py_array_t_impl<Getter, double>(
252 90 : structure_elements(view));
253 : return std::visit(
254 354 : [&view](const auto &data) {
255 177 : const auto &dims = view.dims();
256 : // We return an individual item in two cases:
257 : // 1. For 0-D data (consistent with numpy behavior, e.g., when slicing
258 : // a 1-D array).
259 : // 2. For 1-D event data, where the individual item is then a
260 : // vector-like object.
261 177 : if (dims.ndim() == 0) {
262 : return make_scalar(data[0], get_data_variable_concept_handle(view),
263 55 : view);
264 : } else {
265 : // Returning view (span or ElementArrayView) by value. This
266 : // references data in variable, so it must be kept alive. There is
267 : // no policy that supports this, so we use `keep_alive_impl`
268 : // manually.
269 122 : auto ret = py::cast(data, py::return_value_policy::move);
270 122 : pybind11::detail::keep_alive_impl(
271 : ret, get_data_variable_concept_handle(view));
272 122 : return ret;
273 122 : }
274 : },
275 177 : get<Getter>(view));
276 : }
277 :
278 46240 : template <class Var> static py::object values(py::object &object) {
279 46240 : return get_py_array_t<get_values, Var>(object);
280 : }
281 :
282 16888 : template <class Var> static py::object variances(py::object &object) {
283 16888 : if (!object.cast<Var &>().has_variances())
284 16401 : return py::none();
285 487 : return get_py_array_t<get_variances, Var>(object);
286 : }
287 :
288 : template <class Var>
289 529 : static void set_values(Var &view, const py::object &obj) {
290 529 : if (is_structured(view.dtype())) {
291 7 : auto elems = structure_elements(view);
292 7 : set_values(elems, obj);
293 7 : } else {
294 522 : set(view.dims(), view.unit(), get<get_values>(view), obj);
295 : }
296 484 : }
297 :
298 : template <class Var>
299 71 : static void set_variances(Var &view, const py::object &obj) {
300 71 : if (obj.is_none())
301 30 : return remove_variances(view);
302 41 : if (!view.has_variances())
303 37 : init_variances(view);
304 41 : set(view.dims(), view.unit(), get<get_variances>(view), obj);
305 : }
306 :
307 : private:
308 43068 : static auto numpy_attr(const char *const name) {
309 43068 : return py::module_::import("numpy").attr(name);
310 : }
311 :
312 : template <class Scalar, class View>
313 44048 : static auto make_scalar(Scalar &&scalar, py::object parent,
314 : const View &view) {
315 : if constexpr (std::is_same_v<std::decay_t<Scalar>,
316 : scipp::python::PyObject>) {
317 : // Returning PyObject. This increments the reference counter of
318 : // the element, so it is ok if the parent `parent` (the variable)
319 : // goes out of scope.
320 34 : return scalar.to_pybind();
321 : } else if constexpr (std::is_same_v<std::decay_t<Scalar>,
322 : core::time_point>) {
323 92 : const auto np_datetime64 = numpy_attr("datetime64");
324 93 : return np_datetime64(scalar.time_since_epoch(),
325 368 : to_numpy_time_string(view.unit()));
326 92 : } else if constexpr (std::is_same_v<std::decay_t<Scalar>, int32_t>) {
327 174 : return numpy_attr("int32")(scalar);
328 : } else if constexpr (std::is_same_v<std::decay_t<Scalar>, int64_t>) {
329 23288 : return numpy_attr("int64")(scalar);
330 : } else if constexpr (std::is_same_v<std::decay_t<Scalar>, float>) {
331 144 : return numpy_attr("float32")(scalar);
332 : } else if constexpr (std::is_same_v<std::decay_t<Scalar>, double>) {
333 19370 : return numpy_attr("float64")(scalar);
334 : } else if constexpr (!std::is_reference_v<Scalar>) {
335 : // Views such as slices of data arrays for binned data are
336 : // returned by value and require separate handling to avoid the
337 : // py::return_value_policy::reference_internal in the default case
338 : // below.
339 348 : return py::cast(scalar, py::return_value_policy::move);
340 : } else {
341 : // Returning reference to element in variable. Return-policy
342 : // reference_internal keeps alive `parent`. Note that an attempt to
343 : // pass `keep_alive` as a call policy to `def_property` failed,
344 : // resulting in exception from pybind11, so we have to handle it by
345 : // hand here.
346 : return py::cast(scalar, py::return_value_policy::reference_internal,
347 598 : std::move(parent));
348 : }
349 : }
350 :
351 : // Helper function object to get a scalar value or variance.
352 : template <class View> struct GetScalarVisitor {
353 : py::object &self; // The object we're getting the value / variance from.
354 : std::remove_reference_t<View> &view; // self as a view.
355 :
356 43993 : template <class Data> auto operator()(const Data &&data) const {
357 43993 : return make_scalar(data[0], self, view);
358 : }
359 : };
360 :
361 : // Helper function object to set a scalar value or variance.
362 : template <class View> struct SetScalarVisitor {
363 : const py::object &rhs; // The object we are assigning.
364 : std::remove_reference_t<View> &view; // View of self.
365 :
366 77210 : template <class Data> auto operator()(Data &&data) const {
367 : using T = typename std::decay_t<decltype(data)>::value_type;
368 : if constexpr (std::is_same_v<T, scipp::python::PyObject>)
369 1 : data[0] = rhs;
370 : else if constexpr (std::is_same_v<T, scipp::core::time_point>) {
371 : // TODO support int
372 51 : if (view.unit() != parse_datetime_dtype(rhs)) {
373 : // TODO implement
374 42 : throw std::invalid_argument(
375 : "Conversion of time units is not implemented.");
376 : }
377 9 : data[0] = make_time_point(rhs.template cast<py::buffer>());
378 : } else
379 77158 : data[0] = rhs.cast<T>();
380 77168 : }
381 : };
382 :
383 : public:
384 : // Return a scalar value from a variable, implicitly requiring that the
385 : // variable is 0-dimensional and thus has only a single item.
386 44092 : template <class Var> static py::object value(py::object &obj) {
387 44092 : auto &view = obj.cast<Var &>();
388 43949 : if (!std::is_const_v<Var> && get_data_variable(view).is_readonly())
389 : return as_ElementArrayViewImpl<const Ts...>::template value<const Var>(
390 143 : obj);
391 43949 : expect_scalar(view.dims(), "value");
392 87889 : if (view.dtype() == dtype<scipp::core::Quaternion> ||
393 87889 : view.dtype() == dtype<scipp::core::Translation> ||
394 87886 : view.dtype() == dtype<Eigen::Affine3d>)
395 9 : return get_py_array_t<get_values, Var>(obj);
396 : return std::visit(GetScalarVisitor<decltype(view)>{obj, view},
397 43937 : get<get_values>(view));
398 : }
399 : // Return a scalar variance from a variable, implicitly requiring that the
400 : // variable is 0-dimensional and thus has only a single item.
401 135 : template <class Var> static py::object variance(py::object &obj) {
402 135 : auto &view = obj.cast<Var &>();
403 133 : if (!std::is_const_v<Var> && get_data_variable(view).is_readonly())
404 : return as_ElementArrayViewImpl<const Ts...>::template variance<const Var>(
405 2 : obj);
406 133 : expect_scalar(view.dims(), "variance");
407 129 : if (!view.has_variances())
408 73 : return py::none();
409 : return std::visit(GetScalarVisitor<decltype(view)>{obj, view},
410 56 : get<get_variances>(view));
411 : }
412 : // Set a scalar value in a variable, implicitly requiring that the
413 : // variable is 0-dimensional and thus has only a single item.
414 77212 : template <class Var> static void set_value(Var &view, const py::object &obj) {
415 77212 : expect_scalar(view.dims(), "value");
416 77211 : if (is_structured(view.dtype())) {
417 5 : auto elems = structure_elements(view);
418 5 : set_values(elems, obj);
419 5 : } else {
420 77206 : std::visit(SetScalarVisitor<decltype(view)>{obj, view},
421 : get<get_values>(view));
422 : }
423 77168 : }
424 : // Set a scalar variance in a variable, implicitly requiring that the
425 : // variable is 0-dimensional and thus has only a single item.
426 : template <class Var>
427 6 : static void set_variance(Var &view, const py::object &obj) {
428 6 : expect_scalar(view.dims(), "variance");
429 5 : if (obj.is_none())
430 0 : return remove_variances(view);
431 5 : if (!view.has_variances())
432 0 : init_variances(view);
433 :
434 5 : std::visit(SetScalarVisitor<decltype(view)>{obj, view},
435 : get<get_variances>(view));
436 : }
437 : };
438 :
439 : using as_ElementArrayView = as_ElementArrayViewImpl<
440 : double, float, int64_t, int32_t, bool, std::string, scipp::core::time_point,
441 : Variable, DataArray, Dataset, bucket<Variable>, bucket<DataArray>,
442 : bucket<Dataset>, Eigen::Vector3d, Eigen::Matrix3d, scipp::python::PyObject,
443 : Eigen::Affine3d, scipp::core::Quaternion, scipp::core::Translation>;
444 :
445 : template <class T, class... Ignored>
446 9 : void bind_common_data_properties(pybind11::class_<T, Ignored...> &c) {
447 0 : c.def_property_readonly(
448 : "dims",
449 86582 : [](const T &self) {
450 86582 : const auto &labels = self.dims().labels();
451 86582 : const auto ndim = static_cast<size_t>(self.ndim());
452 86582 : py::typing::Tuple<py::str, py::ellipsis> dims(ndim);
453 211409 : for (size_t i = 0; i < ndim; ++i) {
454 124827 : dims[i] = labels[i].name();
455 : }
456 173164 : return dims;
457 0 : },
458 : "Dimension labels of the data (read-only).",
459 9 : py::return_value_policy::move);
460 9 : c.def_property_readonly(
461 53526 : "dim", [](const T &self) { return self.dim().name(); },
462 : "The only dimension label for 1-dimensional data, raising an exception "
463 : "if the data is not 1-dimensional.");
464 0 : c.def_property_readonly(
465 10710 : "ndim", [](const T &self) { return self.ndim(); },
466 : "Number of dimensions of the data (read-only).",
467 9 : py::return_value_policy::move);
468 0 : c.def_property_readonly(
469 : "shape",
470 14920 : [](const T &self) {
471 14920 : const auto &sizes = self.dims().sizes();
472 14920 : const auto ndim = static_cast<size_t>(self.ndim());
473 14920 : py::typing::Tuple<int, py::ellipsis> shape(ndim);
474 26832 : for (size_t i = 0; i < ndim; ++i) {
475 11912 : shape[i] = sizes[i];
476 : }
477 29840 : return shape;
478 0 : },
479 9 : "Shape of the data (read-only).", py::return_value_policy::move);
480 0 : c.def_property_readonly(
481 : "sizes",
482 12439 : [](const T &self) {
483 12439 : const auto &dims = self.dims();
484 : // Use py::dict directly instead of std::map in order to guarantee
485 : // that items are stored in the order of insertion.
486 12439 : py::typing::Dict<py::str, int> sizes;
487 43032 : for (const auto label : dims.labels()) {
488 30593 : sizes[label.name().c_str()] = dims[label];
489 : }
490 12439 : return sizes;
491 0 : },
492 : "dict mapping dimension labels to dimension sizes (read-only).",
493 9 : py::return_value_policy::move);
494 9 : }
495 :
496 : template <class T, class... Ignored>
497 6 : void bind_data_properties(pybind11::class_<T, Ignored...> &c) {
498 6 : bind_common_data_properties(c);
499 6 : c.def_property_readonly(
500 25228 : "dtype", [](const T &self) { return self.dtype(); },
501 : "Data type contained in the variable.");
502 6 : c.def_property(
503 : "unit",
504 34670 : [](const T &self) {
505 34670 : return self.unit() == units::none ? std::optional<units::Unit>()
506 35182 : : self.unit();
507 : },
508 596 : [](T &self, const ProtoUnit &unit) {
509 310 : self.setUnit(unit_or_default(unit, self.dtype()));
510 : },
511 : "Physical unit of the data.");
512 6 : c.def_property("values", &as_ElementArrayView::values<T>,
513 : &as_ElementArrayView::set_values<T>,
514 : "Array of values of the data.");
515 6 : c.def_property("variances", &as_ElementArrayView::variances<T>,
516 : &as_ElementArrayView::set_variances<T>,
517 : "Array of variances of the data.");
518 6 : c.def_property(
519 : "value", &as_ElementArrayView::value<T>,
520 : &as_ElementArrayView::set_value<T>,
521 : "The only value for 0-dimensional data, raising an exception if the data "
522 : "is not 0-dimensional.");
523 6 : c.def_property(
524 : "variance", &as_ElementArrayView::variance<T>,
525 : &as_ElementArrayView::set_variance<T>,
526 : "The only variance for 0-dimensional data, raising an exception if the "
527 : "data is not 0-dimensional.");
528 : if constexpr (std::is_same_v<T, DataArray> || std::is_same_v<T, Variable>) {
529 0 : c.def_property_readonly(
530 8 : "size", [](const T &self) { return self.dims().volume(); },
531 : "Number of elements in the data (read-only).",
532 6 : py::return_value_policy::move);
533 : }
534 6 : }
|