LCOV - code coverage report
Current view: top level - dataset - extract.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 59 62 95.2 %
Date: 2024-04-28 01:25:40 Functions: 21 21 100.0 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #include <numeric>
       6             : 
       7             : #include "scipp/variable/variable_factory.h"
       8             : 
       9             : #include "scipp/dataset/bins.h"
      10             : #include "scipp/dataset/except.h"
      11             : #include "scipp/dataset/extract.h"
      12             : #include "scipp/dataset/util.h"
      13             : 
      14             : namespace scipp {
      15             : 
      16             : namespace {
      17             : 
      18             : /// Transform data of data array or dataset, coord, masks, and and attrs are
      19             : /// shallow-copied.
      20             : ///
      21             : /// Beware of the mask-copy behavior, which is not suitable for data returned to
      22             : /// the user.
      23             : template <class T, class Func, class... Ts>
      24         964 : T transform_data(const T &obj, Func func, const Ts &...other) {
      25         964 :   T out(obj);
      26             :   if constexpr (std::is_same_v<T, Variable>) {
      27         720 :     return func(obj, other...);
      28             :   } else if constexpr (std::is_same_v<T, DataArray>) {
      29         188 :     out.setData(func(obj.data(), other.data()...));
      30             :   } else {
      31         128 :     for (const auto &item : obj)
      32          72 :       out.setData(item.name(), func(item.data(), other[item.name()].data()...),
      33             :                   dataset::AttrPolicy::Keep);
      34             :   }
      35         244 :   return out;
      36         720 : }
      37             : 
      38             : template <class Buffer>
      39         494 : Variable copy_ranges_from_buffer(const Variable &indices, const Dim dim,
      40             :                                  const Buffer &buffer) {
      41         494 :   return copy(make_bins_no_validate(indices, dim, buffer));
      42             : }
      43             : 
      44          12 : Variable copy_ranges_from_bins_buffer(const Variable &indices,
      45             :                                       const Variable &data) {
      46          12 :   if (data.dtype() == dtype<bucket<Variable>>) {
      47           6 :     const auto &[i, dim, buf] = data.constituents<Variable>();
      48           6 :     return copy_ranges_from_buffer(indices, dim, buf);
      49          12 :   } else if (data.dtype() == dtype<bucket<DataArray>>) {
      50           6 :     const auto &[i, dim, buf] = data.constituents<DataArray>();
      51           6 :     return copy_ranges_from_buffer(indices, dim, buf);
      52           6 :   } else {
      53           0 :     const auto &[i, dim, buf] = data.constituents<Dataset>();
      54           0 :     return copy_ranges_from_buffer(indices, dim, buf);
      55           0 :   }
      56             : }
      57             : 
      58         490 : Variable dense_or_bin_indices(const Variable &var) {
      59         490 :   return is_bins(var) ? var.bin_indices() : var;
      60             : }
      61             : 
      62         490 : Variable dense_or_copy_bin_elements(const Variable &dense_or_indices,
      63             :                                     const Variable &data) {
      64         490 :   return is_bins(data) ? copy_ranges_from_bins_buffer(dense_or_indices, data)
      65         490 :                        : dense_or_indices;
      66             : }
      67             : } // namespace
      68             : 
      69             : template <class T>
      70         482 : T extract_ranges(const Variable &indices, const T &data, const Dim dim) {
      71         482 :   T no_edges;
      72             :   if constexpr (std::is_same_v<T, Variable>)
      73         360 :     no_edges = data;
      74             :   else
      75         122 :     no_edges = strip_edges_along(data, dim);
      76             :   // 1. Operate on dense data, or equivalent array of indices (if binned) to
      77             :   // obtain output data of correct shape with proper meta data.
      78         482 :   auto dense = transform_data(no_edges, dense_or_bin_indices);
      79         482 :   auto out =
      80         482 :       copy_ranges_from_buffer(indices, dim, dense).template bin_buffer<T>();
      81             :   // 2. If we have binned data then the data of the DataArray or Dataset
      82             :   // obtained in step 1. give the indices into the underlying buffer to be
      83             :   // copied. This then replaces the data to obtain the final result. Does
      84             :   // nothing if dense data.
      85         964 :   return transform_data(out, dense_or_copy_bin_elements, no_edges);
      86         482 : }
      87             : 
      88             : namespace {
      89         417 : template <class T> T extract_impl(const T &obj, const Variable &condition) {
      90         417 :   if (condition.dtype() != dtype<bool>)
      91           1 :     throw except::TypeError(
      92             :         "Cannot extract elements based on condition with non-boolean dtype. If "
      93             :         "you intended to select a range based on a label you must specify the "
      94             :         "dimension.");
      95         416 :   if (condition.dims().ndim() != 1)
      96           6 :     throw except::DimensionError("Condition must by 1-D, but got " +
      97             :                                  to_string(condition.dims()) + '.');
      98         410 :   if (!obj.dims().includes(condition.dims()))
      99           7 :     throw except::DimensionError(
     100             :         "Condition dimensions " + to_string(condition.dims()) +
     101             :         " must be be included in the dimensions of the sliced object " +
     102             :         to_string(obj.dims()) + '.');
     103             : 
     104         403 :   auto values = condition.values<bool>().as_span();
     105         403 :   std::vector<scipp::index_pair> indices;
     106        2684 :   for (scipp::index i = 0; i < scipp::size(values); ++i) {
     107        2281 :     if (i > 0 && values[i - 1] == values[i])
     108        1400 :       continue;    // not an edge
     109         881 :     if (values[i]) // rising edge
     110         529 :       indices.emplace_back(i, scipp::size(values));
     111         352 :     else if (i != 0) // falling edge
     112         229 :       indices.back().second = i;
     113             :   }
     114         403 :   return extract_ranges(makeVariable<scipp::index_pair>(Dims{condition.dim()},
     115         806 :                                                         Shape{indices.size()},
     116         806 :                                                         Values(indices)),
     117        1209 :                         obj, condition.dim());
     118         403 : }
     119             : } // namespace
     120             : 
     121         337 : Variable extract(const Variable &var, const Variable &condition) {
     122         337 :   return extract_impl(var, condition);
     123             : }
     124             : 
     125          74 : DataArray extract(const DataArray &da, const Variable &condition) {
     126          74 :   return extract_impl(da, condition);
     127             : }
     128             : 
     129           6 : Dataset extract(const Dataset &ds, const Variable &condition) {
     130           6 :   return extract_impl(ds, condition);
     131             : }
     132             : 
     133             : template SCIPP_DATASET_EXPORT Variable extract_ranges(const Variable &,
     134             :                                                       const Variable &,
     135             :                                                       const Dim);
     136             : template SCIPP_DATASET_EXPORT DataArray extract_ranges(const Variable &,
     137             :                                                        const DataArray &,
     138             :                                                        const Dim);
     139             : template SCIPP_DATASET_EXPORT Dataset extract_ranges(const Variable &,
     140             :                                                      const Dataset &,
     141             :                                                      const Dim);
     142             : 
     143             : } // namespace scipp

Generated by: LCOV version 1.14