Coverage for install/scipp/coords/rule.py: 51%

111 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-17 01:51 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Jan-Lukas Wynen 

4""" 

5Rules encode instructions for how to compute a coordinate in ``transform_coords``. 

6They provide a common interface for renaming and computing new coordinates. 

7""" 

8 

9from __future__ import annotations 

10 

11import inspect 

12from abc import ABC, abstractmethod 

13from collections.abc import Callable, Iterable, Mapping 

14from copy import copy 

15from functools import partial 

16from typing import TYPE_CHECKING, Any 

17 

18from ..core import Variable 

19from .coord import Coord 

20 

21if TYPE_CHECKING: 

22 from typing import Protocol as _Protocol 

23 

24 # Importing CoordTable from coord_table.py would result in an import 

25 # cycle because that module imports rule.py 

26 # CoordTable is only needed for type annotations here, 

27 # so a protocol is enough. 

28 class _CoordProvider(_Protocol): 

29 def consume(self, name: str) -> Coord: 

30 pass 

31 

32else: 

33 _Protocol = object 

34 _CoordProvider = Any 

35 

36Kernel = Callable[..., Variable] 

37 

38 

39class Rule(ABC): 

40 def __init__(self, out_names: tuple[str, ...]): 

41 self.out_names = out_names 

42 

43 @abstractmethod 

44 def __call__(self, coords: _CoordProvider) -> dict[str, Coord]: 

45 """Evaluate the rule.""" 

46 

47 @property 

48 @abstractmethod 

49 def dependencies(self) -> tuple[str]: 

50 """Return names of coords that this rule needs as inputs.""" 

51 

52 def _format_out_names(self): 

53 return f'({", ".join(self.out_names)})' 

54 

55 

56class FetchRule(Rule): 

57 """ 

58 Get coords from the provided dict-like sources. 

59 

60 Can be used to abstract away retrieving coords from the input DataArray. 

61 """ 

62 

63 def __init__( 

64 self, 

65 out_names: tuple[str, ...], 

66 dense_sources: Mapping[str, Variable], 

67 event_sources: Mapping[str, Variable], 

68 ): 

69 super().__init__(out_names) 

70 self._dense_sources = dense_sources 

71 self._event_sources = event_sources 

72 

73 def __call__(self, coords: _CoordProvider) -> dict[str, Coord]: 

74 return { 

75 out_name: Coord( 

76 dense=self._dense_sources.get(out_name, None), 

77 event=self._event_sources.get(out_name, None), 

78 aligned=True, 

79 ) 

80 for out_name in self.out_names 

81 } 

82 

83 @property 

84 def dependencies(self) -> tuple[str, ...]: 

85 return () 

86 

87 def __str__(self): 

88 return f'Input {self._format_out_names()}' 

89 

90 

91class RenameRule(Rule): 

92 """ 

93 Return the input coordinate and give it a new name. 

94 """ 

95 

96 def __init__(self, out_names: tuple[str, ...], in_name: str): 

97 super().__init__(out_names) 

98 self._in_name = in_name 

99 

100 def __call__(self, coords: _CoordProvider) -> dict[str, Coord]: 

101 # Shallow copy the _Coord object to allow the alias to have 

102 # a different alignment and usage count than the original. 

103 return { 

104 out_name: copy(coords.consume(self._in_name)) for out_name in self.out_names 

105 } 

106 

107 @property 

108 def dependencies(self) -> tuple[str, ...]: 

109 return (self._in_name,) 

110 

111 def __str__(self): 

112 return f'Rename {self._format_out_names()} <- {self._in_name}' 

113 

114 

115class ComputeRule(Rule): 

116 """ 

117 Compute new coordinates using the provided callable. 

118 """ 

119 

120 def __init__(self, out_names: tuple[str, ...], func: Kernel): 

121 super().__init__(out_names) 

122 self._func = func 

123 self._arg_names = _arg_names(func) 

124 

125 def __call__(self, coords: _CoordProvider) -> dict[str, Coord]: 

