Coverage for install/scipp/core/concepts.py: 62%

42 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-04-28 01:28 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Simon Heybrock 

4from functools import reduce 

5from typing import Callable, Dict, List, Mapping, Tuple, Union 

6 

7from ..typing import Dims, VariableLikeType 

8from .cpp_classes import DataArray, Variable 

9from .logical import logical_or 

10 

11 

12def _copied(obj: Mapping[str, Variable]) -> Dict[str, Variable]: 

13 return {name: var.copy() for name, var in obj.items()} 

14 

15 

16def _reduced(obj: Mapping[str, Variable], dims: List[str]) -> Dict[str, Variable]: 

17 dims = set(dims) 

18 return {name: var for name, var in obj.items() if dims.isdisjoint(var.dims)} 

19 

20 

21def rewrap_output_data(prototype: VariableLikeType, data) -> VariableLikeType: 

22 if isinstance(prototype, DataArray): 

23 return DataArray( 

24 data=data, 

25 coords=prototype.coords, 

26 attrs=prototype.deprecated_attrs, 

27 masks=_copied(prototype.masks), 

28 ) 

29 else: 

30 return data 

31 

32 

33def rewrap_reduced_data( 

34 prototype: VariableLikeType, data, dim: Dims 

35) -> VariableLikeType: 

36 return DataArray( 

37 data, 

38 coords=reduced_coords(prototype, dim), 

39 masks=reduced_masks(prototype, dim), 

40 attrs=reduced_attrs(prototype, dim), 

41 ) 

42 

43 

44def transform_data(obj: VariableLikeType, func: Callable) -> VariableLikeType: 

45 if isinstance(obj, Variable): 

46 return func(obj) 

47 if isinstance(obj, DataArray): 

48 return rewrap_output_data(obj, func(obj.data)) 

49 else: 

50 raise TypeError(f"{func} only supports Variable and DataArray as inputs.") 

51 

52 

53def concrete_dims(obj: VariableLikeType, dim: Dims) -> Tuple[str]: 

54 """Convert a dimension specification into a concrete tuple of dimension labels. 

55 

56 This does *not* validate that the dimension labels are valid for the given object. 

57 """ 

58 if dim is None: 

59 return obj.dims 

60 return (dim,) if isinstance(dim, str) else tuple(dim) 

61 

62 

63def reduced_coords(da: DataArray, dim: Dims) -> Dict[str, Variable]: 

64 return _reduced(da.coords, concrete_dims(da, dim)) 

65 

66 

67def reduced_attrs(da: DataArray, dim: Dims) -> Dict[str, Variable]: 

68 return _reduced(da.deprecated_attrs, concrete_dims(da, dim)) 

69 

70 

71def reduced_masks(da: DataArray, dim: Dims) -> Dict[str, Variable]: 

72 return _copied(_reduced(da.masks, concrete_dims(da, dim))) 

73 

74 

75def irreducible_mask(da: DataArray, dim: Dims) -> Union[None, Variable]: 

76 """ 

77 The union of masks that would need to be applied in a reduction op over dim. 

78 

79 Irreducible means that a reduction operation must apply these masks since they 

80 depend on the reduction dimensions. Returns None if there is no irreducible mask. 

81 """ 

82 dims = set(concrete_dims(da, dim)) 

83 irreducible = [mask for mask in da.masks.values() if not dims.isdisjoint(mask.dims)] 

84 if len(irreducible) == 0: 

85 return None 

86 

87 def _transposed_like_data(x): 

88 return x.transpose([dim for dim in da.dims if dim in x.dims]) 

89 

90 if len(irreducible) == 1: 

91 return _transposed_like_data(irreducible[0]).copy() 

92 return _transposed_like_data(reduce(logical_or, irreducible))