Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Jan-Lukas Wynen
5 :
6 : #include "scipp/variable/pow.h"
7 :
8 : #include "scipp/core/element/math.h"
9 : #include "scipp/core/except.h"
10 : #include "scipp/core/tag_util.h"
11 : #include "scipp/variable/astype.h"
12 : #include "scipp/variable/reduction.h"
13 : #include "scipp/variable/transform.h"
14 :
15 : using namespace scipp::core;
16 :
17 : namespace scipp::variable {
18 :
19 : namespace {
20 : template <class V>
21 269 : Variable pow_do_transform(V &&base, const Variable &exponent,
22 : const bool in_place) {
23 269 : if (!in_place) {
24 109 : return variable::transform(base, exponent, element::pow, "pow");
25 : } else {
26 : if constexpr (std::is_const_v<std::remove_reference_t<V>>) {
27 : return variable::transform(base, exponent, element::pow, "pow");
28 : } else {
29 160 : variable::transform_in_place(base, base, exponent, element::pow_in_place,
30 : "pow");
31 160 : return std::forward<V>(base);
32 : }
33 : }
34 : }
35 :
36 : template <class T> struct PowUnit {
37 112 : static units::Unit apply(const units::Unit base_unit,
38 : const Variable &exponent) {
39 112 : const auto exp_val = exponent.value<T>();
40 : if constexpr (std::is_floating_point_v<T>) {
41 61 : if (static_cast<T>(static_cast<int64_t>(exp_val)) != exp_val) {
42 3 : throw except::UnitError("Powers of dimension-full variables must be "
43 : "integers or integer valued floats. Got " +
44 : std::to_string(exp_val) + ".");
45 : }
46 : }
47 109 : return pow(base_unit, exp_val);
48 : }
49 : };
50 :
51 : template <class V>
52 288 : Variable pow_handle_unit(V &&base, const Variable &exponent,
53 : const bool in_place) {
54 288 : if (const auto exp_unit = variableFactory().elem_unit(exponent);
55 288 : exp_unit != units::one) {
56 18 : throw except::UnitError("Powers must be dimensionless, got exponent.unit=" +
57 : to_string(exp_unit) + ".");
58 : }
59 :
60 270 : const auto base_unit = variableFactory().elem_unit(base);
61 270 : if (base_unit == units::one) {
62 157 : return pow_do_transform(std::forward<V>(base), exponent, in_place);
63 : }
64 113 : if (exponent.dims().ndim() != 0) {
65 1 : throw except::DimensionError("Exponents must be scalar if the base is not "
66 : "dimensionless. Got base.unit=" +
67 : to_string(base_unit) + " and exponent.dims=" +
68 : to_string(exponent.dims()) + ".");
69 : }
70 :
71 112 : Variable res = in_place ? std::forward<V>(base) : copy(std::forward<V>(base));
72 112 : variableFactory().set_elem_unit(res, units::one);
73 112 : pow_do_transform(res, exponent, true);
74 221 : variableFactory().set_elem_unit(
75 112 : res, core::CallDType<double, float, int64_t, int32_t>::apply<PowUnit>(
76 : exponent.dtype(), base_unit, exponent));
77 109 : return res;
78 112 : }
79 :
80 86 : bool has_negative_value(const Variable &var) {
81 172 : return astype(min(var), dtype<int64_t>, CopyPolicy::TryAvoid)
82 172 : .value<int64_t>() < 0l;
83 : }
84 :
85 : template <class V>
86 293 : Variable pow_handle_dtype(V &&base, const Variable &exponent,
87 : const bool in_place) {
88 293 : if (is_bins(exponent)) {
89 1 : throw std::invalid_argument("Binned exponents are not supported by pow.");
90 : }
91 292 : if (!is_int(base.dtype())) {
92 167 : return pow_handle_unit(std::forward<V>(base), exponent, in_place);
93 : }
94 125 : if (is_int(exponent.dtype())) {
95 86 : if (has_negative_value(exponent)) {
96 4 : throw std::invalid_argument(
97 : "Integers to negative powers are not allowed.");
98 : }
99 82 : return pow_handle_unit(std::forward<V>(base), exponent, in_place);
100 : }
101 : // Base has integer dtype but exponent does not.
102 39 : return pow_handle_unit(astype(base, exponent.dtype()), exponent, true);
103 : }
104 : } // namespace
105 :
106 253 : Variable pow(const Variable &base, const Variable &exponent) {
107 515 : return pow_handle_dtype(base.broadcast(merge(base.dims(), exponent.dims())),
108 443 : exponent, false);
109 : }
110 :
111 82 : Variable &pow(const Variable &base, const Variable &exponent, Variable &out) {
112 82 : const auto target_dims = merge(base.dims(), exponent.dims());
113 130 : core::expect::equals(target_dims, out.dims());
114 58 : copy(astype(base, out.dtype(), CopyPolicy::TryAvoid), out);
115 58 : pow_handle_dtype(out, exponent, true);
116 57 : return out;
117 82 : }
118 : } // namespace scipp::variable
|