Coverage for install/scipp/compat/wrapping.py: 100%

47 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-04-28 01:28 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Simon Heybrock 

4 

5from functools import wraps 

6from typing import Callable, Union 

7 

8from ..core import BinEdgeError, DataArray, DimensionError, VariancesError 

9 

10 

11def _validated_masks(da, dim): 

12 masks = {} 

13 for name, mask in da.masks.items(): 

14 if dim in mask.dims: 

15 raise DimensionError( 

16 f"Cannot apply function along '{dim}' since mask '{name}' depends " 

17 "on this dimension." 

18 ) 

19 masks[name] = mask.copy() 

20 return masks 

21 

22 

23def wrap1d(is_partial=False, accept_masks=False, keep_coords=False): 

24 """Decorator factory for decorating functions that wrap non-scipp 1-D functions. 

25 

26 1-D functions are typically functions from libraries such as scipy that depend 

27 on a single 'axis' argument. 

28 

29 The decorators returned by this factory apply pre- and postprocessing as follows: 

30 

31 - An 'axis' keyword argument will raise ``ValueError``, recommending use of 'dim'. 

32 The index of the provided dimension is added as axis to kwargs. 

33 - Providing data with variances will raise ``sc.VariancesError`` since third-party 

34 libraries typically cannot handle variances. 

35 - Coordinates, masks, and attributes that act as "observers", i.e., do not depend 

36 on the dimension of the function application, are added to the output data array. 

37 Masks are deep-copied as per the usual requirement in Scipp. 

38 

39 Parameters 

40 ---------- 

41 is_partial: 

42 The wrapped function is partial, i.e., does not return a data 

43 array itself, but a callable that returns a data array. If true, 

44 the postprocessing step is not applied to the wrapped function. 

45 Instead, the callable returned by the decorated function is 

46 decorated with the postprocessing step. 

47 """ 

48 

49 def decorator(func: Callable) -> Callable: 

50 @wraps(func) 

51 def function(da: DataArray, dim: str, **kwargs) -> Union[DataArray, Callable]: 

52 if 'axis' in kwargs: 

53 raise ValueError("Use the 'dim' keyword argument instead of 'axis'.") 

54 if da.variances is not None: 

55 raise VariancesError( 

56 "Cannot apply function to data with uncertainties. If uncertainties" 

57 " should be ignored, use 'sc.values(da)' to extract only values." 

58 ) 

59 if da.sizes[dim] != da.coords[dim].sizes[dim]: 

60 raise BinEdgeError( 

61 "Cannot apply function to data array with bin edges." 

62 ) 

63 

64 kwargs['axis'] = da.dims.index(dim) 

65 

66 if accept_masks: 

67 masks = {k: v for k, v in da.masks.items() if dim not in v.dims} 

68 else: 

69 masks = _validated_masks(da, dim) 

70 if keep_coords: 

71 coords = da.coords 

72 attrs = da.deprecated_attrs 

73 else: 

74 coords = {k: v for k, v in da.coords.items() if dim not in v.dims} 

75 attrs = { 

76 k: v for k, v in da.deprecated_attrs.items() if dim not in v.dims 

77 } 

78 

79 def _add_observing_metadata(da): 

80 for k, v in coords.items(): 

81 da.coords[k] = v 

82 for k, v in masks.items(): 

83 da.masks[k] = v.copy() 

84 for k, v in attrs.items(): 

85 da.deprecated_attrs[k] = v 

86 return da 

87 

88 def postprocessing(func): 

89 @wraps(func) 

90 def function(*args, **kwargs): 

91 return _add_observing_metadata(func(*args, **kwargs)) 

92 

93 return function 

94 

95 if is_partial: 

96 return postprocessing(func(da, dim, **kwargs)) 

97 else: 

98 return _add_observing_metadata(func(da, dim, **kwargs)) 

99 

100 return function 

101 

102 return decorator