LCOV - code coverage report
Current view: top level - variable - pow.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 49 49 100.0 %
Date: 2024-04-28 01:25:40 Functions: 13 13 100.0 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Jan-Lukas Wynen
       5             : 
       6             : #include "scipp/variable/pow.h"
       7             : 
       8             : #include "scipp/core/element/math.h"
       9             : #include "scipp/core/except.h"
      10             : #include "scipp/core/tag_util.h"
      11             : #include "scipp/variable/astype.h"
      12             : #include "scipp/variable/reduction.h"
      13             : #include "scipp/variable/transform.h"
      14             : 
      15             : using namespace scipp::core;
      16             : 
      17             : namespace scipp::variable {
      18             : 
      19             : namespace {
      20             : template <class V>
      21         269 : Variable pow_do_transform(V &&base, const Variable &exponent,
      22             :                           const bool in_place) {
      23         269 :   if (!in_place) {
      24         109 :     return variable::transform(base, exponent, element::pow, "pow");
      25             :   } else {
      26             :     if constexpr (std::is_const_v<std::remove_reference_t<V>>) {
      27             :       return variable::transform(base, exponent, element::pow, "pow");
      28             :     } else {
      29         160 :       variable::transform_in_place(base, base, exponent, element::pow_in_place,
      30             :                                    "pow");
      31         160 :       return std::forward<V>(base);
      32             :     }
      33             :   }
      34             : }
      35             : 
      36             : template <class T> struct PowUnit {
      37         112 :   static units::Unit apply(const units::Unit base_unit,
      38             :                            const Variable &exponent) {
      39         112 :     const auto exp_val = exponent.value<T>();
      40             :     if constexpr (std::is_floating_point_v<T>) {
      41          61 :       if (static_cast<T>(static_cast<int64_t>(exp_val)) != exp_val) {
      42           3 :         throw except::UnitError("Powers of dimension-full variables must be "
      43             :                                 "integers or integer valued floats. Got " +
      44             :                                 std::to_string(exp_val) + ".");
      45             :       }
      46             :     }
      47         109 :     return pow(base_unit, exp_val);
      48             :   }
      49             : };
      50             : 
      51             : template <class V>
      52         288 : Variable pow_handle_unit(V &&base, const Variable &exponent,
      53             :                          const bool in_place) {
      54         288 :   if (const auto exp_unit = variableFactory().elem_unit(exponent);
      55         288 :       exp_unit != units::one) {
      56          18 :     throw except::UnitError("Powers must be dimensionless, got exponent.unit=" +
      57             :                             to_string(exp_unit) + ".");
      58             :   }
      59             : 
      60         270 :   const auto base_unit = variableFactory().elem_unit(base);
      61         270 :   if (base_unit == units::one) {
      62         157 :     return pow_do_transform(std::forward<V>(base), exponent, in_place);
      63             :   }
      64         113 :   if (exponent.dims().ndim() != 0) {
      65           1 :     throw except::DimensionError("Exponents must be scalar if the base is not "
      66             :                                  "dimensionless. Got base.unit=" +
      67             :                                  to_string(base_unit) + " and exponent.dims=" +
      68             :                                  to_string(exponent.dims()) + ".");
      69             :   }
      70             : 
      71         112 :   Variable res = in_place ? std::forward<V>(base) : copy(std::forward<V>(base));
      72         112 :   variableFactory().set_elem_unit(res, units::one);
      73         112 :   pow_do_transform(res, exponent, true);
      74         221 :   variableFactory().set_elem_unit(
      75         112 :       res, core::CallDType<double, float, int64_t, int32_t>::apply<PowUnit>(
      76             :                exponent.dtype(), base_unit, exponent));
      77         109 :   return res;
      78         112 : }
      79             : 
      80          86 : bool has_negative_value(const Variable &var) {
      81         172 :   return astype(min(var), dtype<int64_t>, CopyPolicy::TryAvoid)
      82         172 :              .value<int64_t>() < 0l;
      83             : }
      84             : 
      85             : template <class V>
      86         293 : Variable pow_handle_dtype(V &&base, const Variable &exponent,
      87             :                           const bool in_place) {
      88         293 :   if (is_bins(exponent)) {
      89           1 :     throw std::invalid_argument("Binned exponents are not supported by pow.");
      90             :   }
      91         292 :   if (!is_int(base.dtype())) {
      92         167 :     return pow_handle_unit(std::forward<V>(base), exponent, in_place);
      93             :   }
      94         125 :   if (is_int(exponent.dtype())) {
      95          86 :     if (has_negative_value(exponent)) {
      96           4 :       throw std::invalid_argument(
      97             :           "Integers to negative powers are not allowed.");
      98             :     }
      99          82 :     return pow_handle_unit(std::forward<V>(base), exponent, in_place);
     100             :   }
     101             :   // Base has integer dtype but exponent does not.
     102          39 :   return pow_handle_unit(astype(base, exponent.dtype()), exponent, true);
     103             : }
     104             : } // namespace
     105             : 
     106         253 : Variable pow(const Variable &base, const Variable &exponent) {
     107         515 :   return pow_handle_dtype(base.broadcast(merge(base.dims(), exponent.dims())),
     108         443 :                           exponent, false);
     109             : }
     110             : 
     111          82 : Variable &pow(const Variable &base, const Variable &exponent, Variable &out) {
     112          82 :   const auto target_dims = merge(base.dims(), exponent.dims());
     113         130 :   core::expect::equals(target_dims, out.dims());
     114          58 :   copy(astype(base, out.dtype(), CopyPolicy::TryAvoid), out);
     115          58 :   pow_handle_dtype(out, exponent, true);
     116          57 :   return out;
     117          82 : }
     118             : } // namespace scipp::variable

Generated by: LCOV version 1.14