LCOV - code coverage report
Current view: top level - python - dtype.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 148 163 90.8 %
Date: 2024-04-28 01:25:40 Functions: 19 19 100.0 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #include "dtype.h"
       6             : 
       7             : #include <regex>
       8             : 
       9             : #include "scipp/core/eigen.h"
      10             : #include "scipp/core/string.h"
      11             : #include "scipp/dataset/dataset.h"
      12             : #include "scipp/variable/variable.h"
      13             : 
      14             : #include "format.h"
      15             : #include "py_object.h"
      16             : #include "pybind11.h"
      17             : 
      18             : using namespace scipp;
      19             : using namespace scipp::core;
      20             : 
      21             : namespace py = pybind11;
      22             : 
      23             : namespace {
      24             : /// 'kind' character codes for numpy dtypes
      25             : enum class DTypeKind : char {
      26             :   Float = 'f',
      27             :   Int = 'i',
      28             :   Bool = 'b',
      29             :   Datetime = 'M',
      30             :   Object = 'O',
      31             :   String = 'U',
      32             :   RawData = 'V',
      33             : };
      34             : 
      35      112434 : constexpr bool operator==(const char a, const DTypeKind b) {
      36      112434 :   return a == static_cast<char>(b);
      37             : }
      38             : 
      39             : enum class DTypeSize : scipp::index {
      40             :   Float64 = 8,
      41             :   Float32 = 4,
      42             :   Int64 = 8,
      43             :   Int32 = 4,
      44             : };
      45             : 
      46       42727 : constexpr bool operator==(const scipp::index a, const DTypeSize b) {
      47       42727 :   return a == static_cast<scipp::index>(b);
      48             : }
      49             : } // namespace
      50             : 
      51           3 : void init_dtype(py::module &m) {
      52             :   py::class_<DType> PyDType(m, "DType", R"(
      53             : Representation of a data type of a Variable in Scipp.
      54             : See https://scipp.github.io/reference/dtype.html for details.
      55             : 
      56             : The data types ``VariableView``, ``DataArrayView``, and ``DatasetView`` are used for
      57             : objects containing binned data. They cannot be used directly to create arrays of bins.
      58           3 : )");
      59          32 :   PyDType.def(py::init([](const py::object &x) { return scipp_dtype(x); }))
      60           3 :       .def("__eq__",
      61       31527 :            [](const DType &self, const py::object &other) {
      62       31527 :              return self == scipp_dtype(other);
      63             :            })
      64        1754 :       .def("__str__", [](const DType &self) { return to_string(self); })
      65           3 :       .def("__repr__", [](const DType &self) {
      66          60 :         return "DType('" + to_string(self) + "')";
      67             :       });
      68             : 
      69             :   // Explicit list of dtypes to bind since core::dtypeNameRegistry contains
      70             :   // types that are for internal use only and are never returned to Python.
      71           3 :   for (const auto &t : {
      72             :            dtype<bool>,
      73             :            dtype<int32_t>,
      74             :            dtype<int64_t>,
      75             :            dtype<float>,
      76             :            dtype<double>,
      77             :            dtype<std::string>,
      78             :            dtype<Eigen::Vector3d>,
      79             :            dtype<Eigen::Matrix3d>,
      80             :            dtype<Eigen::Affine3d>,
      81             :            dtype<core::Quaternion>,
      82             :            dtype<core::Translation>,
      83             :            dtype<core::time_point>,
      84             :            dtype<Variable>,
      85             :            dtype<DataArray>,
      86             :            dtype<Dataset>,
      87             :            dtype<core::bin<Variable>>,
      88             :            dtype<core::bin<DataArray>>,
      89             :            dtype<core::bin<Dataset>>,
      90             :            dtype<python::PyObject>,
      91          63 :        })
      92          57 :     PyDType.def_property_readonly_static(
      93          57 :         core::dtypeNameRegistry().at(t).c_str(),
      94       20432 :         [t](const py::object &) { return t; });
      95           3 : }
      96             : 
      97       77635 : DType dtype_of(const py::object &x) {
      98       77635 :   if (x.is_none()) {
      99       37164 :     return dtype<void>;
     100       40471 :   } else if (py::isinstance<py::buffer>(x)) {
     101             :     // Cannot use hasattr(x, "dtype") as that would catch Variables as well.
     102       24140 :     return scipp_dtype(x.attr("dtype"));
     103       16331 :   } else if (py::isinstance<py::bool_>(x)) {
     104             :     // bool needs to come before int because bools are instances of int.
     105         252 :     return core::dtype<bool>;
     106       16079 :   } else if (py::isinstance<py::float_>(x)) {
     107       11537 :     return core::dtype<double>;
     108        4542 :   } else if (py::isinstance<py::int_>(x)) {
     109        4286 :     return core::dtype<int64_t>;
     110         256 :   } else if (py::isinstance<py::str>(x)) {
     111         177 :     return core::dtype<std::string>;
     112          79 :   } else if (py::isinstance<variable::Variable>(x)) {
     113          22 :     return core::dtype<variable::Variable>;
     114          57 :   } else if (py::isinstance<dataset::DataArray>(x)) {
     115           7 :     return core::dtype<dataset::DataArray>;
     116          50 :   } else if (py::isinstance<dataset::Dataset>(x)) {
     117           8 :     return core::dtype<dataset::Dataset>;
     118             :   } else {
     119          42 :     return core::dtype<python::PyObject>;
     120             :   }
     121             : }
     122             : 
     123       39636 : scipp::core::DType scipp_dtype(const py::dtype &type) {
     124       39636 :   if (type.kind() == DTypeKind::Float) {
     125       13350 :     if (type.itemsize() == DTypeSize::Float64)
     126       12910 :       return scipp::core::dtype<double>;
     127         440 :     if (type.itemsize() == DTypeSize::Float32)
     128         440 :       return scipp::core::dtype<float>;
     129             :   }
     130       26286 :   if (type.kind() == DTypeKind::Int) {
     131       23239 :     if (type.itemsize() == DTypeSize::Int64)
     132       17541 :       return scipp::core::dtype<std::int64_t>;
     133        5698 :     if (type.itemsize() == DTypeSize::Int32)
     134        5698 :       return scipp::core::dtype<std::int32_t>;
     135             :   }
     136        3047 :   if (type.kind() == DTypeKind::Bool)
     137        1367 :     return scipp::core::dtype<bool>;
     138        1680 :   if (type.kind() == DTypeKind::String)
     139         112 :     return scipp::core::dtype<std::string>;
     140        1568 :   if (type.kind() == DTypeKind::Datetime) {
     141        1565 :     return scipp::core::dtype<scipp::core::time_point>;
     142             :   }
     143           3 :   if (type.kind() == DTypeKind::Object) {
     144           3 :     return scipp::core::dtype<scipp::python::PyObject>;
     145             :   }
     146           0 :   throw std::runtime_error(
     147           0 :       "Unsupported numpy dtype: " +
     148           0 :       py::str(static_cast<py::handle>(type)).cast<std::string>() +
     149             :       "\n"
     150             :       "Supported types are: bool, float32, float64,"
     151           0 :       " int32, int64, string, datetime64, and object");
     152             : }
     153             : 
     154           9 : scipp::core::DType dtype_from_scipp_class(const py::object &type) {
     155             :   // Using the __name__ because we would otherwise have to get a handle
     156             :   // to the Python classes for our C++ classes. And I don't know how
     157             :   // to do that. This approach can break if people (including us) pull
     158             :   // shenanigans with the classes in Python!
     159           9 :   if (type.attr("__name__").cast<std::string>() == "Variable") {
     160           3 :     return dtype<Variable>;
     161           6 :   } else if (type.attr("__name__").cast<std::string>() == "DataArray") {
     162           3 :     return dtype<DataArray>;
     163           3 :   } else if (type.attr("__name__").cast<std::string>() == "Dataset") {
     164           3 :     return dtype<Dataset>;
     165             :   } else {
     166           0 :     throw std::invalid_argument("Invalid dtype");
     167             :   }
     168             : }
     169             : 
     170      100925 : scipp::core::DType scipp_dtype(const py::object &type) {
     171             :   // Check None first, then native scipp Dtype, then numpy.dtype
     172      100925 :   if (type.is_none())
     173       33981 :     return dtype<void>;
     174             :   try {
     175       66944 :     return type.cast<DType>();
     176       39651 :   } catch (const py::cast_error &) {
     177       40519 :     if (py::isinstance<py::type>(type) &&
     178       40519 :         type.attr("__module__").cast<std::string>() == "scipp._scipp.core") {
     179           9 :       return dtype_from_scipp_class(type);
     180             :     }
     181             : 
     182       39642 :     auto np_dtype = py::dtype::from_args(type);
     183       39636 :     if (np_dtype.kind() == DTypeKind::RawData) {
     184           0 :       throw std::invalid_argument(
     185             :           "Unsupported numpy dtype: raw data. This can happen when you pass a "
     186           0 :           "Python object instead of a class. Got dtype=`" +
     187           0 :           py::str(type).cast<std::string>() + '`');
     188             :     }
     189       39636 :     return scipp_dtype(np_dtype);
     190       39651 :   }
     191             : }
     192             : 
     193             : namespace {
     194       42984 : bool is_default(const ProtoUnit &unit) {
     195       42984 :   return std::holds_alternative<DefaultUnit>(unit);
     196             : }
     197             : } // namespace
     198             : 
     199             : std::tuple<scipp::core::DType, std::optional<scipp::units::Unit>>
     200       42990 : cast_dtype_and_unit(const pybind11::object &dtype, const ProtoUnit &unit) {
     201       42990 :   const auto scipp_dtype = ::scipp_dtype(dtype);
     202       42984 :   if (scipp_dtype == core::dtype<core::time_point>) {
     203         192 :     units::Unit deduced_unit = parse_datetime_dtype(dtype);
     204         192 :     if (!is_default(unit)) {
     205         143 :       const auto unit_ = unit_or_default(unit, scipp_dtype);
     206         143 :       if (deduced_unit != units::one && unit_ != deduced_unit) {
     207          84 :         throw std::invalid_argument(
     208         168 :             python::format("The unit encoded in the dtype (", deduced_unit,
     209         168 :                            ") conflicts with the given unit (", unit_, ")."));
     210             :       } else {
     211          59 :         deduced_unit = unit_;
     212             :       }
     213             :     }
     214         108 :     return std::tuple{scipp_dtype, deduced_unit};
     215             :   } else {
     216             :     // Concrete dtype not known at this point so we cannot determine the default
     217             :     // unit here. Therefore nullopt is returned.
     218       42792 :     return std::tuple{scipp_dtype, is_default(unit)
     219       70619 :                                        ? std::optional<scipp::units::Unit>()
     220       70619 :                                        : unit_or_default(unit)};
     221             :   }
     222             : }
     223             : 
     224        3934 : void ensure_conversion_possible(const DType from, const DType to,
     225             :                                 const std::string &data_name) {
     226        5174 :   if (from == to || (core::is_fundamental(from) && core::is_fundamental(to)) ||
     227        5174 :       to == dtype<python::PyObject> ||
     228          61 :       (core::is_int(from) && to == dtype<core::time_point>)) {
     229        3919 :     return; // These are allowed.
     230             :   }
     231          15 :   throw std::invalid_argument(python::format("Cannot convert ", data_name,
     232          30 :                                              " from type ", from, " to ", to));
     233             : }
     234             : 
     235       37894 : DType common_dtype(const py::object &values, const py::object &variances,
     236             :                    const DType dtype, const DType default_dtype) {
     237       37894 :   const DType values_dtype = dtype_of(values);
     238       37894 :   const DType variances_dtype = dtype_of(variances);
     239       37894 :   if (dtype == core::dtype<void>) {
     240             :     // Get dtype solely from data.
     241       33979 :     if (values_dtype == core::dtype<void>) {
     242          20 :       if (variances_dtype == core::dtype<void>) {
     243           0 :         return default_dtype;
     244             :       }
     245          20 :       return variances_dtype;
     246             :     } else {
     247       34670 :       if (variances_dtype != core::dtype<void> &&
     248         711 :           values_dtype != variances_dtype) {
     249           0 :         throw std::invalid_argument(python::format(
     250             :             "The dtypes of the 'values' (", values_dtype, ") and 'variances' (",
     251             :             variances_dtype,
     252             :             ") arguments do not match. You can specify a dtype explicitly to"
     253           0 :             " trigger a conversion if applicable."));
     254             :       }
     255       33959 :       return values_dtype;
     256             :     }
     257             :   } else { // dtype != core::dtype<void>
     258             :     // Combine data and explicit dtype with potential conversion.
     259        3915 :     if (values_dtype != core::dtype<void>) {
     260        3945 :       ensure_conversion_possible(values_dtype, dtype, "values");
     261             :     }
     262        3900 :     if (variances_dtype != core::dtype<void>) {
     263          19 :       ensure_conversion_possible(variances_dtype, dtype, "variances");
     264             :     }
     265        3900 :     return dtype;
     266             :   }
     267             : }
     268             : 
     269         612 : bool has_datetime_dtype(const py::object &obj) {
     270         612 :   if (py::hasattr(obj, "dtype")) {
     271         578 :     return obj.attr("dtype").attr("kind").cast<char>() == DTypeKind::Datetime;
     272             :   } else {
     273             :     // numpy.datetime64 and numpy.ndarray both have 'dtype' attributes.
     274             :     // Mark everything else as not-datetime.
     275          34 :     return false;
     276             :   }
     277             : }
     278             : 
     279             : [[nodiscard]] scipp::units::Unit
     280        1421 : parse_datetime_dtype(const std::string &dtype_name) {
     281             :   static std::regex datetime_regex{R"(datetime64(\[(\w+)\])?)",
     282        1421 :                                    std::regex_constants::optimize};
     283        1421 :   constexpr size_t unit_idx = 2;
     284        1421 :   std::smatch match;
     285        2842 :   if (!std::regex_match(dtype_name, match, datetime_regex) ||
     286        1421 :       match.size() != 3) {
     287           0 :     throw std::invalid_argument("Invalid dtype, expected datetime64, got " +
     288           0 :                                 dtype_name);
     289             :   }
     290             : 
     291        1421 :   if (match.length(unit_idx) == 0) {
     292          50 :     return scipp::units::dimensionless;
     293        1371 :   } else if (match[unit_idx] == "s") {
     294         293 :     return scipp::units::s;
     295        1078 :   } else if (match[unit_idx] == "us") {
     296         191 :     return scipp::units::us;
     297         887 :   } else if (match[unit_idx] == "ns") {
     298         243 :     return scipp::units::ns;
     299         644 :   } else if (match[unit_idx] == "m") {
     300             :     // In np.datetime64, m means minute.
     301          10 :     return units::Unit("min");
     302             :   } else {
     303        1922 :     for (const char *name : {"ms", "h", "D", "M", "Y"}) {
     304        1922 :       if (match[unit_idx] == name) {
     305         634 :         return units::Unit(name);
     306             :       }
     307             :     }
     308             :   }
     309             : 
     310           0 :   throw std::invalid_argument(std::string("Unsupported unit in datetime: ") +
     311           0 :                               std::string(match[unit_idx]));
     312        1421 : }
     313             : 
     314             : [[nodiscard]] scipp::units::Unit
     315        1420 : parse_datetime_dtype(const pybind11::object &dtype) {
     316        1420 :   if (py::isinstance<py::type>(dtype)) {
     317             :     // This handles dtype=np.datetime64, i.e. passing the class.
     318           1 :     return units::one;
     319        1419 :   } else if (py::hasattr(dtype, "dtype")) {
     320         614 :     return parse_datetime_dtype(dtype.attr("dtype"));
     321         805 :   } else if (py::hasattr(dtype, "name")) {
     322         618 :     return parse_datetime_dtype(dtype.attr("name").cast<std::string>());
     323             :   } else {
     324         187 :     return parse_datetime_dtype(py::str(dtype).cast<std::string>());
     325             :   }
     326             : }

Generated by: LCOV version 1.14