Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 :
6 : #include "scipp/dataset/groupby.h"
7 : #include "scipp/dataset/dataset.h"
8 :
9 : #include "docstring.h"
10 : #include "pybind11.h"
11 :
12 : using namespace scipp;
13 : using namespace scipp::dataset;
14 :
15 : namespace py = pybind11;
16 :
17 60 : template <class T> Docstring docstring_groupby(const std::string &op) {
18 120 : return Docstring()
19 120 : .description("Element-wise " + op +
20 : " over the specified dimension "
21 : "within a group.")
22 120 : .returns("The computed " + op +
23 : " over each group, combined along "
24 : "the dimension specified when calling :py:func:`scipp.groupby`.")
25 60 : .rtype<T>()
26 120 : .param("dim", "Dimension to reduce when computing the " + op + ".",
27 120 : "Dim");
28 : }
29 :
30 : #define STRINGIFY(x) #x
31 : #define TOSTRING(x) STRINGIFY(x)
32 : #define BIND_GROUPBY_OP(CLS, NAME) \
33 : CLS.def( \
34 : TOSTRING(NAME), \
35 : [](const GroupBy<T> &self, const std::string &dim) { \
36 : return self.NAME(Dim{dim}); \
37 : }, \
38 : py::arg("dim"), py::call_guard<py::gil_scoped_release>(), \
39 : docstring_groupby<T>(TOSTRING(NAME)).c_str());
40 :
41 6 : template <class T> void bind_groupby(py::module &m, const std::string &name) {
42 6 : m.def(
43 : "groupby",
44 8 : [](const T &x, const std::string &dim) { return groupby(x, Dim{dim}); },
45 0 : py::arg("data"), py::arg("group"),
46 0 : py::call_guard<py::gil_scoped_release>());
47 :
48 6 : m.def(
49 : "groupby",
50 0 : [](const T &x, const std::string &dim, const Variable &bins) {
51 0 : return groupby(x, Dim{dim}, bins);
52 : },
53 0 : py::arg("data"), py::arg("group"), py::arg("bins"),
54 0 : py::call_guard<py::gil_scoped_release>());
55 :
56 6 : m.def("groupby",
57 6 : py::overload_cast<const T &, const Variable &, const Variable &>(
58 : &groupby),
59 0 : py::arg("data"), py::arg("group"), py::arg("bins"),
60 0 : py::call_guard<py::gil_scoped_release>());
61 :
62 6 : py::class_<GroupBy<T>> groupBy(m, name.c_str(), R"(
63 : GroupBy object implementing split-apply-combine mechanism.)");
64 :
65 6 : BIND_GROUPBY_OP(groupBy, mean);
66 14 : BIND_GROUPBY_OP(groupBy, sum);
67 6 : BIND_GROUPBY_OP(groupBy, nansum);
68 6 : BIND_GROUPBY_OP(groupBy, all);
69 6 : BIND_GROUPBY_OP(groupBy, any);
70 6 : BIND_GROUPBY_OP(groupBy, min);
71 6 : BIND_GROUPBY_OP(groupBy, nanmin);
72 6 : BIND_GROUPBY_OP(groupBy, max);
73 6 : BIND_GROUPBY_OP(groupBy, nanmax);
74 6 : BIND_GROUPBY_OP(groupBy, concat);
75 6 : }
76 :
77 3 : void init_groupby(py::module &m) {
78 3 : bind_groupby<DataArray>(m, "GroupByDataArray");
79 3 : bind_groupby<Dataset>(m, "GroupByDataset");
80 3 : }
|