LCOV - code coverage report
Current view: top level - python - bind_units.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 125 144 86.8 %
Date: 2024-04-28 01:25:40 Functions: 23 25 92.0 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #include <sstream>
       6             : 
       7             : #include "scipp/core/tag_util.h"
       8             : #include "scipp/units/unit.h"
       9             : 
      10             : #include "pybind11.h"
      11             : #include "unit.h"
      12             : 
      13             : using namespace scipp;
      14             : namespace py = pybind11;
      15             : 
      16             : constexpr int UNIT_DICT_VERSION = 2;
      17             : constexpr std::array SUPPORTED_UNIT_DICT_VERSIONS = {1, 2};
      18             : 
      19             : namespace {
      20             : 
      21         282 : bool is_supported_unit(const units::Unit &unit) {
      22         282 :   return unit.underlying().commodity() == 0;
      23             : }
      24             : 
      25             : // We only support units where we are confident that we can encode them using
      26             : // a different unit library, in order to ensure that we can switch
      27             : // implementations in the future if necessary.
      28         138 : void assert_supported_unit_for_dict(const units::Unit &unit) {
      29         138 :   if (!is_supported_unit(unit)) {
      30           0 :     throw std::invalid_argument("Unit cannot be converted to dict: '" +
      31           0 :                                 to_string(unit) +
      32           0 :                                 "' Commodities are not supported.");
      33             :   }
      34         138 : }
      35             : 
      36         138 : py::dict to_dict(const units::Unit &unit) {
      37         138 :   assert_supported_unit_for_dict(unit);
      38             : 
      39         138 :   py::dict dict;
      40         138 :   dict["__version__"] = UNIT_DICT_VERSION;
      41         138 :   dict["multiplier"] = unit.underlying().multiplier();
      42             : 
      43         138 :   unit.map_over_flags([&dict](const char *const name, const auto flag) mutable {
      44         552 :     if (flag) {
      45           9 :       dict[name] = true;
      46             :     }
      47         552 :   });
      48             : 
      49         138 :   py::dict powers;
      50         138 :   unit.map_over_bases(
      51        1515 :       [&powers](const char *const base, const auto power) mutable {
      52        1380 :         if (power != 0) {
      53         135 :           powers[base] = power;
      54             :         }
      55        1380 :       });
      56         138 :   if (!powers.empty())
      57         123 :     dict["powers"] = powers;
      58             : 
      59         276 :   return dict;
      60         138 : }
      61             : 
      62        2128 : template <class T = int> T get(const py::dict &dict, const char *const name) {
      63        2128 :   if (dict.contains(name)) {
      64         156 :     return dict[name].cast<T>();
      65             :   }
      66        1972 :   return T{};
      67             : }
      68             : 
      69         152 : void assert_dict_version_supported(const py::dict &dict) {
      70         152 :   if (const auto ver = dict["__version__"].cast<int>();
      71         152 :       std::find(SUPPORTED_UNIT_DICT_VERSIONS.cbegin(),
      72             :                 SUPPORTED_UNIT_DICT_VERSIONS.cend(),
      73         152 :                 ver) == SUPPORTED_UNIT_DICT_VERSIONS.cend()) {
      74           0 :     std::ostringstream oss;
      75           0 :     oss << "Unit dict has version " << std::to_string(ver)
      76           0 :         << " but the current installation of scipp only supports versions [";
      77           0 :     for (const auto v : SUPPORTED_UNIT_DICT_VERSIONS)
      78           0 :       oss << v << ", ";
      79           0 :     oss << "]";
      80           0 :     throw std::invalid_argument(oss.str());
      81           0 :   }
      82         152 : }
      83             : 
      84         152 : units::Unit from_dict(const py::dict &dict) {
      85         152 :   assert_dict_version_supported(dict);
      86             : 
      87         152 :   const py::dict powers = dict.contains("powers") ? dict["powers"] : py::dict();
      88         304 :   return units::Unit(llnl::units::precise_unit(
      89         304 :       llnl::units::detail::unit_data{
      90             :           get(powers, "m"), get(powers, "kg"), get(powers, "s"),
      91             :           get(powers, "A"), get(powers, "K"), get(powers, "mol"),
      92             :           get(powers, "cd"), get(powers, "$"), get(powers, "counts"),
      93         152 :           get(powers, "rad"), get<bool>(dict, "per_unit"),
      94         152 :           get<bool>(dict, "i_flag"), get<bool>(dict, "e_flag"),
      95         152 :           get<bool>(dict, "equation")},
      96         304 :       dict["multiplier"].cast<double>()));
      97         152 : }
      98             : 
      99         144 : std::string repr(const units::Unit &unit) {
     100         144 :   if (!is_supported_unit(unit)) {
     101           0 :     return "<unsupported unit: " + to_string(unit) + '>';
     102             :   }
     103             : 
     104         144 :   std::ostringstream oss;
     105         144 :   oss << "Unit(";
     106             : 
     107         144 :   bool first = true;
     108         144 :   if (const auto mult = unit.underlying().multiplier(); mult != 1.0) {
     109          11 :     oss << mult;
     110          11 :     first = false;
     111             :   }
     112             : 
     113         144 :   unit.map_over_bases(
     114        1838 :       [&oss, &first](const char *const base, const auto power) mutable {
     115        1440 :         if (power != 0) {
     116         127 :           if (!first) {
     117          20 :             oss << "*";
     118             :           } else {
     119         107 :             first = false;
     120             :           }
     121         127 :           oss << base;
     122         127 :           if (power != 1)
     123          17 :             oss << "**" << power;
     124             :         }
     125        1440 :       });
     126         144 :   if (first)
     127          26 :     oss << "1"; // multiplier == 1 and all powers == 0
     128             : 
     129         144 :   unit.map_over_flags([&oss](const char *const name, const auto flag) mutable {
     130         576 :     if (flag)
     131           4 :       oss << ", " << name << "=True";
     132         576 :   });
     133         144 :   oss << ')';
     134         144 :   return oss.str();
     135         144 : }
     136             : 
     137           0 : std::string repr_html(const units::Unit &unit) {
     138             :   // Regular string output is in a div with data-mime-type="text/plain"
     139             :   // But html output is in a div with data-mime-type="text/html"
     140             :   // Jupyter applies different padding to those, so hack the inner pre element
     141             :   // to match the padding of text/plain.
     142           0 :   return "<pre style=\"margin-bottom:0; padding-top:var(--jp-code-padding)\">" +
     143           0 :          unit.name() + "</pre>";
     144             : }
     145             : 
     146           0 : void repr_pretty(const units::Unit &unit, py::object &p,
     147             :                  [[maybe_unused]] const bool cycle) {
     148           0 :   p.attr("text")(unit.name());
     149           0 : }
     150             : 
     151             : } // namespace
     152             : 
     153           3 : void init_units(py::module &m) {
     154           3 :   py::class_<DefaultUnit>(m, "DefaultUnit")
     155           3 :       .def("__repr__",
     156          36 :            [](const DefaultUnit &) { return "<automatically deduced unit>"; });
     157           3 :   py::class_<units::Unit>(m, "Unit", "A physical unit.")
     158           3 :       .def(py::init<const std::string &>())
     159         735 :       .def("__str__", [](const units::Unit &u) { return u.name(); })
     160           3 :       .def("__repr__", repr)
     161           3 :       .def("_repr_html_", repr_html)
     162           3 :       .def("_repr_pretty_", repr_pretty)
     163           3 :       .def_property_readonly("name", &units::Unit::name,
     164             :                              "A read-only string describing the "
     165             :                              "type of unit.")
     166           3 :       .def(py::self + py::self)
     167           3 :       .def(py::self - py::self)
     168           3 :       .def(py::self * py::self)
     169             :       // cppcheck-suppress duplicateExpression
     170           3 :       .def(py::self / py::self)
     171           3 :       .def("__pow__", [](const units::Unit &self,
     172         598 :                          const int64_t power) { return pow(self, power); })
     173           4 :       .def("__abs__", [](const units::Unit &self) { return abs(self); })
     174           3 :       .def(py::self == py::self)
     175           3 :       .def(py::self != py::self)
     176           6 :       .def(hash(py::self))
     177           3 :       .def("to_dict", to_dict,
     178             :            "Serialize a unit to a dict.\n\nThis function is meant to be used "
     179             :            "with :meth:`scipp.Unit.from_dict` to serialize units.\n\n"
     180             :            "Warning\n"
     181             :            "-------\n"
     182             :            "The structure of the returned dict is an implementation detail and "
     183             :            "may change without warning at any time! "
     184             :            "It should not be used to access the internal representation of "
     185             :            "``Unit``.")
     186           3 :       .def("from_dict", from_dict,
     187             :            "Deserialize a unit from a dict.\n\nThis function is meant to be "
     188             :            "used in combination with :meth:`scipp.Unit.to_dict`.");
     189             : 
     190           4 :   m.def("abs", [](const units::Unit &u) { return abs(u); });
     191           3 :   m.def("pow", [](const units::Unit &u, const int64_t power) {
     192           1 :     return pow(u, power);
     193             :   });
     194           3 :   m.def("pow",
     195           1 :         [](const units::Unit &u, const double power) { return pow(u, power); });
     196           4 :   m.def("reciprocal", [](const units::Unit &u) { return units::one / u; });
     197           4 :   m.def("sqrt", [](const units::Unit &u) { return sqrt(u); });
     198             : 
     199           3 :   py::implicitly_convertible<std::string, units::Unit>();
     200             : 
     201           3 :   auto units = m.def_submodule("units");
     202           3 :   units.attr("angstrom") = units::angstrom;
     203           3 :   units.attr("counts") = units::counts;
     204           3 :   units.attr("deg") = units::deg;
     205           3 :   units.attr("dimensionless") = units::dimensionless;
     206           3 :   units.attr("kg") = units::kg;
     207           3 :   units.attr("K") = units::K;
     208           3 :   units.attr("meV") = units::meV;
     209           3 :   units.attr("m") = units::m;
     210             :   // Note: No binding to units::none here, use None in Python!
     211           3 :   units.attr("one") = units::one;
     212           3 :   units.attr("rad") = units::rad;
     213           3 :   units.attr("s") = units::s;
     214           3 :   units.attr("us") = units::us;
     215           3 :   units.attr("ns") = units::ns;
     216           3 :   units.attr("mm") = units::mm;
     217             : 
     218           3 :   units.attr("default_unit") = DefaultUnit{};
     219             : 
     220           6 :   m.def("to_numpy_time_string",
     221           3 :         py::overload_cast<const ProtoUnit &>(to_numpy_time_string))
     222           3 :       .def(
     223             :           "units_identical",
     224          11 :           [](const units::Unit &a, const units::Unit &b) {
     225          11 :             return identical(a, b);
     226             :           },
     227             :           "Check if two units are numerically identical.\n\n"
     228             :           "The regular equality operator allows for small differences "
     229             :           "in the unit's floating point multiplier. ``units_identical`` "
     230             :           "checks for exact identity.")
     231           3 :       .def("add_unit_alias", scipp::units::add_unit_alias, py::kw_only(),
     232           0 :            py::arg("name"), py::arg("unit"))
     233           3 :       .def("clear_unit_aliases", scipp::units::clear_unit_aliases);
     234           3 : }

Generated by: LCOV version 1.14