Coverage for install/scipp/core/concepts.py: 59%
46 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-01 01:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-01 01:59 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock
4from collections.abc import Callable, Iterable, Mapping
5from functools import reduce
6from typing import TypeVar
8from ..typing import Dims, VariableLikeType
9from .cpp_classes import DataArray, DimensionError, Variable
10from .logical import logical_or
12_VarOrDa = TypeVar('_VarOrDa', Variable, DataArray)
15def _copied(obj: Mapping[str, Variable]) -> dict[str, Variable]:
16 return {name: var.copy() for name, var in obj.items()}
19def _reduced(obj: Mapping[str, Variable], dims: Iterable[str]) -> dict[str, Variable]:
20 ref_dims = set(dims)
21 return {name: var for name, var in obj.items() if ref_dims.isdisjoint(var.dims)}
24def rewrap_output_data(prototype: _VarOrDa, data: Variable) -> _VarOrDa:
25 if isinstance(prototype, DataArray):
26 return DataArray(
27 data=data,
28 coords=prototype.coords,
29 attrs=prototype.deprecated_attrs,
30 masks=_copied(prototype.masks),
31 )
32 else:
33 return data
36def rewrap_reduced_data(prototype: DataArray, data: Variable, dim: Dims) -> DataArray:
37 return DataArray(
38 data,
39 coords=reduced_coords(prototype, dim),
40 masks=reduced_masks(prototype, dim),
41 attrs=reduced_attrs(prototype, dim),
42 )
45def transform_data(
46 obj: VariableLikeType, func: Callable[[Variable], Variable]
47) -> VariableLikeType:
48 if isinstance(obj, Variable):
49 return func(obj)
50 if isinstance(obj, DataArray):
51 return rewrap_output_data(obj, func(obj.data))
52 else:
53 raise TypeError(f"{func} only supports Variable and DataArray as inputs.")
56def concrete_dims(obj: VariableLikeType, dim: Dims) -> tuple[str, ...]:
57 """Convert a dimension specification into a concrete tuple of dimension labels.
59 This does *not* validate that the dimension labels are valid for the given object.
60 """
61 if dim is None:
62 if None in obj.dims:
63 raise DimensionError(
64 f'Got data group with unequal dimension lengths: dim={obj.dims}'
65 )
66 return obj.dims
67 return (dim,) if isinstance(dim, str) else tuple(dim)
70def reduced_coords(da: DataArray, dim: Dims) -> dict[str, Variable]:
71 return _reduced(da.coords, concrete_dims(da, dim))
74def reduced_attrs(da: DataArray, dim: Dims) -> dict[str, Variable]:
75 return _reduced(da.deprecated_attrs, concrete_dims(da, dim))
78def reduced_masks(da: DataArray, dim: Dims) -> dict[str, Variable]:
79 return _copied(_reduced(da.masks, concrete_dims(da, dim)))
82def irreducible_mask(da: DataArray, dim: Dims) -> None | Variable:
83 """
84 The union of masks that would need to be applied in a reduction op over dim.
86 Irreducible means that a reduction operation must apply these masks since they
87 depend on the reduction dimensions. Returns None if there is no irreducible mask.
88 """
89 dims = set(concrete_dims(da, dim))
90 irreducible = [mask for mask in da.masks.values() if not dims.isdisjoint(mask.dims)]
91 if len(irreducible) == 0:
92 return None
94 def _transposed_like_data(x: Variable) -> Variable:
95 return x.transpose([d for d in da.dims if d in x.dims])
97 if len(irreducible) == 1:
98 return _transposed_like_data(irreducible[0]).copy()
99 return _transposed_like_data(reduce(logical_or, irreducible))