Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 : #pragma once
6 :
7 : #include "scipp/common/numeric.h"
8 : #include "scipp/common/overloaded.h"
9 : #include "scipp/core/eigen.h"
10 : #include "scipp/core/element/arg_list.h"
11 : #include "scipp/core/transform_common.h"
12 : #include <Eigen/Geometry>
13 : #include <cmath>
14 :
15 : #if __cplusplus > 201703L
16 : #include <numeric>
17 : namespace scipp::core::element::detail {
18 : using midpoint = std::midpoint;
19 : }
20 : #else
21 : namespace scipp::core::element::detail {
22 56 : template <class T> constexpr auto midpoint(const T &a, const T &b) {
23 : if constexpr (std::is_integral_v<T>) {
24 : using U = std::make_unsigned_t<T>;
25 14 : int sign = 1;
26 14 : U m = a;
27 14 : U M = b;
28 14 : if (a > b) {
29 4 : sign = -1;
30 4 : m = b;
31 4 : M = a;
32 : }
33 14 : return a + sign * static_cast<T>(static_cast<U>(M - m) / 2);
34 : } else {
35 42 : constexpr auto lo = std::numeric_limits<T>::min() * 2;
36 42 : constexpr auto hi = std::numeric_limits<T>::max() / 2;
37 42 : if (std::abs(a) <= hi && std::abs(b) <= hi)
38 42 : return (a + b) / 2; // always correctly rounded
39 0 : if (std::abs(a) < lo) // not safe to halve a
40 0 : return a + b / 2;
41 0 : if (std::abs(b) < lo) // not safe to halve b
42 0 : return a / 2 + b;
43 0 : return a / 2 + b / 2; // otherwise correctly rounded
44 : }
45 : }
46 : } // namespace scipp::core::element::detail
47 : #endif
48 :
49 : namespace scipp::core::element {
50 :
51 : constexpr auto abs =
52 22640 : overloaded{arg_list<double, float, int64_t, int32_t>, [](const auto x) {
53 : using std::abs;
54 22640 : return abs(x);
55 : }};
56 :
57 : constexpr auto norm = overloaded{arg_list<Eigen::Vector3d>,
58 4 : [](const auto &x) { return x.norm(); },
59 2 : [](const units::Unit &x) { return x; }};
60 :
61 : constexpr auto pow = overloaded{
62 : arg_list<std::tuple<double, double>, std::tuple<double, float>,
63 : std::tuple<double, int32_t>, std::tuple<double, int64_t>,
64 : std::tuple<float, double>, std::tuple<float, float>,
65 : std::tuple<float, int32_t>, std::tuple<float, int64_t>,
66 : std::tuple<int64_t, int64_t>, std::tuple<int64_t, int32_t>>,
67 : transform_flags::expect_no_variance_arg<1>, dimensionless_unit_check_return,
68 669 : [](const auto &base, const auto &exponent) {
69 : using numeric::pow;
70 669 : return pow(base, exponent);
71 : }};
72 :
73 : constexpr auto pow_in_place = overloaded{
74 : arg_list<
75 : std::tuple<double, double, double>, std::tuple<double, double, float>,
76 : std::tuple<double, double, int32_t>,
77 : std::tuple<double, double, int64_t>, std::tuple<float, float, double>,
78 : std::tuple<float, float, float>, std::tuple<float, float, int32_t>,
79 : std::tuple<float, float, int64_t>,
80 : std::tuple<int64_t, int64_t, int64_t>,
81 : std::tuple<int64_t, int64_t, int32_t>>,
82 : transform_flags::expect_in_variance_if_out_variance,
83 : transform_flags::expect_no_variance_arg<2>,
84 404 : [](auto &out, const auto &base, const auto &exponent) {
85 : // Use element::pow instead of numeric::pow to inherit unit
86 : // handling.
87 404 : out = element::pow(base, exponent);
88 404 : }};
89 :
90 7604 : constexpr auto sqrt = overloaded{arg_list<double, float>, [](const auto x) {
91 : using std::sqrt;
92 7604 : return sqrt(x);
93 : }};
94 :
95 : constexpr auto dot = overloaded{
96 : arg_list<Eigen::Vector3d>,
97 4 : [](const auto &a, const auto &b) { return a.dot(b); },
98 2 : [](const units::Unit &a, const units::Unit &b) { return a * b; }};
99 :
100 : constexpr auto cross = overloaded{
101 : arg_list<Eigen::Vector3d>,
102 4 : [](const auto &a, const auto &b) { return a.cross(b); },
103 2 : [](const units::Unit &a, const units::Unit &b) { return a * b; }};
104 :
105 : constexpr auto reciprocal = overloaded{
106 : arg_list<double, float>,
107 540 : [](const auto &x) { return static_cast<std::decay_t<decltype(x)>>(1) / x; },
108 347 : [](const units::Unit &unit) { return units::one / unit; }};
109 :
110 : constexpr auto exp =
111 : overloaded{arg_list<double, float>, dimensionless_unit_check_return,
112 1159008 : [](const auto &x) {
113 : using std::exp;
114 1159008 : return exp(x);
115 : }};
116 :
117 : constexpr auto log =
118 : overloaded{arg_list<double, float>, dimensionless_unit_check_return,
119 11 : [](const auto &x) {
120 : using std::log;
121 11 : return log(x);
122 : }};
123 :
124 : constexpr auto log10 =
125 : overloaded{arg_list<double, float>, dimensionless_unit_check_return,
126 11 : [](const auto &x) {
127 : using std::log10;
128 11 : return log10(x);
129 : }};
130 :
131 : constexpr auto floor =
132 : overloaded{transform_flags::expect_no_variance_arg<0>,
133 : transform_flags::expect_no_variance_arg<1>,
134 15 : core::element::arg_list<double, float>, [](const auto &a) {
135 : using std::floor;
136 15 : return floor(a);
137 : }};
138 :
139 : constexpr auto ceil =
140 : overloaded{transform_flags::expect_no_variance_arg<0>,
141 : transform_flags::expect_no_variance_arg<1>,
142 15 : core::element::arg_list<double, float>, [](const auto &a) {
143 : using std::ceil;
144 15 : return ceil(a);
145 : }};
146 :
147 : constexpr auto rint =
148 : overloaded{transform_flags::expect_no_variance_arg<0>,
149 : transform_flags::expect_no_variance_arg<1>,
150 21 : core::element::arg_list<double, float>, [](const auto &a) {
151 : using std::rint;
152 21 : return rint(a);
153 : }};
154 :
155 : constexpr auto special = overloaded{arg_list<double, float, int64_t, int32_t>,
156 : dimensionless_unit_check_return,
157 : transform_flags::expect_no_variance_arg<0>};
158 :
159 1 : constexpr auto erf = overloaded{special, [](const auto &x) {
160 : using std::erf;
161 1 : return erf(x);
162 : }};
163 :
164 1 : constexpr auto erfc = overloaded{special, [](const auto &x) {
165 : using std::erfc;
166 1 : return erfc(x);
167 : }};
168 :
169 : /*
170 : * Variances are not allowed because the outputs would be strongly correlated.
171 : * Given inputs (x, y, z), the midpoints have covariance
172 : * Cov(mid(x, y), mid(y, z)) = Var(y) / 4
173 : * In the common case that all inputs have similar variances,
174 : * Pearson's correlation coefficient is
175 : * rho ~ 1/2
176 : * that is, neighboring outputs are 50% correlated.
177 : */
178 : constexpr auto midpoint = overloaded{
179 : arg_list<double, float, int64_t, int32_t, time_point>,
180 : transform_flags::expect_no_variance_arg<0>,
181 : transform_flags::expect_no_variance_arg<1>,
182 16 : [](const units::Unit &a, const units::Unit &b) {
183 16 : expect::equals(a, b);
184 16 : return a;
185 : },
186 56 : [](const auto &a, const auto &b) {
187 : if constexpr (std::is_same_v<std::decay_t<decltype(a)>, time_point>) {
188 : return time_point{
189 0 : detail::midpoint(a.time_since_epoch(), b.time_since_epoch())};
190 : } else {
191 56 : return detail::midpoint(a, b);
192 : }
193 : }};
194 :
195 : } // namespace scipp::core::element
|