126 inputs = { 

127 name: coords.consume(coord) for coord, name in self._arg_names.items() 

128 } 

129 outputs = None 

130 if any(coord.has_event for coord in inputs.values()): 

131 outputs = self._compute_with_events(inputs) 

132 if all(coord.has_dense for coord in inputs.values()): 

133 dense_outputs = self._compute_pure_dense(inputs) 

134 if outputs is None: 

135 outputs = dense_outputs 

136 else: 

137 for name, coord in dense_outputs.items(): 

138 outputs[name].dense = coord.dense 

139 return self._without_unrequested(outputs) 

140 

141 def _compute_pure_dense(self, inputs): 

142 outputs = self._func(**{name: coord.dense for name, coord in inputs.items()}) 

143 outputs = self._to_dict(outputs) 

144 return { 

145 name: Coord(dense=var, event=None, aligned=True) 

146 for name, var in outputs.items() 

147 } 

148 

149 def _compute_with_events(self, inputs): 

150 args = { 

151 name: coord.event if coord.has_event else coord.dense 

152 for name, coord in inputs.items() 

153 } 

154 outputs = self._to_dict(self._func(**args)) 

155 # Dense outputs may be produced as side effects of processing event 

156 # coords. 

157 outputs = { 

158 name: Coord( 

159 dense=var if var.bins is None else None, 

160 event=var if var.bins is not None else None, 

161 aligned=True, 

162 ) 

163 for name, var in outputs.items() 

164 } 

165 return outputs 

166 

167 def _without_unrequested(self, d: dict[str, Any]) -> dict[str, Any]: 

168 missing_outputs = [key for key in self.out_names if key not in d] 

169 if missing_outputs: 

170 raise TypeError( 

171 f'transform_coords was expected to compute {missing_outputs} ' 

172 f'using `{self._func.__name__}` but the function returned ' 

173 f'{list(d.keys())} instead.' 

174 ) 

175 return {key: d[key] for key in self.out_names} 

176 

177 def _to_dict(self, output) -> dict[str, Variable]: 

178 if not isinstance(output, dict): 

179 if len(self.out_names) != 1: 

180 raise TypeError( 

181 'Function returned a single output but ' 

182 f'{len(self.out_names)} were expected.' 

183 ) 

184 return {self.out_names[0]: output} 

185 return output 

186 

187 @property 

188 def dependencies(self) -> tuple[str, ...]: 

189 return tuple(self._arg_names) 

190 

191 @property 

192 def func_name(self) -> str: 

193 return self._func.__name__ 

194 

195 def __str__(self): 

196 # Class instances defining __call__ as well as objects created with 

197 # functools.partial may/do not define __name__. 

198 name = getattr(self._func, '__name__', repr(self._func)) 

199 return ( 

200 f'Compute {self._format_out_names()} = {name}' 

201 f'({", ".join(self.dependencies)})' 

202 ) 

203 

204 

205def rules_of_type(rules: list[Rule], rule_type: type) -> Iterable[Rule]: 

206 yield from filter(lambda rule: isinstance(rule, rule_type), rules) 

207 

208 

209def rule_output_names(rules: list[Rule], rule_type: type) -> Iterable[str]: 

210 for rule in rules_of_type(rules, rule_type): 

211 yield from rule.out_names 

212 

213 

214def _arg_names(func) -> dict[str, str]: 

215 spec = inspect.getfullargspec(func) 

216 if spec.varargs is not None or spec.varkw is not None: 

217 raise ValueError( 

218 'Function with variable arguments not allowed in ' 

219 f'conversion graph: `{func.__name__}`.' 

220 ) 

221 if inspect.isfunction(func) or func.__class__ == partial: 

222 args = spec.args 

223 else: 

224 # Strip off the 'self'. Objects returned by functools.partial are not 

225 # functions, but nevertheless do not have 'self'. 

226 args = spec.args[1:] 

227 names = tuple(args + spec.kwonlyargs) 

228 coords = getattr(func, '__transform_coords_input_keys__', names) 

229 return dict(zip(coords, names, strict=True))