Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #include "scipp/dataset/sort.h"
6 : #include "scipp/core/parallel.h"
7 : #include "scipp/core/tag_util.h"
8 : #include "scipp/dataset/extract.h"
9 :
10 : namespace scipp::dataset {
11 :
12 : namespace {
13 :
14 62 : constexpr auto nan_sensitive_less = [](const auto &a, const auto &b) {
15 : if constexpr (std::is_floating_point_v<std::decay_t<decltype(a)>>)
16 45 : if (std::isnan(b))
17 13 : return !std::isnan(a);
18 49 : return a < b;
19 : };
20 :
21 : template <class T> struct IndicesForSorting {
22 15 : static Variable apply(const Variable &key, const SortOrder order) {
23 15 : const auto size = key.dims()[key.dim()];
24 15 : const auto values = key.values<T>();
25 15 : std::vector<std::pair<T, scipp::index>> key_index;
26 15 : key_index.reserve(size);
27 :
28 : {
29 15 : scipp::index i = 0;
30 81 : for (const auto &value : key.values<T>())
31 51 : key_index.emplace_back(value, i++);
32 : }
33 :
34 15 : if (order == SortOrder::Ascending)
35 9 : core::parallel::parallel_sort(
36 50 : key_index.begin(), key_index.end(), [](const auto &a, const auto &b) {
37 32 : return nan_sensitive_less(a.first, b.first);
38 : });
39 : else
40 6 : core::parallel::parallel_sort(
41 42 : key_index.begin(), key_index.end(), [](const auto &a, const auto &b) {
42 30 : return nan_sensitive_less(b.first, a.first);
43 : });
44 :
45 15 : auto indices =
46 30 : makeVariable<scipp::index_pair>(Dims{key.dim()}, Shape{size});
47 15 : std::transform(key_index.begin(), key_index.end(),
48 30 : indices.values<scipp::index_pair>().as_span().begin(),
49 51 : [](const auto &item) {
50 51 : return std::pair{item.second, item.second + 1};
51 : });
52 30 : return indices;
53 15 : }
54 : };
55 :
56 15 : Variable indices_for_sorting(const Variable &key, const SortOrder order) {
57 : return core::CallDType<
58 : double, float, int64_t, int32_t, bool, std::string,
59 15 : core::time_point>::apply<IndicesForSorting>(key.dtype(), key, order);
60 : }
61 :
62 14 : void require_same_shape(const Dimensions &var_dims, const Dimensions &key_dims,
63 : const Dim dim) {
64 14 : if (var_dims[dim] != key_dims[dim])
65 2 : throw except::DimensionError(
66 4 : "Cannot sort: key for dimension " + to_string(dim) + " has length " +
67 8 : std::to_string(key_dims[dim]) + " while variable has length " +
68 8 : std::to_string(var_dims[dim]) + ". Lengths must agree.");
69 12 : }
70 :
71 : } // namespace
72 :
73 : /// Return a Variable sorted based on key.
74 11 : Variable sort(const Variable &var, const Variable &key, const SortOrder order) {
75 11 : require_same_shape(var.dims(), key.dims(), key.dim());
76 20 : return extract_ranges(indices_for_sorting(key, order), var, key.dim());
77 : }
78 :
79 : /// Return a DataArray sorted based on key.
80 3 : DataArray sort(const DataArray &array, const Variable &key,
81 : const SortOrder order) {
82 3 : require_same_shape(array.dims(), key.dims(), key.dim());
83 4 : return extract_ranges(indices_for_sorting(key, order), array, key.dim());
84 : }
85 :
86 : /// Return a DataArray sorted based on coordinate.
87 3 : DataArray sort(const DataArray &array, const Dim &key, const SortOrder order) {
88 4 : return sort(array, array.meta()[key], order);
89 : }
90 :
91 : /// Return a Dataset sorted based on key.
92 3 : Dataset sort(const Dataset &dataset, const Variable &key,
93 : const SortOrder order) {
94 6 : return extract_ranges(indices_for_sorting(key, order), dataset, key.dim());
95 : }
96 :
97 : /// Return a Dataset sorted based on coordinate.
98 0 : Dataset sort(const Dataset &dataset, const Dim &key, const SortOrder order) {
99 0 : return sort(dataset, dataset.coords()[key], order);
100 : }
101 :
102 : } // namespace scipp::dataset
|