Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #include "dim.h"
6 : #include "pybind11.h"
7 : #include "slice_utils.h"
8 :
9 : #include "scipp/dataset/dataset.h"
10 : #include "scipp/dataset/sort.h"
11 : #include "scipp/variable/math.h"
12 : #include "scipp/variable/operations.h"
13 : #include "scipp/variable/slice.h"
14 : #include "scipp/variable/sort.h"
15 : #include "scipp/variable/util.h"
16 :
17 : using namespace scipp;
18 : using namespace scipp::variable;
19 : using namespace scipp::dataset;
20 :
21 : namespace py = pybind11;
22 :
23 7 : auto get_sort_order(const std::string &order) {
24 7 : if (order == "ascending")
25 4 : return SortOrder::Ascending;
26 3 : else if (order == "descending")
27 3 : return SortOrder::Descending;
28 : else
29 0 : throw std::runtime_error("Sort order must be 'ascending' or 'descending'");
30 : }
31 :
32 3 : template <typename T> void bind_dot(py::module &m) {
33 3 : m.def(
34 1 : "dot", [](const T &x, const T &y) { return dot(x, y); }, py::arg("x"),
35 0 : py::arg("y"), py::call_guard<py::gil_scoped_release>());
36 3 : }
37 :
38 9 : template <typename T> void bind_sort(py::module &m) {
39 9 : m.def(
40 : "sort",
41 1 : [](const T &x, const Variable &key, const std::string &order) {
42 1 : return sort(x, key, get_sort_order(order));
43 : },
44 0 : py::arg("x"), py::arg("key"), py::arg("order"),
45 9 : py::call_guard<py::gil_scoped_release>());
46 9 : }
47 :
48 9 : template <typename T> void bind_sort_dim(py::module &m) {
49 9 : m.def(
50 : "sort",
51 2 : [](const T &x, const std::string &dim, const std::string &order) {
52 2 : return sort(x, Dim{dim}, get_sort_order(order));
53 : },
54 0 : py::arg("x"), py::arg("key"), py::arg("order"),
55 9 : py::call_guard<py::gil_scoped_release>());
56 9 : }
57 :
58 3 : void bind_issorted(py::module &m) {
59 3 : m.def(
60 : "issorted",
61 2 : [](const Variable &x, const std::string &dim, const std::string &order) {
62 2 : return issorted(x, Dim{dim}, get_sort_order(order));
63 : },
64 6 : py::arg("x"), py::arg("dim"), py::arg("order") = "ascending",
65 3 : py::call_guard<py::gil_scoped_release>());
66 3 : }
67 :
68 3 : void bind_allsorted(py::module &m) {
69 3 : m.def(
70 : "allsorted",
71 2 : [](const Variable &x, const std::string &dim, const std::string &order) {
72 2 : return allsorted(x, Dim{dim}, get_sort_order(order));
73 : },
74 6 : py::arg("x"), py::arg("dim"), py::arg("order") = "ascending",
75 3 : py::call_guard<py::gil_scoped_release>());
76 3 : }
77 :
78 3 : void bind_midpoints(py::module &m) {
79 3 : m.def("midpoints", [](const Variable &var,
80 : const std::optional<std::string> &dim) {
81 24 : return midpoints(var, dim.has_value() ? Dim{*dim} : std::optional<Dim>{});
82 : });
83 3 : }
84 :
85 3 : void init_operations(py::module &m) {
86 3 : bind_dot<Variable>(m);
87 :
88 3 : bind_sort<Variable>(m);
89 3 : bind_sort<DataArray>(m);
90 3 : bind_sort<Dataset>(m);
91 3 : bind_sort_dim<Variable>(m);
92 3 : bind_sort_dim<DataArray>(m);
93 3 : bind_sort_dim<Dataset>(m);
94 3 : bind_issorted(m);
95 3 : bind_allsorted(m);
96 3 : bind_midpoints(m);
97 :
98 3 : m.def(
99 : "label_based_index_to_positional_index",
100 9 : [](const std::vector<std::string> &dims,
101 : const std::vector<scipp::index> &shape, const Variable &coord,
102 : const Variable &value) {
103 9 : const auto [dim, index] =
104 9 : get_slice_params(make_dims(dims, shape), coord, value);
105 9 : return std::tuple{dim.name(), index};
106 : },
107 0 : py::call_guard<py::gil_scoped_release>());
108 3 : m.def("label_based_index_to_positional_index",
109 39 : [](const std::vector<std::string> &dims,
110 : const std::vector<scipp::index> &shape, const Variable &coord,
111 : const py::slice &py_slice) {
112 : try {
113 39 : auto [start_var, stop_var] = label_bounds_from_pyslice(py_slice);
114 39 : const auto [dim, start, stop] = get_slice_params(
115 39 : make_dims(dims, shape), coord, start_var, stop_var);
116 78 : return std::tuple{dim.name(), start, stop};
117 39 : } catch (const py::cast_error &) {
118 0 : throw std::runtime_error(
119 0 : "Value based slice must contain variables.");
120 0 : }
121 : });
122 :
123 3 : m.def("where", &variable::where, py::arg("condition"), py::arg("x"),
124 0 : py::arg("y"), py::call_guard<py::gil_scoped_release>());
125 3 : }
|