Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Piotr Rozyczko
5 : #include "scipp/core/element/comparison.h"
6 : #include "scipp/core/eigen.h"
7 : #include "scipp/units/string.h"
8 : #include "scipp/variable/comparison.h"
9 : #include "scipp/variable/math.h"
10 : #include "scipp/variable/reduction.h"
11 : #include "scipp/variable/transform.h"
12 : #include "scipp/variable/util.h"
13 : #include "scipp/variable/variable.h"
14 :
15 : using namespace scipp::core;
16 :
17 : namespace scipp::variable {
18 :
19 : namespace {
20 449 : Variable _values(Variable &&in) { return in.has_variances() ? values(in) : in; }
21 :
22 : template <class T, class... Remaining>
23 : std::optional<Variable>
24 2419 : try_isclose_spatial(const Variable &a, const Variable &b, const Variable &rtol,
25 : const Variable &atol, const NanComparisons equal_nans) {
26 2419 : if (a.dtype() == dtype<T>)
27 : return std::optional(
28 47 : all(isclose(a.elements<T>(), b.elements<T>(), rtol, atol, equal_nans),
29 47 : Dim::InternalStructureComponent));
30 : if constexpr (sizeof...(Remaining) > 0)
31 1906 : return try_isclose_spatial<Remaining...>(a, b, rtol, atol, equal_nans);
32 : else
33 466 : return std::nullopt;
34 : }
35 :
36 516 : void expect_rtol_unit_dimensionless_or_none(const Variable &rtol,
37 : const Variable &ref) {
38 516 : const auto expected = ref.unit() == units::none ? scipp::units::none
39 516 : : scipp::units::dimensionless;
40 522 : core::expect::unit(rtol, expected, " For rtol arg");
41 513 : }
42 : } // namespace
43 :
44 516 : Variable isclose(const Variable &a, const Variable &b, const Variable &rtol,
45 : const Variable &atol, const NanComparisons equal_nans) {
46 516 : expect_rtol_unit_dimensionless_or_none(rtol, atol);
47 513 : if (const auto r =
48 : try_isclose_spatial<Eigen::Vector3d, Eigen::Matrix3d, Eigen::Affine3d,
49 : core::Translation, core::Quaternion>(
50 513 : a, b, rtol, atol, equal_nans);
51 513 : r.has_value()) {
52 47 : return *r;
53 513 : }
54 :
55 932 : auto tol = atol + rtol * abs(b);
56 466 : if (a.has_variances() && b.has_variances()) {
57 34 : return isclose(values(a), values(b), rtol, atol, equal_nans) &
58 51 : isclose(stddevs(a), stddevs(b), rtol, atol, equal_nans);
59 : } else {
60 449 : if (equal_nans == NanComparisons::Equal)
61 86 : return variable::transform(a, b, _values(std::move(tol)),
62 129 : element::isclose_equal_nan, "isclose");
63 : else
64 812 : return variable::transform(a, b, _values(std::move(tol)),
65 1218 : element::isclose, "isclose");
66 : }
67 466 : }
68 :
69 : } // namespace scipp::variable
|