Coverage for install/scipp/core/concepts.py: 62%
42 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock
4from functools import reduce
5from typing import Callable, Dict, List, Mapping, Tuple, Union
7from ..typing import Dims, VariableLikeType
8from .cpp_classes import DataArray, Variable
9from .logical import logical_or
12def _copied(obj: Mapping[str, Variable]) -> Dict[str, Variable]:
13 return {name: var.copy() for name, var in obj.items()}
16def _reduced(obj: Mapping[str, Variable], dims: List[str]) -> Dict[str, Variable]:
17 dims = set(dims)
18 return {name: var for name, var in obj.items() if dims.isdisjoint(var.dims)}
21def rewrap_output_data(prototype: VariableLikeType, data) -> VariableLikeType:
22 if isinstance(prototype, DataArray):
23 return DataArray(
24 data=data,
25 coords=prototype.coords,
26 attrs=prototype.deprecated_attrs,
27 masks=_copied(prototype.masks),
28 )
29 else:
30 return data
33def rewrap_reduced_data(
34 prototype: VariableLikeType, data, dim: Dims
35) -> VariableLikeType:
36 return DataArray(
37 data,
38 coords=reduced_coords(prototype, dim),
39 masks=reduced_masks(prototype, dim),
40 attrs=reduced_attrs(prototype, dim),
41 )
44def transform_data(obj: VariableLikeType, func: Callable) -> VariableLikeType:
45 if isinstance(obj, Variable):
46 return func(obj)
47 if isinstance(obj, DataArray):
48 return rewrap_output_data(obj, func(obj.data))
49 else:
50 raise TypeError(f"{func} only supports Variable and DataArray as inputs.")
53def concrete_dims(obj: VariableLikeType, dim: Dims) -> Tuple[str]:
54 """Convert a dimension specification into a concrete tuple of dimension labels.
56 This does *not* validate that the dimension labels are valid for the given object.
57 """
58 if dim is None:
59 return obj.dims
60 return (dim,) if isinstance(dim, str) else tuple(dim)
63def reduced_coords(da: DataArray, dim: Dims) -> Dict[str, Variable]:
64 return _reduced(da.coords, concrete_dims(da, dim))
67def reduced_attrs(da: DataArray, dim: Dims) -> Dict[str, Variable]:
68 return _reduced(da.deprecated_attrs, concrete_dims(da, dim))
71def reduced_masks(da: DataArray, dim: Dims) -> Dict[str, Variable]:
72 return _copied(_reduced(da.masks, concrete_dims(da, dim)))
75def irreducible_mask(da: DataArray, dim: Dims) -> Union[None, Variable]:
76 """
77 The union of masks that would need to be applied in a reduction op over dim.
79 Irreducible means that a reduction operation must apply these masks since they
80 depend on the reduction dimensions. Returns None if there is no irreducible mask.
81 """
82 dims = set(concrete_dims(da, dim))
83 irreducible = [mask for mask in da.masks.values() if not dims.isdisjoint(mask.dims)]
84 if len(irreducible) == 0:
85 return None
87 def _transposed_like_data(x):
88 return x.transpose([dim for dim in da.dims if dim in x.dims])
90 if len(irreducible) == 1:
91 return _transposed_like_data(irreducible[0]).copy()
92 return _transposed_like_data(reduce(logical_or, irreducible))