LCOV - code coverage report
Current view: top level - python - variable_init.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 145 159 91.2 %
Date: 2024-04-28 01:25:40 Functions: 56 56 100.0 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Jan-Lukas Wynen
       5             : 
       6             : #include "pybind11.h"
       7             : 
       8             : #include "scipp/core/dtype.h"
       9             : #include "scipp/core/eigen.h"
      10             : #include "scipp/core/tag_util.h"
      11             : #include "scipp/dataset/dataset.h"
      12             : #include "scipp/units/string.h"
      13             : #include "scipp/variable/structures.h"
      14             : #include "scipp/variable/to_unit.h"
      15             : #include "scipp/variable/variable.h"
      16             : 
      17             : #include "dtype.h"
      18             : #include "format.h"
      19             : #include "numpy.h"
      20             : #include "py_object.h"
      21             : #include "unit.h"
      22             : 
      23             : using namespace scipp;
      24             : using namespace scipp::variable;
      25             : 
      26             : namespace py = pybind11;
      27             : 
      28             : namespace {
      29      113816 : bool is_empty(const py::object &sequence) {
      30      113816 :   if (py::isinstance<py::buffer>(sequence)) {
      31           0 :     return sequence.attr("ndim").cast<scipp::index>() == 0;
      32             :   }
      33      113816 :   return !py::bool_{sequence};
      34             : }
      35             : 
      36       22159 : auto shape_of(const py::object &array) { return py::iter(array.attr("shape")); }
      37             : 
      38           2 : scipp::index n_remaining(const py::iterator &it) {
      39           2 :   return std::distance(it, it.end());
      40             : }
      41             : 
      42           1 : [[noreturn]] void throw_ndim_mismatch_error(const scipp::index a_ndim,
      43             :                                             const std::string_view a_name,
      44             :                                             const scipp::index b_ndim,
      45             :                                             const std::string_view b_name) {
      46           1 :   throw std::invalid_argument(
      47           2 :       python::format("The number of dimensions in '", a_name, "' (", a_ndim,
      48             :                      ") does not match the number of dimensions in '", b_name,
      49           2 :                      "' (", b_ndim, ")."));
      50             : }
      51             : 
      52       21403 : void ensure_same_shape(const py::object &values, const py::object &variances) {
      53       21403 :   if (values.is_none() || variances.is_none()) {
      54       21031 :     return;
      55             :   }
      56             : 
      57         372 :   auto val_shape = shape_of(values);
      58         372 :   auto var_shape = shape_of(variances);
      59             : 
      60         372 :   scipp::index dim = 0;
      61         372 :   std::tuple<scipp::index, scipp::index, scipp::index> mismatch{-1, -1, -1};
      62         990 :   for (; val_shape != val_shape.end() && var_shape != var_shape.end();
      63         618 :        ++val_shape, ++var_shape, ++dim) {
      64         618 :     if (val_shape->cast<scipp::index>() != var_shape->cast<scipp::index>()) {
      65           0 :       if (std::get<0>(mismatch) == -1) {
      66             :         // Defer throwing to let ndim error take precedence.
      67           0 :         mismatch = std::tuple{dim, val_shape->cast<scipp::index>(),
      68           0 :                               var_shape->cast<scipp::index>()};
      69             :       }
      70             :     }
      71             :   }
      72         372 :   if (val_shape != val_shape.end() || var_shape != var_shape.end()) {
      73           0 :     throw_ndim_mismatch_error(dim + n_remaining(val_shape), "values",
      74           0 :                               dim + n_remaining(var_shape), "variances");
      75             :   }
      76         372 :   if (std::get<0>(mismatch) != -1) {
      77           0 :     throw std::invalid_argument(python::format(
      78             :         "The shapes of 'values' and 'variances' differ in dimension ",
      79           0 :         std::get<0>(mismatch), ": ", std::get<1>(mismatch), " vs ",
      80           0 :         std::get<2>(mismatch), '.'));
      81             :   }
      82         372 : }
      83             : 
      84             : namespace detail {
      85       21415 : void consume_extra_dims(py::iterator &shape_it,
      86             :                         const scipp::index n_extra_dims) {
      87       21497 :   for (scipp::index i = 0; i < n_extra_dims; ++i) {
      88          82 :     if (shape_it == shape_it.end())
      89           0 :       throw std::invalid_argument(
      90           0 :           "Data has too few dimensions for given dimension labels.");
      91          82 :     ++shape_it;
      92             :   }
      93       21415 : }
      94             : 
      95       21415 : Dimensions build_dimensions(py::iterator &&label_it, py::iterator &&shape_it,
      96             :                             const scipp::index n_extra_dims,
      97             :                             const std::string_view shape_name) {
      98       21415 :   Dimensions dims;
      99       21415 :   scipp::index dim = 0;
     100       46539 :   for (; label_it != label_it.end() && shape_it != shape_it.end();
     101       25124 :        ++label_it, ++shape_it, ++dim) {
     102       25124 :     dims.addInner(Dim{label_it->cast<std::string>()},
     103             :                   shape_it->cast<scipp::index>());
     104             :   }
     105       21415 :   consume_extra_dims(shape_it, n_extra_dims);
     106       21415 :   if (label_it != label_it.end() || shape_it != shape_it.end()) {
     107           1 :     throw_ndim_mismatch_error(dim + n_remaining(label_it), "dims",
     108           1 :                               dim + n_remaining(shape_it), shape_name);
     109             :   }
     110       21414 :   return dims;
     111           1 : }
     112             : } // namespace detail
     113             : 
     114       38028 : Dimensions build_dimensions(const py::object &dim_labels,
     115             :                             const py::object &values,
     116             :                             const py::object &variances,
     117             :                             const scipp::index n_extra_dims = 0) {
     118       38028 :   if (is_empty(dim_labels)) {
     119       16613 :     return Dimensions{};
     120             :   } else {
     121       21415 :     if (!values.is_none()) {
     122       21403 :       ensure_same_shape(values, variances);
     123       42808 :       return detail::build_dimensions(py::iter(dim_labels), shape_of(values),
     124       64209 :                                       n_extra_dims, "values");
     125             :     } else {
     126          24 :       return detail::build_dimensions(py::iter(dim_labels), shape_of(variances),
     127          36 :                                       n_extra_dims, "variances");
     128             :     }
     129             :   }
     130             : }
     131             : 
     132       75788 : py::object parse_data_sequence(const py::object &dim_labels,
     133             :                                const py::object &data) {
     134             :   // Need to check for None because py::array does not preserve it.
     135       75788 :   if (is_empty(dim_labels) || data.is_none()) {
     136       54065 :     return data;
     137             :   } else {
     138       21723 :     return py::array(data);
     139             :   }
     140             : }
     141             : 
     142        2415 : void ensure_is_scalar(const py::buffer &array) {
     143        2415 :   if (const auto ndim = array.attr("ndim").cast<int64_t>(); ndim != 0) {
     144           1 :     throw except::DimensionError(python::format(
     145           2 :         "Cannot interpret ", ndim, "-dimensional array as a scalar."));
     146             :   }
     147        2414 : }
     148             : 
     149             : template <class T>
     150       16511 : T extract_scalar(const py::object &obj, const units::Unit unit) {
     151             :   using TM = ElementTypeMap<T>;
     152             :   using PyType = typename TM::PyType;
     153       16511 :   TM::check_assignable(obj, unit);
     154       16511 :   if (py::isinstance<py::buffer>(obj)) {
     155        2115 :     ensure_is_scalar(obj);
     156        2113 :     return converting_cast<PyType>::cast(obj.attr("item")());
     157             :   } else {
     158       14397 :     return converting_cast<PyType>::cast(obj);
     159             :   }
     160             : }
     161             : 
     162             : template <>
     163         331 : core::time_point extract_scalar<core::time_point>(const py::object &obj,
     164             :                                                   const units::Unit unit) {
     165             :   using TM = ElementTypeMap<core::time_point>;
     166             :   using PyType = typename TM::PyType;
     167         331 :   TM::check_assignable(obj, unit);
     168         331 :   if (py::isinstance<py::buffer>(obj)) {
     169         301 :     ensure_is_scalar(obj);
     170         903 :     return core::time_point{obj.attr("astype")(py::dtype::of<PyType>())
     171         602 :                                 .attr("item")()
     172         301 :                                 .template cast<PyType>()};
     173             :   } else {
     174          30 :     return core::time_point{obj.cast<PyType>()};
     175             :   }
     176             : }
     177             : 
     178             : template <>
     179          43 : python::PyObject extract_scalar<python::PyObject>(const py::object &obj,
     180             :                                                   const units::Unit unit) {
     181             :   using TM = ElementTypeMap<python::PyObject>;
     182          43 :   TM::check_assignable(obj, unit);
     183          43 :   return obj;
     184             : }
     185             : 
     186             : template <class T>
     187       38776 : auto make_element_array(const Dimensions &dims, const py::object &source,
     188             :                         const units::Unit unit) {
     189       38776 :   if (source.is_none()) {
     190          20 :     return element_array<T>();
     191       38756 :   } else if (dims.ndim() == 0) {
     192       33516 :     return element_array<T>(1, extract_scalar<T>(source, unit));
     193             :   } else {
     194       21871 :     element_array<T> array(dims.volume(), core::init_for_overwrite);
     195       21871 :     copy_array_into_view(cast_to_array_like<T>(source, unit), array, dims);
     196       21871 :     return array;
     197       21871 :   }
     198             : }
     199             : 
     200             : template <class T> struct MakeVariable {
     201       37878 :   static Variable apply(const Dimensions &dims, const py::object &values,
     202             :                         const py::object &variances, const units::Unit unit) {
     203       37878 :     const auto [values_unit, final_unit] = common_unit<T>(values, unit);
     204       37877 :     auto values_array =
     205             :         Values(make_element_array<T>(dims, values, values_unit));
     206       75752 :     auto variable = variances.is_none()
     207       37126 :                         ? makeVariable<T>(dims, std::move(values_array))
     208             :                         // cppcheck-suppress accessMoved  # False-positive.
     209         750 :                         : makeVariable<T>(dims, std::move(values_array),
     210       39380 :                                           Variances(make_element_array<T>(
     211             :                                               dims, variances, values_unit)));
     212       37872 :     variable.setUnit(values_unit);
     213       75744 :     return to_unit(variable, final_unit, CopyPolicy::TryAvoid);
     214       37876 :   }
     215             : };
     216             : 
     217       37894 : Variable make_variable(const py::object &dim_labels, const py::object &values,
     218             :                        const py::object &variances,
     219             :                        const std::optional<units::Unit> &unit_, DType dtype) {
     220       37894 :   const auto converted_values = parse_data_sequence(dim_labels, values);
     221       37894 :   const auto converted_variances = parse_data_sequence(dim_labels, variances);
     222       37894 :   dtype = common_dtype(converted_values, converted_variances, dtype);
     223             :   const auto dims =
     224       37879 :       build_dimensions(dim_labels, converted_values, converted_variances);
     225       37878 :   const auto unit = unit_.value_or(variable::default_unit_for(dtype));
     226             :   return core::CallDType<double, float, int64_t, int32_t, bool,
     227             :                          scipp::core::time_point, std::string, Variable,
     228             :                          DataArray, Dataset,
     229             :                          python::PyObject>::apply<MakeVariable>(dtype, dims,
     230             :                                                                 values,
     231             :                                                                 variances,
     232       75750 :                                                                 unit);
     233       37922 : }
     234             : 
     235         110 : template <int N> Dimensions pad_structure_dimensions(Dimensions dims) {
     236         110 :   dims.addInner(Dim::InternalStructureComponent, N);
     237         110 :   return dims;
     238             : }
     239             : 
     240          39 : template <int M, int N> Dimensions pad_structure_dimensions(Dimensions dims) {
     241          39 :   dims.addInner(Dim::InternalStructureRow, M);
     242          39 :   dims.addInner(Dim::InternalStructureColumn, N);
     243          39 :   return dims;
     244             : }
     245             : 
     246             : template <class T, class Elem, int... N>
     247         149 : Variable make_structured_variable(const py::object &dim_labels,
     248             :                                   const py::object &values_,
     249             :                                   const py::object &variances,
     250             :                                   const std::optional<units::Unit> &unit_) {
     251         149 :   if (!variances.is_none())
     252           0 :     throw except::VariancesError("Variances not supported for dtype " +
     253             :                                  to_string(dtype<Elem>));
     254             : 
     255         149 :   const auto values = py::array(values_);
     256         149 :   const auto unit = unit_.value_or(variable::default_unit_for(dtype<Elem>));
     257         149 :   const auto dims =
     258         149 :       build_dimensions(dim_labels, values, py::none(), sizeof...(N));
     259         149 :   const auto padded_dims = pad_structure_dimensions<N...>(dims);
     260             : 
     261         149 :   auto var = variable::make_structures<T, Elem>(
     262             :       dims, unit, make_element_array<Elem>(padded_dims, values, unit));
     263         298 :   return var;
     264         149 : }
     265             : } // namespace
     266             : 
     267             : /*
     268             :  * It is the init method's responsibility to check that the combination
     269             :  * of arguments is valid. Functions down the line do not check again.
     270             :  */
     271           3 : void bind_init(py::class_<Variable> &cls) {
     272           6 :   cls.def(
     273           3 :       py::init([](const py::object &dim_labels, const py::object &values,
     274             :                   const py::object &variances, const ProtoUnit unit,
     275             :                   const py::object &dtype, const bool aligned) {
     276       38135 :         if (values.is_none() && variances.is_none()) {
     277           0 :           throw std::invalid_argument(
     278           0 :               "At least one argument of 'values' and 'variances' is required.");
     279             :         }
     280       38043 :         const auto [scipp_dtype, actual_unit] =
     281       38135 :             cast_dtype_and_unit(dtype, unit);
     282             : 
     283       38043 :         auto var = [&, scipp_dtype = scipp_dtype, actual_unit = actual_unit]() {
     284       38043 :           if (scipp_dtype == ::dtype<Eigen::Vector3d>)
     285             :             return make_structured_variable<Eigen::Vector3d, double, 3>(
     286          86 :                 dim_labels, values, variances, actual_unit);
     287       37957 :           if (scipp_dtype == ::dtype<Eigen::Matrix3d>)
     288             :             return make_structured_variable<Eigen::Matrix3d, double, 3, 3>(
     289          29 :                 dim_labels, values, variances, actual_unit);
     290       37928 :           if (scipp_dtype == ::dtype<Eigen::Affine3d>)
     291             :             return make_structured_variable<Eigen::Affine3d, double, 4, 4>(
     292          10 :                 dim_labels, values, variances, actual_unit);
     293       37918 :           if (scipp_dtype == ::dtype<core::Quaternion>)
     294             :             return make_structured_variable<core::Quaternion, double, 4>(
     295          15 :                 dim_labels, values, variances, actual_unit);
     296       37903 :           if (scipp_dtype == ::dtype<core::Translation>)
     297             :             return make_structured_variable<core::Translation, double, 3>(
     298           9 :                 dim_labels, values, variances, actual_unit);
     299             : 
     300       37894 :           return make_variable(dim_labels, values, variances, actual_unit,
     301       37894 :                                scipp_dtype);
     302       38065 :         }();
     303             : 
     304       38021 :         var.set_aligned(aligned);
     305       76042 :         return var;
     306             :       }),
     307           6 :       py::kw_only(), py::arg("dims"), py::arg("values") = py::none(),
     308           6 :       py::arg("variances") = py::none(), py::arg("unit") = DefaultUnit{},
     309           6 :       py::arg("dtype") = py::none(), py::arg("aligned") = true,
     310             :       R"raw(
     311             : Initialize a variable with values and/or variances.
     312             : 
     313             : At least one argument of ``values`` and ``variances`` must be used.
     314             : if you want to preallocate memory to fill later, use :py:func:`scipp.empty`.
     315             : 
     316             : Attention
     317             : ---------
     318             : This constructor is meant primarily for internal use.
     319             : Use one of the Specialized
     320             : `creation functions <../../reference/creation-functions.rst>`_ instead.
     321             : See in particular :py:func:`scipp.array` and :py:func:`scipp.scalar`.
     322             : 
     323             : Parameters
     324             : ----------
     325             : dims:
     326             :    Dimension labels.
     327             : values:
     328             :    Sequence of values for constructing an array variable.
     329             : variances:
     330             :    Sequence of variances for constructing an array variable.
     331             : unit:
     332             :    Physical unit, defaults to ``scipp.units.dimensionless``.
     333             : dtype:
     334             :    Type of the variable's elements. Is deduced from other arguments
     335             :    in most cases. Defaults to ``sc.DType.float64`` if no deduction is
     336             :    possible.
     337             : aligned:
     338             :    Initial value for the alignment flag.
     339             : )raw");
     340           3 : }

Generated by: LCOV version 1.14