LCOV - code coverage report
Current view: top level - variable - shape.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 95 98 96.9 %
Date: 2024-11-17 01:47:58 Functions: 10 11 90.9 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #include <algorithm>
       6             : 
       7             : #include "scipp/core/dimensions.h"
       8             : 
       9             : #include "scipp/variable/arithmetic.h"
      10             : #include "scipp/variable/bins.h"
      11             : #include "scipp/variable/creation.h"
      12             : #include "scipp/variable/except.h"
      13             : #include "scipp/variable/shape.h"
      14             : #include "scipp/variable/util.h"
      15             : #include "scipp/variable/variable_concept.h"
      16             : #include "scipp/variable/variable_factory.h"
      17             : 
      18             : using namespace scipp::core;
      19             : 
      20             : namespace scipp::variable {
      21             : 
      22       38370 : Variable broadcast(const Variable &var, const Dimensions &dims) {
      23       38370 :   return var.broadcast(dims);
      24             : }
      25             : 
      26             : namespace {
      27          34 : auto get_bin_sizes(const scipp::span<const Variable> vars) {
      28          34 :   std::vector<Variable> sizes;
      29          34 :   sizes.reserve(vars.size());
      30         110 :   for (const auto &var : vars)
      31          76 :     sizes.emplace_back(bin_sizes(var));
      32          34 :   return sizes;
      33           0 : }
      34             : } // namespace
      35             : 
      36        2605 : Variable concat(const scipp::span<const Variable> vars, const Dim dim) {
      37        2605 :   if (vars.empty())
      38           1 :     throw std::invalid_argument("Cannot concat empty list.");
      39             :   const auto it =
      40        2604 :       std::find_if(vars.begin(), vars.end(),
      41        2828 :                    [dim](const auto &var) { return var.dims().contains(dim); });
      42        2604 :   Dimensions dims;
      43             :   // Expand dims for inputs that do not contain dim already. Favor order given
      44             :   // by first input, if not found add as outer dim.
      45        2604 :   if (it == vars.end()) {
      46         129 :     dims = vars.front().dims();
      47         129 :     dims.add(dim, 1);
      48             :   } else {
      49        2475 :     dims = it->dims();
      50        2475 :     dims.resize(dim, 1);
      51             :   }
      52        2604 :   std::vector<Variable> tmp;
      53        2604 :   scipp::index size = 0;
      54        7925 :   for (const auto &var : vars) {
      55        5322 :     if (var.dims().contains(dim))
      56        4949 :       tmp.emplace_back(var);
      57             :     else
      58         373 :       tmp.emplace_back(broadcast(var, dims));
      59        5321 :     size += tmp.back().dims()[dim];
      60             :   }
      61        2603 :   dims.resize(dim, size);
      62        2603 :   Variable out;
      63        2603 :   if (is_bins(vars.front())) {
      64          34 :     out = empty_like(vars.front(), {}, concat(get_bin_sizes(vars), dim));
      65             :   } else {
      66        2569 :     out = empty_like(vars.front(), dims);
      67             :   }
      68        2603 :   scipp::index offset = 0;
      69        7899 :   for (const auto &var : tmp) {
      70        5320 :     const auto extent = var.dims()[dim];
      71        5344 :     out.data().copy(var, out.slice({dim, offset, offset + extent}));
      72        5296 :     offset += extent;
      73             :   }
      74        5158 :   return out;
      75        2653 : }
      76             : 
      77          78 : Variable resize(const Variable &var, const Dim dim, const scipp::index size,
      78             :                 const FillValue fill) {
      79          78 :   auto dims = var.dims();
      80          78 :   dims.resize(dim, size);
      81         234 :   return special_like(broadcast(Variable(var, Dimensions{}), dims), fill);
      82          78 : }
      83             : 
      84             : /// Return new variable resized to given shape.
      85             : ///
      86             : /// For bucket variables the values of `shape` are interpreted as bucket sizes
      87             : /// to RESERVE and the buffer is also resized accordingly. The emphasis is on
      88             : /// "reserve", i.e., buffer size and begin indices are set up accordingly, but
      89             : /// end=begin is set, i.e., the buckets are empty, but may be grown up to the
      90             : /// requested size. For normal (non-bucket) variable the values of `shape` are
      91             : /// ignored, i.e., only `shape.dims()` is used to determine the shape of the
      92             : /// output.
      93           0 : Variable resize(const Variable &var, const Variable &shape) {
      94           0 :   return {shape.dims(), var.data().makeDefaultFromParent(shape)};
      95             : }
      96             : 
      97         255 : Variable fold(const Variable &view, const Dim from_dim,
      98             :               const Dimensions &to_dims) {
      99         255 :   return view.fold(from_dim, to_dims);
     100             : }
     101             : 
     102       14411 : Variable flatten(const Variable &view,
     103             :                  const scipp::span<const Dim> &from_labels, const Dim to_dim) {
     104       14411 :   if (from_labels.empty()) {
     105          31 :     auto out(view);
     106          31 :     out.unchecked_dims().addInner(to_dim, 1);
     107          31 :     out.unchecked_strides().push_back(1);
     108          31 :     return out;
     109          31 :   }
     110       14380 :   const auto &labels = view.dims().labels();
     111       14380 :   auto it = std::search(labels.begin(), labels.end(), from_labels.begin(),
     112             :                         from_labels.end());
     113       14380 :   if (it == labels.end())
     114           3 :     throw except::DimensionError("Can only flatten a contiguous set of "
     115           6 :                                  "dimensions in the correct order");
     116       14377 :   scipp::index size = 1;
     117       14377 :   auto to = std::distance(labels.begin(), it);
     118       14377 :   auto out(view);
     119       41230 :   for (const auto &from : from_labels) {
     120       31066 :     size *= out.dims().size(to);
     121       31066 :     if (from == from_labels.back()) {
     122       10164 :       out.unchecked_dims().resize(from, size);
     123       10164 :       out.unchecked_dims().replace_key(from, to_dim);
     124             :     } else {
     125       20902 :       if (out.strides()[to] != out.dims().size(to + 1) * out.strides()[to + 1])
     126        8426 :         return flatten(copy(view), from_labels, to_dim);
     127       16689 :       out.unchecked_dims().erase(from);
     128       16689 :       out.unchecked_strides().erase(to);
     129             :     }
     130             :   }
     131       10164 :   return out;
     132       14377 : }
     133             : 
     134        6550 : Variable transpose(const Variable &var, const scipp::span<const Dim> dims) {
     135        6550 :   return var.transpose(dims);
     136             : }
     137             : 
     138             : std::vector<scipp::Dim>
     139        9979 : dims_for_squeezing(const core::Sizes &data_dims,
     140             :                    const std::optional<scipp::span<const Dim>> selected_dims) {
     141        9979 :   if (selected_dims.has_value()) {
     142       11258 :     for (const auto &dim : *selected_dims) {
     143        2756 :       if (const auto size = data_dims[dim]; size != 1)
     144          15 :         throw except::DimensionError("Cannot squeeze '" + to_string(dim) +
     145          20 :                                      "' of length " + std::to_string(size) +
     146          10 :                                      ", must be of length 1.");
     147             :     }
     148        8502 :     return std::vector<Dim>{selected_dims->begin(), selected_dims->end()};
     149             :   } else {
     150        1472 :     std::vector<Dim> length_1_dims;
     151        1472 :     length_1_dims.reserve(data_dims.size());
     152        2982 :     for (const auto &dim : data_dims) {
     153        1510 :       if (data_dims[dim] == 1) {
     154         170 :         length_1_dims.push_back(dim);
     155             :       }
     156             :     }
     157        1472 :     return length_1_dims;
     158        1472 :   }
     159             : }
     160             : 
     161        9942 : Variable squeeze(const Variable &var,
     162             :                  const std::optional<scipp::span<const Dim>> dims) {
     163        9942 :   auto squeezed = var;
     164       12821 :   for (const auto &dim : dims_for_squeezing(var.dims(), dims)) {
     165        2879 :     squeezed = squeezed.slice({dim, 0});
     166        9939 :   }
     167        9939 :   return squeezed;
     168           3 : }
     169             : 
     170             : } // namespace scipp::variable

Generated by: LCOV version 1.14