LCOV - code coverage report
Current view: top level - core/include/scipp/core/element - math.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 51 57 89.5 %
Date: 2024-04-28 01:25:40 Functions: 52 99 52.5 %

          Line data    Source code
       1             : // SPDX-License-Identifier: BSD-3-Clause
       2             : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
       3             : /// @file
       4             : /// @author Simon Heybrock
       5             : #pragma once
       6             : 
       7             : #include "scipp/common/numeric.h"
       8             : #include "scipp/common/overloaded.h"
       9             : #include "scipp/core/eigen.h"
      10             : #include "scipp/core/element/arg_list.h"
      11             : #include "scipp/core/transform_common.h"
      12             : #include <Eigen/Geometry>
      13             : #include <cmath>
      14             : 
      15             : #if __cplusplus > 201703L
      16             : #include <numeric>
      17             : namespace scipp::core::element::detail {
      18             : using midpoint = std::midpoint;
      19             : }
      20             : #else
      21             : namespace scipp::core::element::detail {
      22          56 : template <class T> constexpr auto midpoint(const T &a, const T &b) {
      23             :   if constexpr (std::is_integral_v<T>) {
      24             :     using U = std::make_unsigned_t<T>;
      25          14 :     int sign = 1;
      26          14 :     U m = a;
      27          14 :     U M = b;
      28          14 :     if (a > b) {
      29           4 :       sign = -1;
      30           4 :       m = b;
      31           4 :       M = a;
      32             :     }
      33          14 :     return a + sign * static_cast<T>(static_cast<U>(M - m) / 2);
      34             :   } else {
      35          42 :     constexpr auto lo = std::numeric_limits<T>::min() * 2;
      36          42 :     constexpr auto hi = std::numeric_limits<T>::max() / 2;
      37          42 :     if (std::abs(a) <= hi && std::abs(b) <= hi)
      38          42 :       return (a + b) / 2; // always correctly rounded
      39           0 :     if (std::abs(a) < lo) // not safe to halve a
      40           0 :       return a + b / 2;
      41           0 :     if (std::abs(b) < lo) // not safe to halve b
      42           0 :       return a / 2 + b;
      43           0 :     return a / 2 + b / 2; // otherwise correctly rounded
      44             :   }
      45             : }
      46             : } // namespace scipp::core::element::detail
      47             : #endif
      48             : 
      49             : namespace scipp::core::element {
      50             : 
      51             : constexpr auto abs =
      52       22640 :     overloaded{arg_list<double, float, int64_t, int32_t>, [](const auto x) {
      53             :                  using std::abs;
      54       22640 :                  return abs(x);
      55             :                }};
      56             : 
      57             : constexpr auto norm = overloaded{arg_list<Eigen::Vector3d>,
      58           4 :                                  [](const auto &x) { return x.norm(); },
      59           2 :                                  [](const units::Unit &x) { return x; }};
      60             : 
      61             : constexpr auto pow = overloaded{
      62             :     arg_list<std::tuple<double, double>, std::tuple<double, float>,
      63             :              std::tuple<double, int32_t>, std::tuple<double, int64_t>,
      64             :              std::tuple<float, double>, std::tuple<float, float>,
      65             :              std::tuple<float, int32_t>, std::tuple<float, int64_t>,
      66             :              std::tuple<int64_t, int64_t>, std::tuple<int64_t, int32_t>>,
      67             :     transform_flags::expect_no_variance_arg<1>, dimensionless_unit_check_return,
      68         669 :     [](const auto &base, const auto &exponent) {
      69             :       using numeric::pow;
      70         669 :       return pow(base, exponent);
      71             :     }};
      72             : 
      73             : constexpr auto pow_in_place = overloaded{
      74             :     arg_list<
      75             :         std::tuple<double, double, double>, std::tuple<double, double, float>,
      76             :         std::tuple<double, double, int32_t>,
      77             :         std::tuple<double, double, int64_t>, std::tuple<float, float, double>,
      78             :         std::tuple<float, float, float>, std::tuple<float, float, int32_t>,
      79             :         std::tuple<float, float, int64_t>,
      80             :         std::tuple<int64_t, int64_t, int64_t>,
      81             :         std::tuple<int64_t, int64_t, int32_t>>,
      82             :     transform_flags::expect_in_variance_if_out_variance,
      83             :     transform_flags::expect_no_variance_arg<2>,
      84         404 :     [](auto &out, const auto &base, const auto &exponent) {
      85             :       // Use element::pow instead of numeric::pow to inherit unit
      86             :       // handling.
      87         404 :       out = element::pow(base, exponent);
      88         404 :     }};
      89             : 
      90        7604 : constexpr auto sqrt = overloaded{arg_list<double, float>, [](const auto x) {
      91             :                                    using std::sqrt;
      92        7604 :                                    return sqrt(x);
      93             :                                  }};
      94             : 
      95             : constexpr auto dot = overloaded{
      96             :     arg_list<Eigen::Vector3d>,
      97           4 :     [](const auto &a, const auto &b) { return a.dot(b); },
      98           2 :     [](const units::Unit &a, const units::Unit &b) { return a * b; }};
      99             : 
     100             : constexpr auto cross = overloaded{
     101             :     arg_list<Eigen::Vector3d>,
     102           4 :     [](const auto &a, const auto &b) { return a.cross(b); },
     103           2 :     [](const units::Unit &a, const units::Unit &b) { return a * b; }};
     104             : 
     105             : constexpr auto reciprocal = overloaded{
     106             :     arg_list<double, float>,
     107         540 :     [](const auto &x) { return static_cast<std::decay_t<decltype(x)>>(1) / x; },
     108         347 :     [](const units::Unit &unit) { return units::one / unit; }};
     109             : 
     110             : constexpr auto exp =
     111             :     overloaded{arg_list<double, float>, dimensionless_unit_check_return,
     112     1159008 :                [](const auto &x) {
     113             :                  using std::exp;
     114     1159008 :                  return exp(x);
     115             :                }};
     116             : 
     117             : constexpr auto log =
     118             :     overloaded{arg_list<double, float>, dimensionless_unit_check_return,
     119          11 :                [](const auto &x) {
     120             :                  using std::log;
     121          11 :                  return log(x);
     122             :                }};
     123             : 
     124             : constexpr auto log10 =
     125             :     overloaded{arg_list<double, float>, dimensionless_unit_check_return,
     126          11 :                [](const auto &x) {
     127             :                  using std::log10;
     128          11 :                  return log10(x);
     129             :                }};
     130             : 
     131             : constexpr auto floor =
     132             :     overloaded{transform_flags::expect_no_variance_arg<0>,
     133             :                transform_flags::expect_no_variance_arg<1>,
     134          15 :                core::element::arg_list<double, float>, [](const auto &a) {
     135             :                  using std::floor;
     136          15 :                  return floor(a);
     137             :                }};
     138             : 
     139             : constexpr auto ceil =
     140             :     overloaded{transform_flags::expect_no_variance_arg<0>,
     141             :                transform_flags::expect_no_variance_arg<1>,
     142          15 :                core::element::arg_list<double, float>, [](const auto &a) {
     143             :                  using std::ceil;
     144          15 :                  return ceil(a);
     145             :                }};
     146             : 
     147             : constexpr auto rint =
     148             :     overloaded{transform_flags::expect_no_variance_arg<0>,
     149             :                transform_flags::expect_no_variance_arg<1>,
     150          21 :                core::element::arg_list<double, float>, [](const auto &a) {
     151             :                  using std::rint;
     152          21 :                  return rint(a);
     153             :                }};
     154             : 
     155             : constexpr auto special = overloaded{arg_list<double, float, int64_t, int32_t>,
     156             :                                     dimensionless_unit_check_return,
     157             :                                     transform_flags::expect_no_variance_arg<0>};
     158             : 
     159           1 : constexpr auto erf = overloaded{special, [](const auto &x) {
     160             :                                   using std::erf;
     161           1 :                                   return erf(x);
     162             :                                 }};
     163             : 
     164           1 : constexpr auto erfc = overloaded{special, [](const auto &x) {
     165             :                                    using std::erfc;
     166           1 :                                    return erfc(x);
     167             :                                  }};
     168             : 
     169             : /*
     170             :  * Variances are not allowed because the outputs would be strongly correlated.
     171             :  * Given inputs (x, y, z), the midpoints have covariance
     172             :  *     Cov(mid(x, y), mid(y, z)) = Var(y) / 4
     173             :  * In the common case that all inputs have similar variances,
     174             :  * Pearson's correlation coefficient is
     175             :  *     rho ~ 1/2
     176             :  * that is, neighboring outputs are 50% correlated.
     177             :  */
     178             : constexpr auto midpoint = overloaded{
     179             :     arg_list<double, float, int64_t, int32_t, time_point>,
     180             :     transform_flags::expect_no_variance_arg<0>,
     181             :     transform_flags::expect_no_variance_arg<1>,
     182          16 :     [](const units::Unit &a, const units::Unit &b) {
     183          16 :       expect::equals(a, b);
     184          16 :       return a;
     185             :     },
     186          56 :     [](const auto &a, const auto &b) {
     187             :       if constexpr (std::is_same_v<std::decay_t<decltype(a)>, time_point>) {
     188             :         return time_point{
     189           0 :             detail::midpoint(a.time_since_epoch(), b.time_since_epoch())};
     190             :       } else {
     191          56 :         return detail::midpoint(a, b);
     192             :       }
     193             :     }};
     194             : 
     195             : } // namespace scipp::core::element

Generated by: LCOV version 1.14