Coverage for install/scipp/coords/rule.py: 51%

111 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-04-28 01:28 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Jan-Lukas Wynen 

4""" 

5Rules encode instructions for how to compute a coordinate in ``transform_coords``. 

6They provide a common interface for renaming and computing new coordinates. 

7""" 

8 

9from __future__ import annotations 

10 

11import inspect 

12from abc import ABC, abstractmethod 

13from copy import copy 

14from functools import partial 

15from typing import Any, Callable, Dict, Iterable, List, Mapping, Tuple 

16 

17from ..core import Variable 

18from .coord import Coord 

19 

20try: 

21 from typing import Protocol as _Protocol 

22 

23 # Importing CoordTable from coord_table.py would result in an import 

24 # cycle because that module import rule.py 

25 # CoordTable is only needed for type annotations here, 

26 # so a protocol is enough. 

27 class _CoordProvider(_Protocol): 

28 def consume(self, name: str) -> Coord: 

29 pass 

30 

31except ImportError: 

32 _Protocol = object 

33 _CoordProvider = Any 

34 

35 

36class Rule(ABC): 

37 def __init__(self, out_names: Tuple[str, ...]): 

38 self.out_names = out_names 

39 

40 @abstractmethod 

41 def __call__(self, coords: _CoordProvider) -> Dict[str, Coord]: 

42 """Evaluate the rule.""" 

43 

44 @property 

45 @abstractmethod 

46 def dependencies(self) -> Tuple[str]: 

47 """Return names of coords that this rule needs as inputs.""" 

48 

49 def _format_out_names(self): 

50 return f'({", ".join(self.out_names)})' 

51 

52 

53class FetchRule(Rule): 

54 """ 

55 Get coords from the provided dict-like sources. 

56 

57 Can be used to abstract away retrieving coords from the input DataArray. 

58 """ 

59 

60 def __init__( 

61 self, 

62 out_names: Tuple[str, ...], 

63 dense_sources: Mapping[str, Variable], 

64 event_sources: Mapping[str, Variable], 

65 ): 

66 super().__init__(out_names) 

67 self._dense_sources = dense_sources 

68 self._event_sources = event_sources 

69 

70 def __call__(self, coords: _CoordProvider) -> Dict[str, Coord]: 

71 return { 

72 out_name: Coord( 

73 dense=self._dense_sources.get(out_name, None), 

74 event=self._event_sources.get(out_name, None), 

75 aligned=True, 

76 ) 

77 for out_name in self.out_names 

78 } 

79 

80 @property 

81 def dependencies(self) -> Tuple[str, ...]: 

82 return () 

83 

84 def __str__(self): 

85 return f'Input {self._format_out_names()}' 

86 

87 

88class RenameRule(Rule): 

89 """ 

90 Return the input coordinate and give it a new name. 

91 """ 

92 

93 def __init__(self, out_names: Tuple[str, ...], in_name: str): 

94 super().__init__(out_names) 

95 self._in_name = in_name 

96 

97 def __call__(self, coords: _CoordProvider) -> Dict[str, Coord]: 

98 # Shallow copy the _Coord object to allow the alias to have 

99 # a different alignment and usage count than the original. 

100 return { 

101 out_name: copy(coords.consume(self._in_name)) for out_name in self.out_names 

102 } 

103 

104 @property 

105 def dependencies(self) -> Tuple[str, ...]: 

106 return (self._in_name,) 

107 

108 def __str__(self): 

109 return f'Rename {self._format_out_names()} <- {self._in_name}' 

110 

111 

112class ComputeRule(Rule): 

113 """ 

114 Compute new coordinates using the provided callable. 

115 """ 

116 

117 def __init__(self, out_names: Tuple[str, ...], func: Callable): 

118 super().__init__(out_names) 

119 self._func = func 

120 self._arg_names = _arg_names(func) 

121 

122 def __call__(self, coords: _CoordProvider) -> Dict[str, Coord]: 

123 inputs = { 

124 name: coords.consume(coord) for coord, name in self._arg_names.items() 

125 } 

