LCOV - code coverage report
Current view: top level - python - shape.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 55 59 93.2 %
Date: 2024-12-01 01:56:34 Functions: 27 29 93.1 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Owen Arnold
       5             : #include "scipp/variable/shape.h"
       6             : #include "docstring.h"
       7             : #include "pybind11.h"
       8             : #include "scipp/dataset/shape.h"
       9             : #include "scipp/variable/variable.h"
      10             : 
      11             : #include "dim.h"
      12             : 
      13             : using namespace scipp;
      14             : using namespace scipp::variable;
      15             : 
      16             : namespace py = pybind11;
      17             : 
      18             : namespace {
      19             : 
      20           3 : template <class T> void bind_broadcast(py::module &m) {
      21           3 :   m.def(
      22             :       "broadcast",
      23        1312 :       [](const T &self, const std::vector<std::string> &labels,
      24             :          const std::vector<scipp::index> &shape) {
      25        1312 :         return broadcast(self, make_dims(labels, shape));
      26             :       },
      27           0 :       py::arg("x"), py::arg("dims"), py::arg("shape"));
      28           3 : }
      29             : 
      30           9 : template <class T> void bind_concat(py::module &m) {
      31           9 :   m.def(
      32             :       "concat",
      33        1150 :       [](const std::vector<T> &x, const std::string &dim) {
      34        1150 :         return concat(x, Dim{dim});
      35             :       },
      36           0 :       py::arg("x"), py::arg("dim"), py::call_guard<py::gil_scoped_release>());
      37           9 : }
      38             : 
      39           6 : template <class T> void bind_fold(pybind11::module &mod) {
      40           6 :   mod.def(
      41             :       "fold",
      42         103 :       [](const T &self, const std::string &dim,
      43             :          const std::vector<std::string> &labels,
      44             :          const std::vector<scipp::index> &shape) {
      45         103 :         return fold(self, Dim{dim}, make_dims(labels, shape));
      46             :       },
      47           0 :       py::arg("x"), py::arg("dim"), py::arg("dims"), py::arg("shape"),
      48           6 :       py::call_guard<py::gil_scoped_release>());
      49           6 : }
      50             : 
      51           6 : template <class T> void bind_flatten(pybind11::module &mod) {
      52           6 :   mod.def(
      53             :       "flatten",
      54        2855 :       [](const T &self, const std::optional<std::vector<std::string>> &dims,
      55             :          const std::string &to) {
      56        2855 :         if (dims.has_value())
      57         788 :           return flatten(self, to_dim_type(*dims), Dim{to});
      58             :         // If no dims are given then we flatten all dims. For variables we just
      59             :         // provide a list of all labels. DataArrays are different, as the
      60             :         // behavior in the degenerate case of a 0-D 'self' must distinguish
      61             :         // between flattening "zero dims" and "all dims". The latter is
      62             :         // specified using std::nullopt.
      63             :         if constexpr (std::is_same_v<T, Variable>)
      64           8 :           return flatten(self, self.dims().labels(), Dim{to});
      65             :         else
      66        2063 :           return flatten(self, std::nullopt, Dim{to});
      67             :       },
      68           0 :       py::arg("x"), py::arg("dims"), py::arg("to"),
      69           6 :       py::call_guard<py::gil_scoped_release>());
      70           6 : }
      71             : 
      72           9 : template <class T> void bind_transpose(pybind11::module &mod) {
      73           9 :   mod.def(
      74             :       "transpose",
      75        5017 :       [](const T &self, const std::vector<std::string> &dims) {
      76        5017 :         return transpose(self, to_dim_type(dims));
      77             :       },
      78          18 :       py::arg("x"), py::arg("dims") = std::vector<std::string>{});
      79           9 : }
      80             : 
      81           9 : template <class T> void bind_squeeze(pybind11::module &mod) {
      82           9 :   mod.def(
      83             :       "squeeze",
      84          24 :       [](const T &self, const std::optional<std::vector<std::string>> &dims) {
      85          24 :         return squeeze(self, dims.has_value()
      86          48 :                                  ? std::optional{to_dim_type(*dims)}
      87          48 :                                  : std::optional<std::vector<Dim>>{});
      88             :       },
      89          18 :       py::arg("x"), py::arg("dims") = std::nullopt);
      90           9 : }
      91             : } // namespace
      92             : 
      93           3 : void init_shape(py::module &m) {
      94           3 :   bind_broadcast<Variable>(m);
      95           3 :   bind_concat<Variable>(m);
      96           3 :   bind_concat<DataArray>(m);
      97           3 :   bind_concat<Dataset>(m);
      98           3 :   bind_fold<Variable>(m);
      99           3 :   bind_fold<DataArray>(m);
     100           3 :   bind_flatten<Variable>(m);
     101           3 :   bind_flatten<DataArray>(m);
     102           3 :   bind_transpose<Variable>(m);
     103           3 :   bind_transpose<DataArray>(m);
     104           3 :   bind_transpose<Dataset>(m);
     105           3 :   bind_squeeze<Variable>(m);
     106           3 :   bind_squeeze<DataArray>(m);
     107           3 :   bind_squeeze<Dataset>(m);
     108           3 : }

Generated by: LCOV version 1.14