LCOV - code coverage report
Current view: top level - dataset - sort.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 46 48 95.8 %
Date: 2024-12-01 01:56:34 Functions: 20 42 47.6 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #include "scipp/dataset/sort.h"
       6             : #include "scipp/core/parallel.h"
       7             : #include "scipp/core/tag_util.h"
       8             : #include "scipp/dataset/extract.h"
       9             : 
      10             : namespace scipp::dataset {
      11             : 
      12             : namespace {
      13             : 
      14          62 : constexpr auto nan_sensitive_less = [](const auto &a, const auto &b) {
      15             :   if constexpr (std::is_floating_point_v<std::decay_t<decltype(a)>>)
      16          45 :     if (std::isnan(b))
      17          13 :       return !std::isnan(a);
      18          49 :   return a < b;
      19             : };
      20             : 
      21             : template <class T> struct IndicesForSorting {
      22          15 :   static Variable apply(const Variable &key, const SortOrder order) {
      23          15 :     const auto size = key.dims()[key.dim()];
      24          15 :     const auto values = key.values<T>();
      25          15 :     std::vector<std::pair<T, scipp::index>> key_index;
      26          15 :     key_index.reserve(size);
      27             : 
      28             :     {
      29          15 :       scipp::index i = 0;
      30          81 :       for (const auto &value : key.values<T>())
      31          51 :         key_index.emplace_back(value, i++);
      32             :     }
      33             : 
      34          15 :     if (order == SortOrder::Ascending)
      35           9 :       core::parallel::parallel_sort(
      36          50 :           key_index.begin(), key_index.end(), [](const auto &a, const auto &b) {
      37          32 :             return nan_sensitive_less(a.first, b.first);
      38             :           });
      39             :     else
      40           6 :       core::parallel::parallel_sort(
      41          42 :           key_index.begin(), key_index.end(), [](const auto &a, const auto &b) {
      42          30 :             return nan_sensitive_less(b.first, a.first);
      43             :           });
      44             : 
      45          15 :     auto indices =
      46          30 :         makeVariable<scipp::index_pair>(Dims{key.dim()}, Shape{size});
      47          15 :     std::transform(key_index.begin(), key_index.end(),
      48          30 :                    indices.values<scipp::index_pair>().as_span().begin(),
      49          51 :                    [](const auto &item) {
      50          51 :                      return std::pair{item.second, item.second + 1};
      51             :                    });
      52          30 :     return indices;
      53          15 :   }
      54             : };
      55             : 
      56          15 : Variable indices_for_sorting(const Variable &key, const SortOrder order) {
      57             :   return core::CallDType<
      58             :       double, float, int64_t, int32_t, bool, std::string,
      59          15 :       core::time_point>::apply<IndicesForSorting>(key.dtype(), key, order);
      60             : }
      61             : 
      62          14 : void require_same_shape(const Dimensions &var_dims, const Dimensions &key_dims,
      63             :                         const Dim dim) {
      64          14 :   if (var_dims[dim] != key_dims[dim])
      65           2 :     throw except::DimensionError(
      66           4 :         "Cannot sort: key for dimension " + to_string(dim) + " has length " +
      67           8 :         std::to_string(key_dims[dim]) + " while variable has length " +
      68           8 :         std::to_string(var_dims[dim]) + ". Lengths must agree.");
      69          12 : }
      70             : 
      71             : } // namespace
      72             : 
      73             : /// Return a Variable sorted based on key.
      74          11 : Variable sort(const Variable &var, const Variable &key, const SortOrder order) {
      75          11 :   require_same_shape(var.dims(), key.dims(), key.dim());
      76          20 :   return extract_ranges(indices_for_sorting(key, order), var, key.dim());
      77             : }
      78             : 
      79             : /// Return a DataArray sorted based on key.
      80           3 : DataArray sort(const DataArray &array, const Variable &key,
      81             :                const SortOrder order) {
      82           3 :   require_same_shape(array.dims(), key.dims(), key.dim());
      83           4 :   return extract_ranges(indices_for_sorting(key, order), array, key.dim());
      84             : }
      85             : 
      86             : /// Return a DataArray sorted based on coordinate.
      87           3 : DataArray sort(const DataArray &array, const Dim &key, const SortOrder order) {
      88           4 :   return sort(array, array.meta()[key], order);
      89             : }
      90             : 
      91             : /// Return a Dataset sorted based on key.
      92           3 : Dataset sort(const Dataset &dataset, const Variable &key,
      93             :              const SortOrder order) {
      94           6 :   return extract_ranges(indices_for_sorting(key, order), dataset, key.dim());
      95             : }
      96             : 
      97             : /// Return a Dataset sorted based on coordinate.
      98           0 : Dataset sort(const Dataset &dataset, const Dim &key, const SortOrder order) {
      99           0 :   return sort(dataset, dataset.coords()[key], order);
     100             : }
     101             : 
     102             : } // namespace scipp::dataset

Generated by: LCOV version 1.14