126 outputs = None 

127 if any(coord.has_event for coord in inputs.values()): 

128 outputs = self._compute_with_events(inputs) 

129 if all(coord.has_dense for coord in inputs.values()): 

130 dense_outputs = self._compute_pure_dense(inputs) 

131 if outputs is None: 

132 outputs = dense_outputs 

133 else: 

134 for name, coord in dense_outputs.items(): 

135 outputs[name].dense = coord.dense 

136 return self._without_unrequested(outputs) 

137 

138 def _compute_pure_dense(self, inputs): 

139 outputs = self._func(**{name: coord.dense for name, coord in inputs.items()}) 

140 outputs = self._to_dict(outputs) 

141 return { 

142 name: Coord(dense=var, event=None, aligned=True) 

143 for name, var in outputs.items() 

144 } 

145 

146 def _compute_with_events(self, inputs): 

147 args = { 

148 name: coord.event if coord.has_event else coord.dense 

149 for name, coord in inputs.items() 

150 } 

151 outputs = self._to_dict(self._func(**args)) 

152 # Dense outputs may be produced as side effects of processing event 

153 # coords. 

154 outputs = { 

155 name: Coord( 

156 dense=var if var.bins is None else None, 

157 event=var if var.bins is not None else None, 

158 aligned=True, 

159 ) 

160 for name, var in outputs.items() 

161 } 

162 return outputs 

163 

164 def _without_unrequested(self, d: Dict[str, Any]) -> Dict[str, Any]: 

165 missing_outputs = [key for key in self.out_names if key not in d] 

166 if missing_outputs: 

167 raise TypeError( 

168 f'transform_coords was expected to compute {missing_outputs} ' 

169 f'using `{self._func.__name__}` but the function returned ' 

170 f'{list(d.keys())} instead.' 

171 ) 

172 return {key: d[key] for key in self.out_names} 

173 

174 def _to_dict(self, output) -> Dict[str, Variable]: 

175 if not isinstance(output, dict): 

176 if len(self.out_names) != 1: 

177 raise TypeError( 

178 'Function returned a single output but ' 

179 f'{len(self.out_names)} were expected.' 

180 ) 

181 return {self.out_names[0]: output} 

182 return output 

183 

184 @property 

185 def dependencies(self) -> Tuple[str, ...]: 

186 return tuple(self._arg_names) 

187 

188 @property 

189 def func_name(self) -> str: 

190 return self._func.__name__ 

191 

192 def __str__(self): 

193 # Class instances defining __call__ as well as objects created with 

194 # functools.partial may/do not define __name__. 

195 name = getattr(self._func, '__name__', repr(self._func)) 

196 return ( 

197 f'Compute {self._format_out_names()} = {name}' 

198 f'({", ".join(self.dependencies)})' 

199 ) 

200 

201 

202def rules_of_type(rules: List[Rule], rule_type: type) -> Iterable[Rule]: 

203 yield from filter(lambda rule: isinstance(rule, rule_type), rules) 

204 

205 

206def rule_output_names(rules: List[Rule], rule_type: type) -> Iterable[str]: 

207 for rule in rules_of_type(rules, rule_type): 

208 yield from rule.out_names 

209 

210 

211def _arg_names(func) -> Dict[str, str]: 

212 spec = inspect.getfullargspec(func) 

213 if spec.varargs is not None or spec.varkw is not None: 

214 raise ValueError( 

215 'Function with variable arguments not allowed in ' 

216 f'conversion graph: `{func.__name__}`.' 

217 ) 

218 if inspect.isfunction(func) or func.__class__ == partial: 

219 args = spec.args 

220 else: 

221 # Strip off the 'self'. Objects returned by functools.partial are not 

222 # functions, but nevertheless do not have 'self'. 

223 args = spec.args[1:] 

224 names = tuple(args + spec.kwonlyargs) 

225 coords = getattr(func, '__transform_coords_input_keys__', names) 

226 return dict(zip(coords, names))