Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Piotr Rozyczko
5 : #include "scipp/core/element/comparison.h"
6 : #include "scipp/core/eigen.h"
7 : #include "scipp/units/string.h"
8 : #include "scipp/variable/comparison.h"
9 : #include "scipp/variable/math.h"
10 : #include "scipp/variable/reduction.h"
11 : #include "scipp/variable/transform.h"
12 : #include "scipp/variable/util.h"
13 : #include "scipp/variable/variable.h"
14 :
15 : using namespace scipp::core;
16 :
17 : namespace scipp::variable {
18 :
19 : namespace {
20 453 : Variable _values(Variable &&in) { return in.has_variances() ? values(in) : in; }
21 :
22 : template <class T, class... Remaining>
23 : std::optional<Variable>
24 2439 : try_isclose_spatial(const Variable &a, const Variable &b, const Variable &rtol,
25 : const Variable &atol, const NanComparisons equal_nans) {
26 2439 : if (a.dtype() == dtype<T>)
27 : return std::optional(
28 47 : all(isclose(a.elements<T>(), b.elements<T>(), rtol, atol, equal_nans),
29 47 : Dim::InternalStructureComponent));
30 : if constexpr (sizeof...(Remaining) > 0)
31 1922 : return try_isclose_spatial<Remaining...>(a, b, rtol, atol, equal_nans);
32 : else
33 470 : return std::nullopt;
34 : }
35 :
36 520 : void expect_rtol_unit_dimensionless_or_none(const Variable &rtol,
37 : const Variable &ref) {
38 520 : const auto expected = ref.unit() == units::none ? scipp::units::none
39 520 : : scipp::units::dimensionless;
40 526 : core::expect::equals(expected, rtol.unit(), " For rtol arg");
41 517 : }
42 : } // namespace
43 :
44 520 : Variable isclose(const Variable &a, const Variable &b, const Variable &rtol,
45 : const Variable &atol, const NanComparisons equal_nans) {
46 520 : expect_rtol_unit_dimensionless_or_none(rtol, atol);
47 517 : if (const auto r =
48 : try_isclose_spatial<Eigen::Vector3d, Eigen::Matrix3d, Eigen::Affine3d,
49 : core::Translation, core::Quaternion>(
50 517 : a, b, rtol, atol, equal_nans);
51 517 : r.has_value()) {
52 47 : return *r;
53 517 : }
54 :
55 940 : auto tol = atol + rtol * abs(b);
56 470 : if (a.has_variances() && b.has_variances()) {
57 34 : return isclose(values(a), values(b), rtol, atol, equal_nans) &
58 51 : isclose(stddevs(a), stddevs(b), rtol, atol, equal_nans);
59 : } else {
60 453 : if (equal_nans == NanComparisons::Equal)
61 86 : return variable::transform(a, b, _values(std::move(tol)),
62 129 : element::isclose_equal_nan, "isclose");
63 : else
64 820 : return variable::transform(a, b, _values(std::move(tol)),
65 1230 : element::isclose, "isclose");
66 : }
67 470 : }
68 :
69 : } // namespace scipp::variable
|