Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Owen Arnold
5 : #include "scipp/variable/shape.h"
6 : #include "docstring.h"
7 : #include "pybind11.h"
8 : #include "scipp/dataset/shape.h"
9 : #include "scipp/variable/variable.h"
10 :
11 : #include "dim.h"
12 :
13 : using namespace scipp;
14 : using namespace scipp::variable;
15 :
16 : namespace py = pybind11;
17 :
18 : namespace {
19 :
20 3 : template <class T> void bind_broadcast(py::module &m) {
21 3 : m.def(
22 : "broadcast",
23 1312 : [](const T &self, const std::vector<std::string> &labels,
24 : const std::vector<scipp::index> &shape) {
25 1312 : return broadcast(self, make_dims(labels, shape));
26 : },
27 0 : py::arg("x"), py::arg("dims"), py::arg("shape"));
28 3 : }
29 :
30 9 : template <class T> void bind_concat(py::module &m) {
31 9 : m.def(
32 : "concat",
33 1150 : [](const std::vector<T> &x, const std::string &dim) {
34 1150 : return concat(x, Dim{dim});
35 : },
36 0 : py::arg("x"), py::arg("dim"), py::call_guard<py::gil_scoped_release>());
37 9 : }
38 :
39 6 : template <class T> void bind_fold(pybind11::module &mod) {
40 6 : mod.def(
41 : "fold",
42 103 : [](const T &self, const std::string &dim,
43 : const std::vector<std::string> &labels,
44 : const std::vector<scipp::index> &shape) {
45 103 : return fold(self, Dim{dim}, make_dims(labels, shape));
46 : },
47 0 : py::arg("x"), py::arg("dim"), py::arg("dims"), py::arg("shape"),
48 6 : py::call_guard<py::gil_scoped_release>());
49 6 : }
50 :
51 6 : template <class T> void bind_flatten(pybind11::module &mod) {
52 6 : mod.def(
53 : "flatten",
54 2760 : [](const T &self, const std::optional<std::vector<std::string>> &dims,
55 : const std::string &to) {
56 2760 : if (dims.has_value())
57 693 : return flatten(self, to_dim_type(*dims), Dim{to});
58 : // If no dims are given then we flatten all dims. For variables we just
59 : // provide a list of all labels. DataArrays are different, as the
60 : // behavior in the degenerate case of a 0-D 'self' must distinguish
61 : // between flattening "zero dims" and "all dims". The latter is
62 : // specified using std::nullopt.
63 : if constexpr (std::is_same_v<T, Variable>)
64 8 : return flatten(self, self.dims().labels(), Dim{to});
65 : else
66 2063 : return flatten(self, std::nullopt, Dim{to});
67 : },
68 0 : py::arg("x"), py::arg("dims"), py::arg("to"),
69 6 : py::call_guard<py::gil_scoped_release>());
70 6 : }
71 :
72 9 : template <class T> void bind_transpose(pybind11::module &mod) {
73 9 : mod.def(
74 : "transpose",
75 4423 : [](const T &self, const std::vector<std::string> &dims) {
76 4423 : return transpose(self, to_dim_type(dims));
77 : },
78 18 : py::arg("x"), py::arg("dims") = std::vector<std::string>{});
79 9 : }
80 :
81 9 : template <class T> void bind_squeeze(pybind11::module &mod) {
82 9 : mod.def(
83 : "squeeze",
84 24 : [](const T &self, const std::optional<std::vector<std::string>> &dims) {
85 24 : return squeeze(self, dims.has_value()
86 48 : ? std::optional{to_dim_type(*dims)}
87 48 : : std::optional<std::vector<Dim>>{});
88 : },
89 18 : py::arg("x"), py::arg("dims") = std::nullopt);
90 9 : }
91 : } // namespace
92 :
93 3 : void init_shape(py::module &m) {
94 3 : bind_broadcast<Variable>(m);
95 3 : bind_concat<Variable>(m);
96 3 : bind_concat<DataArray>(m);
97 3 : bind_concat<Dataset>(m);
98 3 : bind_fold<Variable>(m);
99 3 : bind_fold<DataArray>(m);
100 3 : bind_flatten<Variable>(m);
101 3 : bind_flatten<DataArray>(m);
102 3 : bind_transpose<Variable>(m);
103 3 : bind_transpose<DataArray>(m);
104 3 : bind_transpose<Dataset>(m);
105 3 : bind_squeeze<Variable>(m);
106 3 : bind_squeeze<DataArray>(m);
107 3 : bind_squeeze<Dataset>(m);
108 3 : }
|