Line data Source code
1 : // SPDX-License-Identifier: BSD-3-Clause
2 : // Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3 : /// @file
4 : /// @author Simon Heybrock
5 :
6 : #include "../variable/operations_common.h"
7 : #include "scipp/dataset/sized_dict.h"
8 : #include "scipp/variable/arithmetic.h"
9 : #include "scipp/variable/creation.h"
10 : #include "scipp/variable/reduction.h"
11 : #include "scipp/variable/shape.h"
12 : #include "scipp/variable/special_values.h"
13 : #include "scipp/variable/util.h"
14 :
15 : #include "dataset_operations_common.h"
16 :
17 : namespace scipp::dataset {
18 :
19 : namespace {
20 : template <class Op>
21 2592 : Variable reduce_impl(const Variable &var, const Dim dim, const Masks &masks,
22 : const FillValue fill, const Op &op) {
23 4252 : if (auto mask_union = irreducible_mask(masks, dim); mask_union.is_valid()) {
24 1660 : mask_union = transpose(
25 : mask_union, intersection(var.dims(), mask_union.dims()).labels());
26 : return op(
27 3320 : where(mask_union, dense_special_like(var, Dimensions{}, fill), var),
28 1660 : dim);
29 : }
30 932 : return op(var, dim);
31 : }
32 : } // namespace
33 :
34 2383 : Variable sum(const Variable &var, const Dim dim, const Masks &masks) {
35 : return reduce_impl(var, dim, masks, FillValue::Default,
36 4766 : [](auto &&...args) { return sum(args...); });
37 : }
38 :
39 81 : Variable nansum(const Variable &var, const Dim dim, const Masks &masks) {
40 : return reduce_impl(var, dim, masks, FillValue::Default,
41 162 : [](auto &&...args) { return nansum(args...); });
42 : }
43 :
44 24 : Variable max(const Variable &var, const Dim dim, const Masks &masks) {
45 : return reduce_impl(var, dim, masks, FillValue::Lowest,
46 48 : [](auto &&...args) { return max(args...); });
47 : }
48 :
49 20 : Variable nanmax(const Variable &var, const Dim dim, const Masks &masks) {
50 : return reduce_impl(var, dim, masks, FillValue::Lowest,
51 40 : [](auto &&...args) { return nanmax(args...); });
52 : }
53 :
54 24 : Variable min(const Variable &var, const Dim dim, const Masks &masks) {
55 : return reduce_impl(var, dim, masks, FillValue::Max,
56 48 : [](auto &&...args) { return min(args...); });
57 : }
58 :
59 20 : Variable nanmin(const Variable &var, const Dim dim, const Masks &masks) {
60 : return reduce_impl(var, dim, masks, FillValue::Max,
61 40 : [](auto &&...args) { return nanmin(args...); });
62 : }
63 :
64 20 : Variable all(const Variable &var, const Dim dim, const Masks &masks) {
65 : return reduce_impl(var, dim, masks, FillValue::True,
66 40 : [](auto &&...args) { return all(args...); });
67 : }
68 :
69 20 : Variable any(const Variable &var, const Dim dim, const Masks &masks) {
70 : return reduce_impl(var, dim, masks, FillValue::False,
71 40 : [](auto &&...args) { return any(args...); });
72 : }
73 :
74 49 : Variable mean(const Variable &var, const Dim dim, const Masks &masks) {
75 98 : if (const auto mask_union = irreducible_mask(masks, dim);
76 49 : mask_union.is_valid()) {
77 23 : const auto count = sum(~mask_union, dim);
78 46 : return mean_impl(where(mask_union, zero_like(var), var), dim, count);
79 72 : }
80 26 : return mean(var, dim);
81 : }
82 :
83 37 : Variable nanmean(const Variable &var, const Dim dim, const Masks &masks) {
84 74 : if (const auto mask_union = irreducible_mask(masks, dim);
85 37 : mask_union.is_valid()) {
86 : const auto count = sum(
87 30 : where(mask_union, makeVariable<bool>(Values{false}), ~isnan(var)), dim);
88 30 : return nanmean_impl(where(mask_union, zero_like(var), var), dim, count);
89 52 : }
90 22 : return nanmean(var, dim);
91 : }
92 :
93 : } // namespace scipp::dataset
|