LCOV - code coverage report
Current view: top level - python - bins.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 119 142 83.8 %
Date: 2024-11-24 01:48:31 Functions: 29 35 82.9 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Hezbrock
       5             : #include "scipp/dataset/bins.h"
       6             : #include "scipp/core/except.h"
       7             : #include "scipp/dataset/bin.h"
       8             : #include "scipp/dataset/bins_view.h"
       9             : #include "scipp/variable/arithmetic.h"
      10             : #include "scipp/variable/cumulative.h"
      11             : #include "scipp/variable/shape.h"
      12             : #include "scipp/variable/util.h"
      13             : #include "scipp/variable/variable.h"
      14             : #include "scipp/variable/variable_factory.h"
      15             : 
      16             : #include "bind_data_array.h"
      17             : #include "dim.h"
      18             : #include "pybind11.h"
      19             : 
      20             : using namespace scipp;
      21             : 
      22             : namespace py = pybind11;
      23             : 
      24             : namespace {
      25             : 
      26             : template <class T>
      27        5552 : auto call_make_bins(const std::optional<Variable> &begin_arg,
      28             :                     const std::optional<Variable> &end_arg, const Dim dim,
      29             :                     T &&data, const bool validate = true) {
      30        5552 :   Variable indices;
      31        5552 :   if (begin_arg.has_value()) {
      32        5544 :     const auto &begin = *begin_arg;
      33        5544 :     if (end_arg.has_value()) {
      34        5542 :       const auto &end = *end_arg;
      35        5542 :       indices = zip(begin, end);
      36             :     } else {
      37           2 :       indices = zip(begin, begin);
      38           2 :       const auto indices_ = indices.values<scipp::index_pair>();
      39           2 :       const auto nindex = scipp::size(indices_);
      40           6 :       for (scipp::index i = 0; i < nindex; ++i) {
      41           4 :         if (i < nindex - 1)
      42           2 :           indices_[i].second = indices_[i + 1].first;
      43             :         else
      44           2 :           indices_[i].second = data.dims()[dim];
      45             :       }
      46           2 :     }
      47           8 :   } else if (!end_arg.has_value()) {
      48           7 :     const auto one = scipp::index{1} * units::none;
      49           7 :     const auto ones = broadcast(one, {dim, data.dims()[dim]});
      50           7 :     const auto begin = cumsum(ones, dim, CumSumMode::Exclusive);
      51           7 :     indices = zip(begin, begin + one);
      52           7 :   } else {
      53           1 :     throw std::runtime_error("`end` given but not `begin`");
      54             :   }
      55        5551 :   if (validate)
      56         216 :     return make_bins(std::move(indices), dim, std::forward<T>(data));
      57             :   else
      58        5335 :     return make_bins_no_validate(std::move(indices), dim,
      59       10670 :                                  std::forward<T>(data));
      60        5552 : }
      61             : 
      62           9 : template <class T> void bind_bins(pybind11::module &m) {
      63          27 :   m.def(
      64             :       "bins",
      65         217 :       [](const std::optional<Variable> &begin,
      66             :          const std::optional<Variable> &end, const std::string &dim,
      67             :          const T &data) {
      68         217 :         return call_make_bins(begin, end, Dim{dim}, T(data));
      69             :       },
      70          27 :       py::arg("begin") = py::none(), py::arg("end") = py::none(),
      71           0 :       py::arg("dim"), py::arg("data")); // do not release GIL since using
      72             :                                         // implicit conversions in functor
      73           9 :   m.def(
      74             :       "_bins_no_validate",
      75        5335 :       [](const Variable &begin, const Variable &end, const std::string &dim,
      76             :          const T &data) {
      77        5335 :         return call_make_bins(begin, end, Dim{dim}, T(data), false);
      78             :       },
      79           0 :       py::arg("begin"), py::arg("end"), py::arg("dim"),
      80           9 :       py::arg("data")); // do not release GIL since using
      81             :                         // implicit conversions in functor
      82           9 : }
      83             : 
      84       13824 : template <class T> py::dict bins_constituents(const Variable &var) {
      85       13824 :   auto &&[indices, dim, buffer] = var.constituents<T>();
      86       13824 :   auto &&[begin, end] = unzip(indices);
      87       13824 :   py::dict out;
      88       13824 :   out["begin"] = std::forward<decltype(begin)>(begin);
      89       13824 :   out["end"] = std::forward<decltype(end)>(end);
      90       13824 :   out["dim"] = std::string(dim.name());
      91       13824 :   out["data"] = std::forward<decltype(buffer)>(buffer);
      92       27648 :   return out;
      93       13824 : }
      94             : 
      95             : template <class T, bool HasAlignment = false>
      96           6 : void bind_bins_map_view(py::module &m, const std::string &name) {
      97           6 :   py::class_<T> c(m, name.c_str());
      98           6 :   bind_common_mutable_view_operators(c);
      99           6 :   bind_pop(c);
     100             :   if constexpr (HasAlignment) {
     101             :     bind_set_aligned(c);
     102             :   }
     103           6 : }
     104             : 
     105           3 : template <class T> void bind_bins_view(py::module &m) {
     106             :   bind_helper_view<str_items_view,
     107           3 :                    decltype(dataset::bins_view<T>(Variable{}).coords())>(
     108             :       m, "_BinsCoords");
     109             :   bind_helper_view<items_view,
     110           3 :                    decltype(dataset::bins_view<T>(Variable{}).masks())>(
     111             :       m, "_BinsMasks");
     112             :   bind_helper_view<str_keys_view,
     113           3 :                    decltype(dataset::bins_view<T>(Variable{}).coords())>(
     114             :       m, "_BinsCoords");
     115             :   bind_helper_view<keys_view,
     116           3 :                    decltype(dataset::bins_view<T>(Variable{}).masks())>(
     117             :       m, "_BinsMasks");
     118             :   bind_helper_view<values_view,
     119           3 :                    decltype(dataset::bins_view<T>(Variable{}).coords())>(
     120             :       m, "_BinsCoords");
     121             :   bind_helper_view<values_view,
     122           3 :                    decltype(dataset::bins_view<T>(Variable{}).masks())>(
     123             :       m, "_BinsMasks");
     124             : 
     125           3 :   py::class_<decltype(dataset::bins_view<T>(Variable{}))> c(
     126             :       m, "_BinsViewDataArray");
     127           3 :   bind_bins_map_view<decltype(dataset::bins_view<T>(Variable{}).meta())>(
     128             :       m, "_BinsMeta");
     129             :   bind_mutable_view_no_dim<
     130           3 :       decltype(dataset::bins_view<T>(Variable{}).coords())>(
     131             :       m, "_BinsCoords", "Dict of event coords.");
     132           3 :   bind_mutable_view<decltype(dataset::bins_view<T>(Variable{}).masks())>(
     133             :       m, "_BinsMasks", "Dict of event masks.");
     134           3 :   bind_bins_map_view<decltype(dataset::bins_view<T>(Variable{}).attrs())>(
     135             :       m, "_BinsAttrs");
     136           3 :   bind_data_array_properties(c);
     137           3 :   m.def("_bins_view",
     138       15143 :         [](const Variable &var) { return dataset::bins_view<T>(var); });
     139           3 : }
     140             : 
     141             : template <class T, class Data>
     142        3742 : auto bins_like(const Variable &bins, const Data &data) {
     143        3742 :   auto &&[idx, dim, buf] = bins.constituents<T>();
     144        7484 :   auto out = make_bins_no_validate(idx, dim, empty_like(data, buf.dims()));
     145        3742 :   out.setSlice(Slice{}, data);
     146        7482 :   return out;
     147        3743 : }
     148             : 
     149           3 : template <class Data> void bind_bins_like(py::module &m) {
     150           3 :   m.def("bins_like", [](const Variable &bins, const Data &data) {
     151        3742 :     if (bins.dtype() == dtype<bucket<Variable>>)
     152           4 :       return bins_like<Variable>(bins, data);
     153        3738 :     if (bins.dtype() == dtype<bucket<DataArray>>)
     154        3738 :       return bins_like<DataArray>(bins, data);
     155           0 :     throw except::TypeError(
     156             :         "In `bins_like`: Prototype must contain binned data but got dtype=" +
     157             :         to_string(bins.dtype()));
     158             :   });
     159           3 : }
     160             : 
     161             : } // namespace
     162             : 
     163           3 : void init_buckets(py::module &m) {
     164           3 :   bind_bins<Variable>(m);
     165           3 :   bind_bins<DataArray>(m);
     166           3 :   bind_bins<Dataset>(m);
     167             : 
     168           3 :   bind_bins_like<Variable>(m);
     169             : 
     170           3 :   m.def("is_bins", variable::is_bins);
     171           3 :   m.def("is_bins",
     172       58445 :         [](const DataArray &array) { return dataset::is_bins(array); });
     173           3 :   m.def("is_bins",
     174           0 :         [](const Dataset &dataset) { return dataset::is_bins(dataset); });
     175             : 
     176           3 :   m.def("bins_constituents", [](const Variable &var) {
     177       13824 :     const auto dt = var.dtype();
     178       13824 :     if (dt == dtype<bucket<Variable>>)
     179         849 :       return bins_constituents<Variable>(var);
     180       12975 :     if (dt == dtype<bucket<DataArray>>)
     181       12974 :       return bins_constituents<DataArray>(var);
     182           1 :     if (dt == dtype<bucket<Dataset>>)
     183           1 :       return bins_constituents<Dataset>(var);
     184           0 :     throw except::TypeError("'constituents' does not support dtype " +
     185           0 :                             to_string(dt));
     186             :   });
     187             : 
     188           3 :   m.def(
     189             :       "lookup_previous",
     190          40 :       [](const DataArray &function, const Variable &x, const std::string &dim,
     191             :          const std::optional<Variable> &fill_value) {
     192          40 :         return dataset::lookup_previous(function, x, Dim{dim}, fill_value);
     193             :       },
     194           0 :       py::call_guard<py::gil_scoped_release>());
     195             : 
     196           3 :   auto buckets = m.def_submodule("buckets");
     197           3 :   buckets.def(
     198             :       "concatenate",
     199           0 :       [](const Variable &a, const Variable &b) {
     200           0 :         return dataset::buckets::concatenate(a, b);
     201             :       },
     202           0 :       py::call_guard<py::gil_scoped_release>());
     203           3 :   buckets.def(
     204             :       "concatenate",
     205           0 :       [](const DataArray &a, const DataArray &b) {
     206           0 :         return dataset::buckets::concatenate(a, b);
     207             :       },
     208           0 :       py::call_guard<py::gil_scoped_release>());
     209           3 :   buckets.def(
     210             :       "append",
     211           0 :       [](Variable &a, const Variable &b) {
     212           0 :         return dataset::buckets::append(a, b);
     213             :       },
     214           0 :       py::call_guard<py::gil_scoped_release>());
     215           3 :   buckets.def(
     216             :       "append",
     217           0 :       [](DataArray &a, const DataArray &b) {
     218           0 :         return dataset::buckets::append(a, b);
     219             :       },
     220           0 :       py::call_guard<py::gil_scoped_release>());
     221           3 :   buckets.def(
     222             :       "map",
     223         127 :       [](const DataArray &function, const Variable &x, const std::string &dim,
     224             :          const std::optional<Variable> &fill_value) {
     225         127 :         return dataset::buckets::map(function, x, Dim{dim}, fill_value);
     226             :       },
     227           0 :       py::call_guard<py::gil_scoped_release>());
     228           3 :   buckets.def(
     229             :       "scale",
     230         111 :       [](DataArray &array, const DataArray &histogram, const std::string &dim) {
     231         111 :         return dataset::buckets::scale(array, histogram, Dim{dim});
     232             :       },
     233           0 :       py::call_guard<py::gil_scoped_release>());
     234             : 
     235           3 :   m.def(
     236             :       "bin",
     237        4228 :       [](const DataArray &array, const std::vector<Variable> &edges,
     238             :          const std::vector<Variable> &groups,
     239             :          const std::vector<std::string> &erase) {
     240        8456 :         return dataset::bin(array, edges, groups, to_dim_type(erase));
     241             :       },
     242           0 :       py::arg("array"), py::arg("edges"),
     243           6 :       py::arg("groups") = std::vector<Variable>{},
     244           6 :       py::arg("erase") = std::vector<std::string>{},
     245           0 :       py::call_guard<py::gil_scoped_release>());
     246             : 
     247           3 :   bind_bins_view<DataArray>(m);
     248           3 : }

Generated by: LCOV version 1.14