LCOV - code coverage report
Current view: top level - dataset - shape.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 175 186 94.1 %
Date: 2024-12-01 01:56:34 Functions: 55 58 94.8 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #include <algorithm>
       6             : 
       7             : #include "scipp/core/except.h"
       8             : #include "scipp/variable/creation.h"
       9             : #include "scipp/variable/shape.h"
      10             : 
      11             : #include "scipp/dataset/except.h"
      12             : #include "scipp/dataset/shape.h"
      13             : 
      14             : #include "dataset_operations_common.h"
      15             : 
      16             : using namespace scipp::variable;
      17             : 
      18             : namespace scipp::dataset {
      19             : 
      20             : /// Map `op` over `items`, return vector of results
      21        6368 : template <class T, class Op> auto map(const T &items, Op op) {
      22        6368 :   std::vector<std::decay_t<decltype(op(items.front()))>> out;
      23        6368 :   out.reserve(items.size());
      24       19345 :   for (const auto &i : items)
      25       12977 :     out.emplace_back(op(i));
      26        6368 :   return out;
      27           0 : }
      28             : 
      29             : /// Concatenate a and b, assuming that a and b contain bin edges.
      30             : ///
      31             : /// Checks that the last edges in `a` match the first edges in `b`. The
      32             : /// Concatenates the input edges, removing duplicate bin edges.
      33          11 : Variable join_edges(const scipp::span<const Variable> vars, const Dim dim) {
      34          11 :   std::vector<Variable> tmp;
      35          11 :   tmp.reserve(vars.size());
      36          31 :   for (const auto &var : vars) {
      37          23 :     if (tmp.empty()) {
      38          11 :       tmp.emplace_back(var);
      39             :     } else {
      40          45 :       core::expect::equals(tmp.back().slice({dim, tmp.back().dims()[dim] - 1}),
      41          27 :                            var.slice({dim, 0}));
      42           9 :       tmp.emplace_back(var.slice({dim, 1, var.dims()[dim]}));
      43             :     }
      44             :   }
      45          16 :   return concat(tmp, dim);
      46          11 : }
      47             : 
      48             : namespace {
      49             : template <class T, class Key>
      50        1183 : bool equal_is_edges(const T &maps, const Key &key, const Dim dim) {
      51        1183 :   return std::adjacent_find(maps.begin(), maps.end(),
      52        2442 :                             [&key, dim](auto &a, auto &b) {
      53        1221 :                               return is_edges(a.sizes(), a[key].dims(), dim) !=
      54        1221 :                                      is_edges(b.sizes(), b[key].dims(), dim);
      55        2366 :                             }) == maps.end();
      56             : }
      57             : template <class T, class Key>
      58        1176 : bool all_is_edges(const T &maps, const Key &key, const Dim dim) {
      59        2366 :   return std::all_of(maps.begin(), maps.end(), [&key, dim](auto &var) {
      60        1190 :     return is_edges(var.sizes(), var[key].dims(), dim);
      61        1176 :   });
      62             : }
      63             : 
      64             : template <class T, class Key>
      65          58 : auto broadcast_along_dim(const T &maps, const Key &key, const Dim dim) {
      66         127 :   return map(maps, [&key, dim](const auto &map) {
      67         127 :     const auto &var = map[key];
      68         254 :     return broadcast(var, merge(Dimensions(dim, map.sizes().contains(dim)
      69         127 :                                                     ? map.sizes().at(dim)
      70             :                                                     : 1),
      71         381 :                                 var.dims()));
      72          58 :   });
      73             : }
      74             : 
      75        3487 : template <class Maps> auto concat_maps(const Maps &maps, const Dim dim) {
      76        3487 :   if (maps.empty())
      77           1 :     throw std::invalid_argument("Cannot concat empty list.");
      78             :   using T = typename Maps::value_type;
      79        3486 :   core::Dict<typename T::key_type, typename T::mapped_type> out;
      80        3486 :   const auto &a = maps.front();
      81        5089 :   for (const auto &[key, a_] : a) {
      82        4921 :     auto vars = map(maps, [&key = key](auto &&map) { return map[key]; });
      83        1603 :     if (a.dim_of(key) == dim) {
      84        1183 :       if (!equal_is_edges(maps, key, dim)) {
      85           7 :         throw except::BinEdgeError(
      86             :             "Either all or none of the inputs must have bin edge coordinates.");
      87        1176 :       } else if (!all_is_edges(maps, key, dim)) {
      88        1165 :         out.insert_or_assign(key, concat(vars, dim));
      89             :       } else {
      90          11 :         out.insert_or_assign(key, join_edges(vars, dim));
      91             :       }
      92             :     } else {
      93             :       // 1D coord is kept only if all inputs have matching 1D coords.
      94        2169 :       if (std::any_of(vars.begin(), vars.end(), [dim, &vars](auto &var) {
      95         894 :             return var.dims().contains(dim) || !equals_nan(var, vars.front());
      96             :           })) {
      97             :         // Mismatching 1D coords must be broadcast to ensure new coord shape
      98             :         // matches new data shape.
      99          58 :         out.insert_or_assign(key,
     100             :                              concat(broadcast_along_dim(maps, key, dim), dim));
     101             :       } else {
     102             :         if constexpr (std::is_same_v<T, Masks>)
     103         130 :           out.insert_or_assign(key, copy(a_));
     104             :         else
     105         232 :           out.insert_or_assign(key, a_);
     106             :       }
     107             :     }
     108             :   }
     109        3476 :   return out;
     110          10 : }
     111             : 
     112             : } // namespace
     113             : 
     114        1164 : DataArray concat(const scipp::span<const DataArray> das, const Dim dim) {
     115        2326 :   auto out = DataArray(concat(map(das, get_data), dim), {},
     116        3493 :                        concat_maps(map(das, get_masks), dim));
     117        1163 :   const auto &coords = map(das, get_coords);
     118        2392 :   for (auto &&[d, coord] : concat_maps(coords, dim)) {
     119        2377 :     coord.set_aligned(d == dim ||
     120        1148 :                       std::any_of(coords.begin(), coords.end(),
     121        1148 :                                   [&d = d](auto &_) { return _.contains(d); }));
     122        1229 :     out.coords().set(d, std::move(coord));
     123        1153 :   }
     124        1313 :   for (auto &&[d, attr] : concat_maps(map(das, get_attrs), dim)) {
     125         160 :     out.attrs().set(d, std::move(attr));
     126        1153 :   }
     127        2306 :   return out;
     128        1173 : }
     129             : 
     130          48 : Dataset concat(const scipp::span<const Dataset> dss, const Dim dim) {
     131          48 :   if (dss.empty())
     132           1 :     throw std::invalid_argument("Cannot concat empty list.");
     133          47 :   Dataset result;
     134          98 :   for (const auto &first : dss.front())
     135          53 :     if (std::all_of(dss.begin(), dss.end(),
     136         110 :                     [&first](auto &ds) { return ds.contains(first.name()); })) {
     137         154 :       auto das = map(dss, [&first](auto &&ds) { return ds[first.name()]; });
     138          50 :       result.setDataInit(first.name(), concat(das, dim));
     139         103 :     }
     140          45 :   if (result.is_valid())
     141          38 :     return result;
     142          14 :   return Dataset({}, Coords(concat(map(dss, get_sizes), dim),
     143          21 :                             concat_maps(map(dss, get_coords), dim)));
     144          47 : }
     145             : 
     146          78 : DataArray resize(const DataArray &a, const Dim dim, const scipp::index size,
     147             :                  const FillValue fill) {
     148             :   return apply_to_data_and_drop_dim(
     149         156 :       a, [](auto &&..._) { return resize(_...); }, dim, size, fill);
     150             : }
     151             : 
     152           3 : Dataset resize(const Dataset &d, const Dim dim, const scipp::index size,
     153             :                const FillValue fill) {
     154             :   return apply_to_items(
     155           6 :       d, [](auto &&..._) { return resize(_...); }, dim, size, fill);
     156             : }
     157             : 
     158           0 : DataArray resize(const DataArray &a, const Dim dim, const DataArray &shape) {
     159             :   return apply_to_data_and_drop_dim(
     160           0 :       a, [](auto &&v, const Dim, auto &&s) { return resize(v, s); }, dim,
     161           0 :       shape.data());
     162             : }
     163             : 
     164           0 : Dataset resize(const Dataset &d, const Dim dim, const Dataset &shape) {
     165           0 :   Dataset result;
     166           0 :   for (const auto &data : d)
     167           0 :     result.setData(data.name(), resize(data, dim, shape[data.name()]));
     168           0 :   return result;
     169           0 : }
     170             : 
     171             : namespace {
     172             : 
     173             : /// Either broadcast variable to from_dims before a reshape or not:
     174             : ///
     175             : /// 1. If all from_dims are contained in the variable's dims, no broadcast
     176             : /// 2. If at least one (but not all) of the from_dims is contained in the
     177             : ///    variable's dims, broadcast
     178             : /// 3. If none of the variables's dimensions are contained, no broadcast
     179        7398 : Variable maybe_broadcast(const Variable &var,
     180             :                          const scipp::span<const Dim> &from_labels,
     181             :                          const Dimensions &data_dims) {
     182        7398 :   const auto &var_dims = var.dims();
     183        7398 :   Dimensions broadcast_dims;
     184       19001 :   for (const auto &dim : var_dims.labels())
     185       23206 :     if (std::find(from_labels.begin(), from_labels.end(), dim) ==
     186       11603 :         from_labels.end())
     187          55 :       broadcast_dims.addInner(dim, var_dims[dim]);
     188             :     else
     189       43763 :       for (const auto &lab : from_labels)
     190       32215 :         if (!broadcast_dims.contains(lab)) {
     191             :           // Need to check if the variable contains that dim, and use the
     192             :           // variable shape in case we have a bin edge.
     193       19595 :           if (var_dims.contains(lab))
     194       11548 :             broadcast_dims.addInner(lab, var_dims[lab]);
     195             :           else
     196        8047 :             broadcast_dims.addInner(lab, data_dims[lab]);
     197             :         }
     198       14796 :   return broadcast(var, broadcast_dims);
     199        7398 : }
     200             : 
     201             : /// Special handling for folding coord along a dim that contains bin edges.
     202           4 : Variable fold_bin_edge(const Variable &var, const Dim from_dim,
     203             :                        const Dimensions &to_dims) {
     204           4 :   auto out = var.slice({from_dim, 0, var.dims()[from_dim] - 1})
     205           8 :                  .fold(from_dim, to_dims) // fold non-overlapping part
     206           4 :                  .as_const();             // mark readonly since we add overlap
     207             :   // Increase dims without changing strides to obtain first == last
     208           4 :   out.unchecked_dims().resize(to_dims.inner(), to_dims[to_dims.inner()] + 1);
     209           4 :   return out;
     210           0 : }
     211             : 
     212             : /// Special handling for flattening coord along a dim that contains bin edges.
     213          48 : Variable flatten_bin_edge(const Variable &var,
     214             :                           const scipp::span<const Dim> &from_labels,
     215             :                           const Dim to_dim, const Dim bin_edge_dim) {
     216          48 :   const auto data_shape = var.dims()[bin_edge_dim] - 1;
     217             : 
     218             :   // Make sure that the bin edges can be joined together
     219          48 :   const auto front = var.slice({bin_edge_dim, 0});
     220          48 :   const auto back = var.slice({bin_edge_dim, data_shape});
     221          48 :   const auto front_flat = flatten(front, front.dims().labels(), to_dim);
     222          48 :   const auto back_flat = flatten(back, back.dims().labels(), to_dim);
     223          48 :   if (front_flat.slice({to_dim, 1, front.dims().volume()}) !=
     224          96 :       back_flat.slice({to_dim, 0, back.dims().volume() - 1}))
     225          34 :     return {};
     226             : 
     227             :   // Make the bulk slice of the coord, leaving out the last bin edge
     228             :   const auto bulk =
     229          14 :       flatten(var.slice({bin_edge_dim, 0, data_shape}), from_labels, to_dim);
     230          14 :   auto out_dims = bulk.dims();
     231             :   // To make the container of the right size, we increase to_dim by 1
     232          14 :   out_dims.resize(to_dim, out_dims[to_dim] + 1);
     233          14 :   auto out = empty(out_dims, var.unit(), var.dtype(), var.has_variances());
     234          14 :   copy(bulk, out.slice({to_dim, 0, out_dims[to_dim] - 1}));
     235          14 :   copy(back_flat.slice({to_dim, back.dims().volume() - 1}),
     236          28 :        out.slice({to_dim, out_dims[to_dim] - 1}));
     237          14 :   return out;
     238          48 : }
     239             : 
     240             : /// Check if one of the from_labels is a bin edge
     241        9685 : Dim bin_edge_in_from_labels(const Variable &var, const Dimensions &array_dims,
     242             :                             const scipp::span<const Dim> &from_labels) {
     243       33974 :   for (const auto &dim : from_labels)
     244       24337 :     if (is_edges(array_dims, var.dims(), dim))
     245          48 :       return dim;
     246        9637 :   return Dim::Invalid;
     247             : }
     248             : 
     249             : } // end anonymous namespace
     250             : 
     251             : /// Fold a single dimension into multiple dimensions
     252             : /// ['x': 6] -> ['y': 2, 'z': 3]
     253          23 : DataArray fold(const DataArray &a, const Dim from_dim,
     254             :                const Dimensions &to_dims) {
     255          69 :   return dataset::transform(a, [&](const auto &var) {
     256          69 :     if (is_edges(a.dims(), var.dims(), from_dim))
     257           4 :       return fold_bin_edge(var, from_dim, to_dims);
     258          65 :     else if (var.dims().contains(from_dim))
     259          49 :       return fold(var, from_dim, to_dims);
     260             :     else
     261          16 :       return var;
     262          23 :   });
     263             : }
     264             : 
     265             : namespace {
     266        2290 : void expect_dimension_subset(const core::Dimensions &full_set,
     267             :                              const scipp::span<const Dim> &subset) {
     268        6974 :   for (const auto &dim : subset) {
     269        4687 :     if (!full_set.contains(dim)) {
     270           9 :       throw except::DimensionError{"Expected dimension " + to_string(dim) +
     271          12 :                                    "to be in " + to_string(full_set)};
     272             :     }
     273             :   }
     274        2287 : }
     275             : } // namespace
     276             : 
     277             : /// Flatten multiple dimensions into a single dimension:
     278             : /// ['y', 'z'] -> ['x']
     279        2292 : DataArray flatten(const DataArray &a,
     280             :                   const std::optional<scipp::span<const Dim>> &from_labels,
     281             :                   const Dim to_dim) {
     282        2292 :   const auto &labels = from_labels.value_or(a.dims().labels());
     283        2292 :   if (from_labels.has_value() && labels.empty())
     284           8 :     return DataArray(flatten(a.data(), labels, to_dim), a.coords(), a.masks(),
     285          10 :                      a.attrs());
     286        2290 :   expect_dimension_subset(a.dims(), labels);
     287        9685 :   return dataset::transform(a, [&](const auto &in) {
     288        9685 :     auto var = (&in == &a.data()) ? in : maybe_broadcast(in, labels, a.dims());
     289        9685 :     const auto bin_edge_dim = bin_edge_in_from_labels(in, a.dims(), labels);
     290        9685 :     if (bin_edge_dim != Dim::Invalid) {
     291          48 :       return flatten_bin_edge(var, labels, to_dim, bin_edge_dim);
     292        9637 :     } else if (a.dims().empty() || var.dims().contains(labels.front())) {
     293             :       // maybe_broadcast ensures that all variables contain
     294             :       // all dims in labels, so only need to check labels.front().
     295        9568 :       return flatten(var, labels, to_dim);
     296             :     } else {
     297             :       // This only happens for metadata.
     298          69 :       return var;
     299             :     }
     300       11972 :   });
     301             : }
     302             : 
     303        1336 : DataArray transpose(const DataArray &a, const scipp::span<const Dim> dims) {
     304        5344 :   return {transpose(a.data(), dims), a.coords(), a.masks(), a.attrs(),
     305        5344 :           a.name()};
     306             : }
     307             : 
     308           2 : Dataset transpose(const Dataset &d, const scipp::span<const Dim> dims) {
     309           6 :   return apply_to_items(d, [](auto &&..._) { return transpose(_...); }, dims);
     310             : }
     311             : 
     312             : namespace {
     313             : template <class T>
     314          37 : T squeeze_impl(const T &x, const std::optional<scipp::span<const Dim>> dims) {
     315          37 :   auto squeezed = x;
     316          77 :   for (const auto &dim : dims_for_squeezing(x.dims(), dims)) {
     317          40 :     squeezed = squeezed.slice({dim, 0});
     318             :   }
     319             :   // Copy explicitly to make sure the output does not have its read-only flag
     320             :   // set.
     321          70 :   return T(squeezed);
     322          37 : }
     323             : } // namespace
     324          32 : DataArray squeeze(const DataArray &a,
     325             :                   const std::optional<scipp::span<const Dim>> dims) {
     326          32 :   return squeeze_impl(a, dims);
     327             : }
     328             : 
     329           5 : Dataset squeeze(const Dataset &d,
     330             :                 const std::optional<scipp::span<const Dim>> dims) {
     331           5 :   return squeeze_impl(d, dims);
     332             : }
     333             : 
     334             : } // namespace scipp::dataset

Generated by: LCOV version 1.14