Coverage for install/scipp/compat/wrapping.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-01 01:59 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Simon Heybrock 

4 

5from collections.abc import Callable, Mapping 

6from functools import wraps 

7from typing import Any, TypeVar 

8 

9from ..core import ( 

10 BinEdgeError, 

11 DataArray, 

12 DimensionError, 

13 Variable, 

14 VariancesError, 

15) 

16 

17 

18def _validated_masks(da: DataArray, dim: str) -> dict[str, Variable]: 

19 masks = {} 

20 for name, mask in da.masks.items(): 

21 if dim in mask.dims: 

22 raise DimensionError( 

23 f"Cannot apply function along '{dim}' since mask '{name}' depends " 

24 "on this dimension." 

25 ) 

26 masks[name] = mask.copy() 

27 return masks 

28 

29 

30_Out = TypeVar('_Out', bound=DataArray | Callable[..., DataArray]) 

31 

32 

33def wrap1d( 

34 is_partial: bool = False, accept_masks: bool = False, keep_coords: bool = False 

35) -> Callable[[Callable[..., _Out]], Callable[..., _Out]]: 

36 """Decorator factory for decorating functions that wrap non-scipp 1-D functions. 

37 

38 1-D functions are typically functions from libraries such as scipy that depend 

39 on a single 'axis' argument. 

40 

41 The decorators returned by this factory apply pre- and postprocessing as follows: 

42 

43 - An 'axis' keyword argument will raise ``ValueError``, recommending use of 'dim'. 

44 The index of the provided dimension is added as axis to kwargs. 

45 - Providing data with variances will raise ``sc.VariancesError`` since third-party 

46 libraries typically cannot handle variances. 

47 - Coordinates, masks, and attributes that act as "observers", i.e., do not depend 

48 on the dimension of the function application, are added to the output data array. 

49 Masks are deep-copied as per the usual requirement in Scipp. 

50 

51 Parameters 

52 ---------- 

53 is_partial: 

54 The wrapped function is partial, i.e., does not return a data 

55 array itself, but a callable that returns a data array. If true, 

56 the postprocessing step is not applied to the wrapped function. 

57 Instead, the callable returned by the decorated function is 

58 decorated with the postprocessing step. 

59 accept_masks: 

60 If false, all masks must apply to the dimension that 

61 the function is applied to. 

62 keep_coords: 

63 If true, preserve the input coordinates. 

64 If false, drop coordinates that do not apply to the dimension 

65 the function is applied to. 

66 """ 

67 

68 def decorator(func: Callable[..., _Out]) -> Callable[..., _Out]: 

69 @wraps(func) 

70 def function(da: DataArray, dim: str, **kwargs: Any) -> _Out: 

71 if 'axis' in kwargs: 

72 raise ValueError("Use the 'dim' keyword argument instead of 'axis'.") 

73 if da.variances is not None: 

74 raise VariancesError( 

75 "Cannot apply function to data with uncertainties. If uncertainties" 

76 " should be ignored, use 'sc.values(da)' to extract only values." 

77 ) 

78 if da.sizes[dim] != da.coords[dim].sizes[dim]: 

79 raise BinEdgeError( 

80 "Cannot apply function to data array with bin edges." 

81 ) 

82 

83 kwargs['axis'] = da.dims.index(dim) 

84 result = func(da, dim, **kwargs) 

85 return _postprocess( 

86 input_da=da, 

87 output_da=result, 

88 dim=dim, 

89 is_partial=is_partial, 

90 accept_masks=accept_masks, 

91 keep_coords=keep_coords, 

92 ) 

93 

94 return function 

95 

96 return decorator 

97 

98 

99def _postprocess( 

100 *, 

101 input_da: DataArray, 

102 output_da: _Out, 

103 dim: str, 

104 is_partial: bool, 

105 accept_masks: bool, 

106 keep_coords: bool, 

107) -> _Out: 

108 if accept_masks: 

109 masks = _remove_columns_in_dim(input_da.masks, dim) 

110 else: 

111 masks = _validated_masks(input_da, dim) 

112 if keep_coords: 

113 coords: Mapping[str, Variable] = input_da.coords 

114 attrs: Mapping[str, Variable] = input_da.deprecated_attrs 

115 else: 

116 coords = _remove_columns_in_dim(input_da.coords, dim) 

117 attrs = _remove_columns_in_dim(input_da.deprecated_attrs, dim) 

118 

119 def add_observing_metadata(da: DataArray) -> DataArray: 

120 # operates in-place! 

121 da.coords.update(coords) 

122 da.masks.update((key, mask.copy()) for key, mask in masks.items()) 

123 da.deprecated_attrs.update(attrs) 

124 return da 

125 

126 if is_partial: # corresponds to `not isinstance(out_da, DataArray)` 

127 

128 def postprocessing(func: Callable[..., DataArray]) -> Callable[..., DataArray]: 

129 @wraps(func) 

130 def function(*args: Any, **kwargs: Any) -> DataArray: 

131 return add_observing_metadata(func(*args, **kwargs)) 

132 

133 return function 

134 

135 return postprocessing(output_da) # type: ignore[arg-type, return-value] 

136 else: 

137 return add_observing_metadata(output_da) # type: ignore[arg-type, return-value] 

138 

139 

140def _remove_columns_in_dim( 

141 mapping: Mapping[str, Variable], dim: str 

142) -> dict[str, Variable]: 

143 return {key: var for key, var in mapping.items() if dim not in var.dims}