Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #include "dtype.h"
6 :
7 : #include <regex>
8 :
9 : #include "scipp/core/eigen.h"
10 : #include "scipp/core/string.h"
11 : #include "scipp/dataset/dataset.h"
12 : #include "scipp/variable/variable.h"
13 :
14 : #include "format.h"
15 : #include "py_object.h"
16 : #include "pybind11.h"
17 :
18 : using namespace scipp;
19 : using namespace scipp::core;
20 :
21 : namespace py = pybind11;
22 :
23 : namespace {
24 : /// 'kind' character codes for numpy dtypes
25 : enum class DTypeKind : char {
26 : Float = 'f',
27 : Int = 'i',
28 : Bool = 'b',
29 : Datetime = 'M',
30 : Object = 'O',
31 : String = 'U',
32 : RawData = 'V',
33 : };
34 :
35 115355 : constexpr bool operator==(const char a, const DTypeKind b) {
36 115355 : return a == static_cast<char>(b);
37 : }
38 :
39 : enum class DTypeSize : scipp::index {
40 : Float64 = 8,
41 : Float32 = 4,
42 : Int64 = 8,
43 : Int32 = 4,
44 : };
45 :
46 43399 : constexpr bool operator==(const scipp::index a, const DTypeSize b) {
47 43399 : return a == static_cast<scipp::index>(b);
48 : }
49 : } // namespace
50 :
51 3 : void init_dtype(py::module &m) {
52 : py::class_<DType> PyDType(m, "DType", R"(
53 : Representation of a data type of a Variable in Scipp.
54 : See https://scipp.github.io/reference/dtype.html for details.
55 :
56 : The data types ``VariableView``, ``DataArrayView``, and ``DatasetView`` are used for
57 : objects containing binned data. They cannot be used directly to create arrays of bins.
58 3 : )");
59 32 : PyDType.def(py::init([](const py::object &x) { return scipp_dtype(x); }))
60 3 : .def("__eq__",
61 32172 : [](const DType &self, const py::object &other) {
62 32172 : return self == scipp_dtype(other);
63 : })
64 1787 : .def("__str__", [](const DType &self) { return to_string(self); })
65 3 : .def("__repr__", [](const DType &self) {
66 60 : return "DType('" + to_string(self) + "')";
67 : });
68 :
69 : // Explicit list of dtypes to bind since core::dtypeNameRegistry contains
70 : // types that are for internal use only and are never returned to Python.
71 3 : for (const auto &t : {
72 : dtype<bool>,
73 : dtype<int32_t>,
74 : dtype<int64_t>,
75 : dtype<float>,
76 : dtype<double>,
77 : dtype<std::string>,
78 : dtype<Eigen::Vector3d>,
79 : dtype<Eigen::Matrix3d>,
80 : dtype<Eigen::Affine3d>,
81 : dtype<core::Quaternion>,
82 : dtype<core::Translation>,
83 : dtype<core::time_point>,
84 : dtype<Variable>,
85 : dtype<DataArray>,
86 : dtype<Dataset>,
87 : dtype<core::bin<Variable>>,
88 : dtype<core::bin<DataArray>>,
89 : dtype<core::bin<Dataset>>,
90 : dtype<python::PyObject>,
91 63 : })
92 57 : PyDType.def_property_readonly_static(
93 57 : core::dtypeNameRegistry().at(t).c_str(),
94 20612 : [t](const py::object &) { return t; });
95 3 : }
96 :
97 81524 : DType dtype_of(const py::object &x) {
98 81524 : if (x.is_none()) {
99 38982 : return dtype<void>;
100 42542 : } else if (py::isinstance<py::buffer>(x)) {
101 : // Cannot use hasattr(x, "dtype") as that would catch Variables as well.
102 24608 : return scipp_dtype(x.attr("dtype"));
103 17934 : } else if (py::isinstance<py::bool_>(x)) {
104 : // bool needs to come before int because bools are instances of int.
105 256 : return core::dtype<bool>;
106 17678 : } else if (py::isinstance<py::float_>(x)) {
107 12699 : return core::dtype<double>;
108 4979 : } else if (py::isinstance<py::int_>(x)) {
109 4723 : return core::dtype<int64_t>;
110 256 : } else if (py::isinstance<py::str>(x)) {
111 177 : return core::dtype<std::string>;
112 79 : } else if (py::isinstance<variable::Variable>(x)) {
113 22 : return core::dtype<variable::Variable>;
114 57 : } else if (py::isinstance<dataset::DataArray>(x)) {
115 7 : return core::dtype<dataset::DataArray>;
116 50 : } else if (py::isinstance<dataset::Dataset>(x)) {
117 8 : return core::dtype<dataset::Dataset>;
118 : } else {
119 42 : return core::dtype<python::PyObject>;
120 : }
121 : }
122 :
123 40603 : scipp::core::DType scipp_dtype(const py::dtype &type) {
124 40603 : if (type.kind() == DTypeKind::Float) {
125 13853 : if (type.itemsize() == DTypeSize::Float64)
126 13428 : return scipp::core::dtype<double>;
127 425 : if (type.itemsize() == DTypeSize::Float32)
128 425 : return scipp::core::dtype<float>;
129 : }
130 26750 : if (type.kind() == DTypeKind::Int) {
131 23473 : if (type.itemsize() == DTypeSize::Int64)
132 17825 : return scipp::core::dtype<std::int64_t>;
133 5648 : if (type.itemsize() == DTypeSize::Int32)
134 5648 : return scipp::core::dtype<std::int32_t>;
135 : }
136 3277 : if (type.kind() == DTypeKind::Bool)
137 1448 : return scipp::core::dtype<bool>;
138 1829 : if (type.kind() == DTypeKind::String)
139 117 : return scipp::core::dtype<std::string>;
140 1712 : if (type.kind() == DTypeKind::Datetime) {
141 1709 : return scipp::core::dtype<scipp::core::time_point>;
142 : }
143 3 : if (type.kind() == DTypeKind::Object) {
144 3 : return scipp::core::dtype<scipp::python::PyObject>;
145 : }
146 0 : throw std::runtime_error(
147 0 : "Unsupported numpy dtype: " +
148 0 : py::str(static_cast<py::handle>(type)).cast<std::string>() +
149 : "\n"
150 : "Supported types are: bool, float32, float64,"
151 0 : " int32, int64, string, datetime64, and object");
152 : }
153 :
154 9 : scipp::core::DType dtype_from_scipp_class(const py::object &type) {
155 : // Using the __name__ because we would otherwise have to get a handle
156 : // to the Python classes for our C++ classes. And I don't know how
157 : // to do that. This approach can break if people (including us) pull
158 : // shenanigans with the classes in Python!
159 9 : if (type.attr("__name__").cast<std::string>() == "Variable") {
160 3 : return dtype<Variable>;
161 6 : } else if (type.attr("__name__").cast<std::string>() == "DataArray") {
162 3 : return dtype<DataArray>;
163 3 : } else if (type.attr("__name__").cast<std::string>() == "Dataset") {
164 3 : return dtype<Dataset>;
165 : } else {
166 0 : throw std::invalid_argument("Invalid dtype");
167 : }
168 : }
169 :
170 104157 : scipp::core::DType scipp_dtype(const py::object &type) {
171 : // Check None first, then native scipp Dtype, then numpy.dtype
172 104157 : if (type.is_none())
173 35685 : return dtype<void>;
174 : try {
175 68472 : return type.cast<DType>();
176 40618 : } catch (const py::cast_error &) {
177 41534 : if (py::isinstance<py::type>(type) &&
178 41534 : type.attr("__module__").cast<std::string>() == "scipp._scipp.core") {
179 9 : return dtype_from_scipp_class(type);
180 : }
181 :
182 40609 : auto np_dtype = py::dtype::from_args(type);
183 40603 : if (np_dtype.kind() == DTypeKind::RawData) {
184 0 : throw std::invalid_argument(
185 : "Unsupported numpy dtype: raw data. This can happen when you pass a "
186 0 : "Python object instead of a class. Got dtype=`" +
187 0 : py::str(type).cast<std::string>() + '`');
188 : }
189 40603 : return scipp_dtype(np_dtype);
190 40618 : }
191 : }
192 :
193 : namespace {
194 45004 : bool is_default(const ProtoUnit &unit) {
195 45004 : return std::holds_alternative<DefaultUnit>(unit);
196 : }
197 : } // namespace
198 :
199 : std::tuple<scipp::core::DType, std::optional<scipp::units::Unit>>
200 45010 : cast_dtype_and_unit(const pybind11::object &dtype, const ProtoUnit &unit) {
201 45010 : const auto scipp_dtype = ::scipp_dtype(dtype);
202 45004 : if (scipp_dtype == core::dtype<core::time_point>) {
203 192 : units::Unit deduced_unit = parse_datetime_dtype(dtype);
204 192 : if (!is_default(unit)) {
205 143 : const auto unit_ = unit_or_default(unit, scipp_dtype);
206 143 : if (deduced_unit != units::one && unit_ != deduced_unit) {
207 84 : throw std::invalid_argument(
208 168 : python::format("The unit encoded in the dtype (", deduced_unit,
209 168 : ") conflicts with the given unit (", unit_, ")."));
210 : } else {
211 59 : deduced_unit = unit_;
212 : }
213 : }
214 108 : return std::tuple{scipp_dtype, deduced_unit};
215 : } else {
216 : // Concrete dtype not known at this point so we cannot determine the default
217 : // unit here. Therefore nullopt is returned.
218 44812 : return std::tuple{scipp_dtype, is_default(unit)
219 74629 : ? std::optional<scipp::units::Unit>()
220 74629 : : unit_or_default(unit)};
221 : }
222 : }
223 :
224 4098 : void ensure_conversion_possible(const DType from, const DType to,
225 : const std::string &data_name) {
226 5483 : if (from == to || (core::is_fundamental(from) && core::is_fundamental(to)) ||
227 5483 : to == dtype<python::PyObject> ||
228 61 : (core::is_int(from) && to == dtype<core::time_point>)) {
229 4083 : return; // These are allowed.
230 : }
231 15 : throw std::invalid_argument(python::format("Cannot convert ", data_name,
232 30 : " from type ", from, " to ", to));
233 : }
234 :
235 39762 : DType common_dtype(const py::object &values, const py::object &variances,
236 : const DType dtype, const DType default_dtype) {
237 39762 : const DType values_dtype = dtype_of(values);
238 39762 : const DType variances_dtype = dtype_of(variances);
239 39762 : if (dtype == core::dtype<void>) {
240 : // Get dtype solely from data.
241 35683 : if (values_dtype == core::dtype<void>) {
242 20 : if (variances_dtype == core::dtype<void>) {
243 0 : return default_dtype;
244 : }
245 20 : return variances_dtype;
246 : } else {
247 36424 : if (variances_dtype != core::dtype<void> &&
248 761 : values_dtype != variances_dtype) {
249 0 : throw std::invalid_argument(python::format(
250 : "The dtypes of the 'values' (", values_dtype, ") and 'variances' (",
251 : variances_dtype,
252 : ") arguments do not match. You can specify a dtype explicitly to"
253 0 : " trigger a conversion if applicable."));
254 : }
255 35663 : return values_dtype;
256 : }
257 : } else { // dtype != core::dtype<void>
258 : // Combine data and explicit dtype with potential conversion.
259 4079 : if (values_dtype != core::dtype<void>) {
260 4109 : ensure_conversion_possible(values_dtype, dtype, "values");
261 : }
262 4064 : if (variances_dtype != core::dtype<void>) {
263 19 : ensure_conversion_possible(variances_dtype, dtype, "variances");
264 : }
265 4064 : return dtype;
266 : }
267 : }
268 :
269 612 : bool has_datetime_dtype(const py::object &obj) {
270 612 : if (py::hasattr(obj, "dtype")) {
271 578 : return obj.attr("dtype").attr("kind").cast<char>() == DTypeKind::Datetime;
272 : } else {
273 : // numpy.datetime64 and numpy.ndarray both have 'dtype' attributes.
274 : // Mark everything else as not-datetime.
275 34 : return false;
276 : }
277 : }
278 :
279 : [[nodiscard]] scipp::units::Unit
280 1421 : parse_datetime_dtype(const std::string &dtype_name) {
281 : static std::regex datetime_regex{R"(datetime64(\[(\w+)\])?)",
282 1421 : std::regex_constants::optimize};
283 1421 : constexpr size_t unit_idx = 2;
284 1421 : std::smatch match;
285 2842 : if (!std::regex_match(dtype_name, match, datetime_regex) ||
286 1421 : match.size() != 3) {
287 0 : throw std::invalid_argument("Invalid dtype, expected datetime64, got " +
288 0 : dtype_name);
289 : }
290 :
291 1421 : if (match.length(unit_idx) == 0) {
292 50 : return scipp::units::dimensionless;
293 1371 : } else if (match[unit_idx] == "s") {
294 293 : return scipp::units::s;
295 1078 : } else if (match[unit_idx] == "us") {
296 191 : return scipp::units::us;
297 887 : } else if (match[unit_idx] == "ns") {
298 243 : return scipp::units::ns;
299 644 : } else if (match[unit_idx] == "m") {
300 : // In np.datetime64, m means minute.
301 10 : return units::Unit("min");
302 : } else {
303 1922 : for (const char *name : {"ms", "h", "D", "M", "Y"}) {
304 1922 : if (match[unit_idx] == name) {
305 634 : return units::Unit(name);
306 : }
307 : }
308 : }
309 :
310 0 : throw std::invalid_argument(std::string("Unsupported unit in datetime: ") +
311 0 : std::string(match[unit_idx]));
312 1421 : }
313 :
314 : [[nodiscard]] scipp::units::Unit
315 1420 : parse_datetime_dtype(const pybind11::object &dtype) {
316 1420 : if (py::isinstance<py::type>(dtype)) {
317 : // This handles dtype=np.datetime64, i.e. passing the class.
318 1 : return units::one;
319 1419 : } else if (py::hasattr(dtype, "dtype")) {
320 614 : return parse_datetime_dtype(dtype.attr("dtype"));
321 805 : } else if (py::hasattr(dtype, "name")) {
322 618 : return parse_datetime_dtype(dtype.attr("name").cast<std::string>());
323 : } else {
324 187 : return parse_datetime_dtype(py::str(dtype).cast<std::string>());
325 : }
326 : }
|