Coverage for install/scipp/core/concepts.py: 59%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-01 01:59 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Simon Heybrock 

4from collections.abc import Callable, Iterable, Mapping 

5from functools import reduce 

6from typing import TypeVar 

7 

8from ..typing import Dims, VariableLikeType 

9from .cpp_classes import DataArray, DimensionError, Variable 

10from .logical import logical_or 

11 

12_VarOrDa = TypeVar('_VarOrDa', Variable, DataArray) 

13 

14 

15def _copied(obj: Mapping[str, Variable]) -> dict[str, Variable]: 

16 return {name: var.copy() for name, var in obj.items()} 

17 

18 

19def _reduced(obj: Mapping[str, Variable], dims: Iterable[str]) -> dict[str, Variable]: 

20 ref_dims = set(dims) 

21 return {name: var for name, var in obj.items() if ref_dims.isdisjoint(var.dims)} 

22 

23 

24def rewrap_output_data(prototype: _VarOrDa, data: Variable) -> _VarOrDa: 

25 if isinstance(prototype, DataArray): 

26 return DataArray( 

27 data=data, 

28 coords=prototype.coords, 

29 attrs=prototype.deprecated_attrs, 

30 masks=_copied(prototype.masks), 

31 ) 

32 else: 

33 return data 

34 

35 

36def rewrap_reduced_data(prototype: DataArray, data: Variable, dim: Dims) -> DataArray: 

37 return DataArray( 

38 data, 

39 coords=reduced_coords(prototype, dim), 

40 masks=reduced_masks(prototype, dim), 

41 attrs=reduced_attrs(prototype, dim), 

42 ) 

43 

44 

45def transform_data( 

46 obj: VariableLikeType, func: Callable[[Variable], Variable] 

47) -> VariableLikeType: 

48 if isinstance(obj, Variable): 

49 return func(obj) 

50 if isinstance(obj, DataArray): 

51 return rewrap_output_data(obj, func(obj.data)) 

52 else: 

53 raise TypeError(f"{func} only supports Variable and DataArray as inputs.") 

54 

55 

56def concrete_dims(obj: VariableLikeType, dim: Dims) -> tuple[str, ...]: 

57 """Convert a dimension specification into a concrete tuple of dimension labels. 

58 

59 This does *not* validate that the dimension labels are valid for the given object. 

60 """ 

61 if dim is None: 

62 if None in obj.dims: 

63 raise DimensionError( 

64 f'Got data group with unequal dimension lengths: dim={obj.dims}' 

65 ) 

66 return obj.dims 

67 return (dim,) if isinstance(dim, str) else tuple(dim) 

68 

69 

70def reduced_coords(da: DataArray, dim: Dims) -> dict[str, Variable]: 

71 return _reduced(da.coords, concrete_dims(da, dim)) 

72 

73 

74def reduced_attrs(da: DataArray, dim: Dims) -> dict[str, Variable]: 

75 return _reduced(da.deprecated_attrs, concrete_dims(da, dim)) 

76 

77 

78def reduced_masks(da: DataArray, dim: Dims) -> dict[str, Variable]: 

79 return _copied(_reduced(da.masks, concrete_dims(da, dim))) 

80 

81 

82def irreducible_mask(da: DataArray, dim: Dims) -> None | Variable: 

83 """ 

84 The union of masks that would need to be applied in a reduction op over dim. 

85 

86 Irreducible means that a reduction operation must apply these masks since they 

87 depend on the reduction dimensions. Returns None if there is no irreducible mask. 

88 """ 

89 dims = set(concrete_dims(da, dim)) 

90 irreducible = [mask for mask in da.masks.values() if not dims.isdisjoint(mask.dims)] 

91 if len(irreducible) == 0: 

92 return None 

93 

94 def _transposed_like_data(x: Variable) -> Variable: 

95 return x.transpose([d for d in da.dims if d in x.dims]) 

96 

97 if len(irreducible) == 1: 

98 return _transposed_like_data(irreducible[0]).copy() 

99 return _transposed_like_data(reduce(logical_or, irreducible))