Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #pragma once
6 :
7 : #include <cmath>
8 : #include <limits>
9 : #include <numeric>
10 :
11 : #include "scipp/common/overloaded.h"
12 : #include "scipp/core/eigen.h"
13 : #include "scipp/core/element/arg_list.h"
14 : #include "scipp/core/element/util.h"
15 : #include "scipp/core/histogram.h"
16 : #include "scipp/core/subbin_sizes.h"
17 : #include "scipp/core/time_point.h"
18 : #include "scipp/core/transform_common.h"
19 :
20 : namespace scipp::core::element {
21 :
22 : template <class Index, class Coord, class Edges = Coord>
23 : using update_indices_by_binning_arg =
24 : std::tuple<Index, Coord, scipp::span<const Edges>>;
25 :
26 : static constexpr auto update_indices_by_binning = overloaded{
27 : element::arg_list<update_indices_by_binning_arg<int64_t, double>,
28 : update_indices_by_binning_arg<int32_t, double>,
29 : update_indices_by_binning_arg<int64_t, float>,
30 : update_indices_by_binning_arg<int32_t, float>,
31 : update_indices_by_binning_arg<int64_t, int64_t>,
32 : update_indices_by_binning_arg<int32_t, int64_t>,
33 : update_indices_by_binning_arg<int64_t, int32_t>,
34 : update_indices_by_binning_arg<int32_t, int32_t>,
35 : update_indices_by_binning_arg<int64_t, time_point>,
36 : update_indices_by_binning_arg<int32_t, time_point>,
37 : update_indices_by_binning_arg<int64_t, int64_t, double>,
38 : update_indices_by_binning_arg<int32_t, int64_t, double>,
39 : update_indices_by_binning_arg<int64_t, int32_t, double>,
40 : update_indices_by_binning_arg<int32_t, int32_t, double>,
41 : update_indices_by_binning_arg<int64_t, float, double>,
42 : update_indices_by_binning_arg<int32_t, float, double>,
43 : update_indices_by_binning_arg<int64_t, double, float>,
44 : update_indices_by_binning_arg<int32_t, double, float>,
45 : update_indices_by_binning_arg<int64_t, int32_t, int64_t>,
46 : update_indices_by_binning_arg<int32_t, int32_t, int64_t>>,
47 9346 : [](units::Unit &indices, const units::Unit &coord,
48 : const units::Unit &groups) {
49 9346 : expect::equals(coord, groups);
50 9346 : expect::equals(units::none, indices);
51 9346 : },
52 : transform_flags::expect_no_variance_arg<1>,
53 : transform_flags::expect_no_variance_arg<2>};
54 :
55 : // Special faster implementation for linear bins.
56 : static constexpr auto update_indices_by_binning_linspace = overloaded{
57 : update_indices_by_binning,
58 22843591 : [](auto &index, const auto &x, const auto &edges) {
59 22843591 : if (index == -1)
60 1600 : return;
61 : using Index = std::decay_t<decltype(index)>;
62 22841991 : const auto params = core::linear_edge_params(edges);
63 22841991 : if (const auto bin = get_bin<Index>(x, edges, params); bin < 0) {
64 94035 : index = -1;
65 : } else {
66 22747956 : index *= std::get<1>(params); // nbin
67 22747956 : index += bin;
68 : }
69 : }};
70 :
71 : static constexpr auto update_indices_by_binning_sorted_edges =
72 : overloaded{update_indices_by_binning,
73 10435741 : [](auto &index, const auto &x, const auto &edges) {
74 10435741 : if (index == -1)
75 639 : return;
76 10435102 : auto it = std::upper_bound(edges.begin(), edges.end(), x);
77 10435102 : index *= scipp::size(edges) - 1;
78 10435102 : if (it == edges.begin() || it == edges.end()) {
79 3496 : index = -1;
80 : } else {
81 10431606 : index += --it - edges.begin();
82 : }
83 : }};
84 :
85 : template <class Index>
86 : static constexpr auto groups_to_map = overloaded{
87 : element::arg_list<scipp::span<const double>, scipp::span<const float>,
88 : scipp::span<const int64_t>, scipp::span<const int32_t>,
89 : scipp::span<const bool>, scipp::span<const std::string>,
90 : scipp::span<const time_point>>,
91 : transform_flags::expect_no_variance_arg<0>,
92 334 : [](const units::Unit &u) { return u; },
93 334 : [](const auto &groups) {
94 : std::unordered_map<typename std::decay_t<decltype(groups)>::value_type,
95 : Index>
96 334 : index;
97 334 : scipp::index current = 0;
98 3742016 : for (const auto &item : groups)
99 3741682 : index[item] = current++;
100 334 : if (scipp::size(groups) != scipp::size(index))
101 0 : throw std::runtime_error("Duplicate group labels.");
102 334 : return index;
103 0 : }};
104 :
105 : template <class Index, class Coord, class Edges = Coord>
106 : using update_indices_by_grouping_arg =
107 : std::tuple<Index, Coord, std::unordered_map<Edges, Index>>;
108 :
109 : static constexpr auto update_indices_by_grouping = overloaded{
110 : element::arg_list<update_indices_by_grouping_arg<int64_t, double>,
111 : update_indices_by_grouping_arg<int32_t, double>,
112 : update_indices_by_grouping_arg<int64_t, float>,
113 : update_indices_by_grouping_arg<int32_t, float>,
114 : update_indices_by_grouping_arg<int64_t, int64_t>,
115 : update_indices_by_grouping_arg<int32_t, int64_t>,
116 : // Given int32 target groups, select from int64. Note that
117 : // we do not support the reverse for now, since the
118 : // `groups.find(x)` below would then have to cast to a
119 : // lower precision, i.e., we would need special handling.
120 : update_indices_by_grouping_arg<int64_t, int64_t, int32_t>,
121 : update_indices_by_grouping_arg<int32_t, int64_t, int32_t>,
122 : update_indices_by_grouping_arg<int64_t, int32_t>,
123 : update_indices_by_grouping_arg<int32_t, int32_t>,
124 : update_indices_by_grouping_arg<int64_t, bool>,
125 : update_indices_by_grouping_arg<int32_t, bool>,
126 : update_indices_by_grouping_arg<int64_t, std::string>,
127 : update_indices_by_grouping_arg<int32_t, std::string>,
128 : update_indices_by_grouping_arg<int32_t, time_point>,
129 : update_indices_by_grouping_arg<int64_t, time_point>>,
130 334 : [](units::Unit &indices, const units::Unit &coord,
131 : const units::Unit &groups) {
132 334 : expect::equals(coord, groups);
133 334 : expect::equals(units::none, indices);
134 334 : },
135 9067650 : [](auto &index, const auto &x, const auto &groups) {
136 9067650 : if (index == -1)
137 0 : return;
138 9067650 : const auto it = groups.find(x);
139 9067650 : index *= scipp::size(groups);
140 9067650 : index = (it == groups.end()) ? -1 : (index + it->second);
141 : }};
142 :
143 : template <class Index, class Coord, class Edges = Coord>
144 : using update_indices_by_grouping_contiguous_arg =
145 : std::tuple<Index, Coord, scipp::index, Edges>;
146 :
147 : static constexpr auto update_indices_by_grouping_contiguous = overloaded{
148 : element::arg_list<
149 : update_indices_by_grouping_contiguous_arg<int64_t, int64_t>,
150 : update_indices_by_grouping_contiguous_arg<int32_t, int64_t>,
151 : // Given int32 target groups, select from int64. Note that
152 : // we do not support the reverse for now, since the
153 : // `groups.find(x)` below would then have to cast to a
154 : // lower precision, i.e., we would need special handling.
155 : update_indices_by_grouping_contiguous_arg<int64_t, int64_t, int32_t>,
156 : update_indices_by_grouping_contiguous_arg<int32_t, int64_t, int32_t>,
157 : update_indices_by_grouping_contiguous_arg<int64_t, int32_t>,
158 : update_indices_by_grouping_contiguous_arg<int32_t, int32_t>>,
159 4550 : [](units::Unit &indices, const units::Unit &coord,
160 : const units::Unit &ngroup, const units::Unit &offset) {
161 4550 : expect::equals(coord, offset);
162 4550 : expect::equals(units::none, ngroup);
163 4550 : expect::equals(units::none, indices);
164 4550 : },
165 29582534 : [](auto &index, const auto &x, const auto &ngroup, const auto &offset) {
166 29582534 : if (index == -1)
167 0 : return;
168 29582534 : index *= ngroup;
169 29582534 : const auto group = x - offset;
170 29582534 : index = group < 0 || group >= ngroup ? -1 : (index + group);
171 : }};
172 :
173 : static constexpr auto update_indices_from_existing = overloaded{
174 : element::arg_list<std::tuple<int64_t, scipp::index, scipp::index>,
175 : std::tuple<int32_t, scipp::index, scipp::index>>,
176 39 : [](units::Unit &, const units::Unit &, const units::Unit &) {},
177 6647 : [](auto &index, const auto bin_index, const auto nbin) {
178 6647 : if (index == -1)
179 0 : return;
180 6647 : index *= nbin;
181 6647 : index += bin_index;
182 : }};
183 :
184 : static constexpr auto count_indices = overloaded{
185 : element::arg_list<
186 : std::tuple<scipp::span<const int64_t>, scipp::index, scipp::index>,
187 : std::tuple<scipp::span<const int32_t>, scipp::index, scipp::index>>,
188 8377 : [](const units::Unit &indices, const auto &offset, const auto &nbin) {
189 8377 : expect::equals(units::none, indices);
190 8377 : expect::equals(units::none, offset);
191 8377 : expect::equals(units::none, nbin);
192 8377 : return units::none;
193 : },
194 72000 : [](const auto &indices, const auto offset, const auto nbin) {
195 72000 : typename SubbinSizes::container_type counts(nbin);
196 63233942 : for (const auto i : indices)
197 63161942 : if (i >= 0)
198 60930769 : ++counts[i];
199 144000 : return SubbinSizes{offset, std::move(counts)};
200 72000 : }};
201 :
202 : } // namespace scipp::core::element
|