LCOV - code coverage report
Current view: top level - python - bind_units.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 125 144 86.8 %
Date: 2024-12-01 01:56:34 Functions: 23 25 92.0 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #include <sstream>
       6             : 
       7             : #include "scipp/core/tag_util.h"
       8             : #include "scipp/units/unit.h"
       9             : 
      10             : #include "pybind11.h"
      11             : #include "unit.h"
      12             : 
      13             : using namespace scipp;
      14             : namespace py = pybind11;
      15             : 
      16             : constexpr int UNIT_DICT_VERSION = 2;
      17             : constexpr std::array SUPPORTED_UNIT_DICT_VERSIONS = {1, 2};
      18             : 
      19             : using UnitDictValue =
      20             :     py::typing::Union<int, double, bool, py::typing::Dict<py::str, double>>;
      21             : using UnitDict = py::typing::Dict<py::str, UnitDictValue>;
      22             : 
      23             : namespace {
      24             : 
      25         282 : bool is_supported_unit(const units::Unit &unit) {
      26         282 :   return unit.underlying().commodity() == 0;
      27             : }
      28             : 
      29             : // We only support units where we are confident that we can encode them using
      30             : // a different unit library, in order to ensure that we can switch
      31             : // implementations in the future if necessary.
      32         138 : void assert_supported_unit_for_dict(const units::Unit &unit) {
      33         138 :   if (!is_supported_unit(unit)) {
      34           0 :     throw std::invalid_argument("Unit cannot be converted to dict: '" +
      35           0 :                                 to_string(unit) +
      36           0 :                                 "' Commodities are not supported.");
      37             :   }
      38         138 : }
      39             : 
      40         138 : auto to_dict(const units::Unit &unit) {
      41         138 :   assert_supported_unit_for_dict(unit);
      42             : 
      43         138 :   UnitDict dict;
      44         138 :   dict["__version__"] = UNIT_DICT_VERSION;
      45         138 :   dict["multiplier"] = unit.underlying().multiplier();
      46             : 
      47         138 :   unit.map_over_flags([&dict](const char *const name, const auto flag) mutable {
      48         552 :     if (flag) {
      49           9 :       dict[name] = true;
      50             :     }
      51         552 :   });
      52             : 
      53         138 :   py::dict powers;
      54         138 :   unit.map_over_bases(
      55        1515 :       [&powers](const char *const base, const auto power) mutable {
      56        1380 :         if (power != 0) {
      57         135 :           powers[base] = power;
      58             :         }
      59        1380 :       });
      60         138 :   if (!powers.empty())
      61         123 :     dict["powers"] = powers;
      62             : 
      63         276 :   return dict;
      64         138 : }
      65             : 
      66        2128 : template <class T = int> T get(const py::dict &dict, const char *const name) {
      67        2128 :   if (dict.contains(name)) {
      68         156 :     return dict[name].cast<T>();
      69             :   }
      70        1972 :   return T{};
      71             : }
      72             : 
      73         152 : void assert_dict_version_supported(const py::dict &dict) {
      74         152 :   if (const auto ver = dict["__version__"].cast<int>();
      75         152 :       std::find(SUPPORTED_UNIT_DICT_VERSIONS.cbegin(),
      76             :                 SUPPORTED_UNIT_DICT_VERSIONS.cend(),
      77         152 :                 ver) == SUPPORTED_UNIT_DICT_VERSIONS.cend()) {
      78           0 :     std::ostringstream oss;
      79           0 :     oss << "Unit dict has version " << std::to_string(ver)
      80           0 :         << " but the current installation of scipp only supports versions [";
      81           0 :     for (const auto v : SUPPORTED_UNIT_DICT_VERSIONS)
      82           0 :       oss << v << ", ";
      83           0 :     oss << "]";
      84           0 :     throw std::invalid_argument(oss.str());
      85           0 :   }
      86         152 : }
      87             : 
      88         152 : units::Unit from_dict(const UnitDict &dict) {
      89         152 :   assert_dict_version_supported(dict);
      90             : 
      91         152 :   const py::dict powers = dict.contains("powers") ? dict["powers"] : py::dict();
      92         304 :   return units::Unit(llnl::units::precise_unit(
      93         304 :       llnl::units::detail::unit_data{
      94             :           get(powers, "m"), get(powers, "kg"), get(powers, "s"),
      95             :           get(powers, "A"), get(powers, "K"), get(powers, "mol"),
      96             :           get(powers, "cd"), get(powers, "$"), get(powers, "counts"),
      97         152 :           get(powers, "rad"), get<bool>(dict, "per_unit"),
      98         152 :           get<bool>(dict, "i_flag"), get<bool>(dict, "e_flag"),
      99         152 :           get<bool>(dict, "equation")},
     100         304 :       dict["multiplier"].cast<double>()));
     101         152 : }
     102             : 
     103         144 : std::string repr(const units::Unit &unit) {
     104         144 :   if (!is_supported_unit(unit)) {
     105           0 :     return "<unsupported unit: " + to_string(unit) + '>';
     106             :   }
     107             : 
     108         144 :   std::ostringstream oss;
     109         144 :   oss << "Unit(";
     110             : 
     111         144 :   bool first = true;
     112         144 :   if (const auto mult = unit.underlying().multiplier(); mult != 1.0) {
     113          11 :     oss << mult;
     114          11 :     first = false;
     115             :   }
     116             : 
     117         144 :   unit.map_over_bases(
     118        1838 :       [&oss, &first](const char *const base, const auto power) mutable {
     119        1440 :         if (power != 0) {
     120         127 :           if (!first) {
     121          20 :             oss << "*";
     122             :           } else {
     123         107 :             first = false;
     124             :           }
     125         127 :           oss << base;
     126         127 :           if (power != 1)
     127          17 :             oss << "**" << power;
     128             :         }
     129        1440 :       });
     130         144 :   if (first)
     131          26 :     oss << "1"; // multiplier == 1 and all powers == 0
     132             : 
     133         144 :   unit.map_over_flags([&oss](const char *const name, const auto flag) mutable {
     134         576 :     if (flag)
     135           4 :       oss << ", " << name << "=True";
     136         576 :   });
     137         144 :   oss << ')';
     138         144 :   return oss.str();
     139         144 : }
     140             : 
     141           0 : std::string repr_html(const units::Unit &unit) {
     142             :   // Regular string output is in a div with data-mime-type="text/plain"
     143             :   // But html output is in a div with data-mime-type="text/html"
     144             :   // Jupyter applies different padding to those, so hack the inner pre element
     145             :   // to match the padding of text/plain.
     146           0 :   return "<pre style=\"margin-bottom:0; padding-top:var(--jp-code-padding)\">" +
     147           0 :          unit.name() + "</pre>";
     148             : }
     149             : 
     150           0 : void repr_pretty(const units::Unit &unit, py::object &p,
     151             :                  [[maybe_unused]] const bool cycle) {
     152           0 :   p.attr("text")(unit.name());
     153           0 : }
     154             : 
     155             : } // namespace
     156             : 
     157           3 : void init_units(py::module &m) {
     158           3 :   py::class_<DefaultUnit>(m, "DefaultUnit")
     159           3 :       .def("__repr__",
     160          36 :            [](const DefaultUnit &) { return "<automatically deduced unit>"; });
     161           3 :   py::class_<units::Unit>(m, "Unit", "A physical unit.")
     162           3 :       .def(py::init<const std::string &>())
     163         777 :       .def("__str__", [](const units::Unit &u) { return u.name(); })
     164           3 :       .def("__repr__", repr)
     165           3 :       .def("_repr_html_", repr_html)
     166           3 :       .def("_repr_pretty_", repr_pretty)
     167           3 :       .def_property_readonly("name", &units::Unit::name,
     168             :                              "A read-only string describing the "
     169             :                              "type of unit.")
     170           3 :       .def(py::self + py::self)
     171           3 :       .def(py::self - py::self)
     172           3 :       .def(py::self * py::self)
     173             :       // cppcheck-suppress duplicateExpression
     174           3 :       .def(py::self / py::self)
     175           3 :       .def("__pow__", [](const units::Unit &self,
     176         550 :                          const int64_t power) { return pow(self, power); })
     177           4 :       .def("__abs__", [](const units::Unit &self) { return abs(self); })
     178           3 :       .def(py::self == py::self)
     179           3 :       .def(py::self != py::self)
     180           6 :       .def(hash(py::self))
     181           3 :       .def("to_dict", to_dict,
     182             :            "Serialize a unit to a dict.\n\nThis function is meant to be used "
     183             :            "with :meth:`scipp.Unit.from_dict` to serialize units.\n\n"
     184             :            "Warning\n"
     185             :            "-------\n"
     186             :            "The structure of the returned dict is an implementation detail and "
     187             :            "may change without warning at any time! "
     188             :            "It should not be used to access the internal representation of "
     189             :            "``Unit``.")
     190           3 :       .def("from_dict", from_dict,
     191             :            "Deserialize a unit from a dict.\n\nThis function is meant to be "
     192             :            "used in combination with :meth:`scipp.Unit.to_dict`.");
     193             : 
     194           4 :   m.def("abs", [](const units::Unit &u) { return abs(u); });
     195           3 :   m.def("pow", [](const units::Unit &u, const int64_t power) {
     196           1 :     return pow(u, power);
     197             :   });
     198           3 :   m.def("pow",
     199           1 :         [](const units::Unit &u, const double power) { return pow(u, power); });
     200           4 :   m.def("reciprocal", [](const units::Unit &u) { return units::one / u; });
     201           4 :   m.def("sqrt", [](const units::Unit &u) { return sqrt(u); });
     202             : 
     203           3 :   py::implicitly_convertible<std::string, units::Unit>();
     204             : 
     205           3 :   auto units = m.def_submodule("units");
     206           3 :   units.attr("angstrom") = units::angstrom;
     207           3 :   units.attr("counts") = units::counts;
     208           3 :   units.attr("deg") = units::deg;
     209           3 :   units.attr("dimensionless") = units::dimensionless;
     210           3 :   units.attr("kg") = units::kg;
     211           3 :   units.attr("K") = units::K;
     212           3 :   units.attr("meV") = units::meV;
     213           3 :   units.attr("m") = units::m;
     214             :   // Note: No binding to units::none here, use None in Python!
     215           3 :   units.attr("one") = units::one;
     216           3 :   units.attr("rad") = units::rad;
     217           3 :   units.attr("s") = units::s;
     218           3 :   units.attr("us") = units::us;
     219           3 :   units.attr("ns") = units::ns;
     220           3 :   units.attr("mm") = units::mm;
     221             : 
     222           3 :   units.attr("default_unit") = DefaultUnit{};
     223             : 
     224           6 :   m.def("to_numpy_time_string",
     225           3 :         py::overload_cast<const ProtoUnit &>(to_numpy_time_string))
     226           3 :       .def(
     227             :           "units_identical",
     228          11 :           [](const units::Unit &a, const units::Unit &b) {
     229          11 :             return identical(a, b);
     230             :           },
     231             :           "Check if two units are numerically identical.\n\n"
     232             :           "The regular equality operator allows for small differences "
     233             :           "in the unit's floating point multiplier. ``units_identical`` "
     234             :           "checks for exact identity.")
     235           3 :       .def("add_unit_alias", scipp::units::add_unit_alias, py::kw_only(),
     236           0 :            py::arg("name"), py::arg("unit"))
     237           3 :       .def("clear_unit_aliases", scipp::units::clear_unit_aliases);
     238           3 : }

Generated by: LCOV version 1.14