Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : #include "scipp/variable/arithmetic.h"
4 : #include "scipp/core/dtype.h"
5 : #include "scipp/core/eigen.h"
6 : #include "scipp/core/element/arithmetic.h"
7 : #include "scipp/core/spatial_transforms.h"
8 : #include "scipp/variable/astype.h"
9 : #include "scipp/variable/pow.h"
10 : #include "scipp/variable/transform.h"
11 : #include "scipp/variable/variable_factory.h"
12 :
13 : namespace scipp::variable {
14 :
15 : namespace {
16 :
17 144037 : bool is_transform_with_translation(const Variable &var) {
18 288060 : return var.dtype() == dtype<Eigen::Affine3d> ||
19 288060 : var.dtype() == dtype<scipp::core::Translation>;
20 : }
21 :
22 26 : auto make_factor(const Variable &prototype, const double value) {
23 26 : const auto unit = variableFactory().elem_unit(prototype) == units::none
24 26 : ? units::none
25 26 : : units::one;
26 52 : return astype(makeVariable<double>(Values{value}, unit),
27 78 : variableFactory().elem_dtype(prototype));
28 : }
29 :
30 : /// True if a and b are correlated, currently only if referencing same.
31 493232 : bool correlated(const Variable &a, const Variable &b) {
32 493232 : return variableFactory().has_variances(a) &&
33 493232 : variableFactory().has_variances(b) && a.is_same(b);
34 : }
35 :
36 : } // namespace
37 :
38 305841 : Variable operator+(const Variable &a, const Variable &b) {
39 305841 : if (correlated(a, b))
40 36 : return a * make_factor(a, 2.0);
41 305823 : return transform(a, b, core::element::add, "add");
42 : }
43 :
44 33805 : Variable operator-(const Variable &a, const Variable &b) {
45 33805 : if (correlated(a, b))
46 4 : return a * make_factor(a, 0.0);
47 33803 : return transform(a, b, core::element::subtract, "subtract");
48 : }
49 :
50 144016 : Variable operator*(const Variable &a, const Variable &b) {
51 144037 : if (is_transform_with_translation(a) &&
52 35 : (is_transform_with_translation(b) ||
53 144030 : b.dtype() == dtype<Eigen::Vector3d>)) {
54 : return transform(a, b, core::element::apply_spatial_transformation,
55 38 : std::string_view("apply_spatial_transformation"));
56 : } else {
57 143996 : if (correlated(a, b))
58 2 : return pow(a, make_factor(a, 2.0));
59 : return transform(a, b, core::element::multiply,
60 287986 : std::string_view("multiply"));
61 : }
62 : }
63 :
64 8957 : Variable operator/(const Variable &a, const Variable &b) {
65 8957 : if (correlated(a, b))
66 2 : return pow(a, make_factor(a, 0.0));
67 8956 : return transform(a, b, core::element::divide, "divide");
68 : }
69 :
70 155 : Variable &operator+=(Variable &a, const Variable &b) {
71 155 : operator+=(Variable(a), b);
72 145 : return a;
73 : }
74 :
75 31 : Variable &operator-=(Variable &a, const Variable &b) {
76 31 : operator-=(Variable(a), b);
77 29 : return a;
78 : }
79 :
80 63 : Variable &operator*=(Variable &a, const Variable &b) {
81 63 : operator*=(Variable(a), b);
82 61 : return a;
83 : }
84 :
85 19 : Variable &operator/=(Variable &a, const Variable &b) {
86 19 : operator/=(Variable(a), b);
87 17 : return a;
88 : }
89 :
90 2 : Variable &floor_divide_equals(Variable &a, const Variable &b) {
91 2 : floor_divide_equals(Variable(a), b);
92 2 : return a;
93 : }
94 :
95 320 : Variable operator+=(Variable &&a, const Variable &b) {
96 320 : if (correlated(a, b))
97 2 : return a *= make_factor(a, 2.0);
98 319 : transform_in_place(a, b, core::element::add_equals,
99 319 : std::string_view("add_equals"));
100 308 : return std::move(a);
101 : }
102 :
103 92 : Variable operator-=(Variable &&a, const Variable &b) {
104 92 : if (correlated(a, b))
105 2 : return a *= make_factor(a, 0.0);
106 91 : transform_in_place(a, b, core::element::subtract_equals,
107 91 : std::string_view("subtract_equals"));
108 89 : return std::move(a);
109 : }
110 :
111 151 : Variable operator*=(Variable &&a, const Variable &b) {
112 151 : if (correlated(a, b))
113 2 : return pow(a, make_factor(a, 2.0), a);
114 150 : transform_in_place(a, b, core::element::multiply_equals,
115 150 : std::string_view("multiply_equals"));
116 145 : return std::move(a);
117 : }
118 :
119 70 : Variable operator/=(Variable &&a, const Variable &b) {
120 70 : if (correlated(a, b))
121 2 : return pow(a, make_factor(a, 0.0), a);
122 69 : transform_in_place(a, b, core::element::divide_equals,
123 69 : std::string_view("divide_equals"));
124 67 : return std::move(a);
125 : }
126 :
127 2 : Variable floor_divide_equals(Variable &&a, const Variable &b) {
128 2 : transform_in_place(a, b, core::element::floor_divide_equals,
129 2 : std::string_view("divide_equals"));
130 2 : return std::move(a);
131 : }
132 :
133 : } // namespace scipp::variable
|