Coverage for install/scipp/coords/rule.py: 51%
111 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Jan-Lukas Wynen
4"""
5Rules encode instructions for how to compute a coordinate in ``transform_coords``.
6They provide a common interface for renaming and computing new coordinates.
7"""
9from __future__ import annotations
11import inspect
12from abc import ABC, abstractmethod
13from copy import copy
14from functools import partial
15from typing import Any, Callable, Dict, Iterable, List, Mapping, Tuple
17from ..core import Variable
18from .coord import Coord
20try:
21 from typing import Protocol as _Protocol
23 # Importing CoordTable from coord_table.py would result in an import
24 # cycle because that module import rule.py
25 # CoordTable is only needed for type annotations here,
26 # so a protocol is enough.
27 class _CoordProvider(_Protocol):
28 def consume(self, name: str) -> Coord:
29 pass
31except ImportError:
32 _Protocol = object
33 _CoordProvider = Any
36class Rule(ABC):
37 def __init__(self, out_names: Tuple[str, ...]):
38 self.out_names = out_names
40 @abstractmethod
41 def __call__(self, coords: _CoordProvider) -> Dict[str, Coord]:
42 """Evaluate the rule."""
44 @property
45 @abstractmethod
46 def dependencies(self) -> Tuple[str]:
47 """Return names of coords that this rule needs as inputs."""
49 def _format_out_names(self):
50 return f'({", ".join(self.out_names)})'
53class FetchRule(Rule):
54 """
55 Get coords from the provided dict-like sources.
57 Can be used to abstract away retrieving coords from the input DataArray.
58 """
60 def __init__(
61 self,
62 out_names: Tuple[str, ...],
63 dense_sources: Mapping[str, Variable],
64 event_sources: Mapping[str, Variable],
65 ):
66 super().__init__(out_names)
67 self._dense_sources = dense_sources
68 self._event_sources = event_sources
70 def __call__(self, coords: _CoordProvider) -> Dict[str, Coord]:
71 return {
72 out_name: Coord(
73 dense=self._dense_sources.get(out_name, None),
74 event=self._event_sources.get(out_name, None),
75 aligned=True,
76 )
77 for out_name in self.out_names
78 }
80 @property
81 def dependencies(self) -> Tuple[str, ...]:
82 return ()
84 def __str__(self):
85 return f'Input {self._format_out_names()}'
88class RenameRule(Rule):
89 """
90 Return the input coordinate and give it a new name.
91 """
93 def __init__(self, out_names: Tuple[str, ...], in_name: str):
94 super().__init__(out_names)
95 self._in_name = in_name
97 def __call__(self, coords: _CoordProvider) -> Dict[str, Coord]:
98 # Shallow copy the _Coord object to allow the alias to have
99 # a different alignment and usage count than the original.
100 return {
101 out_name: copy(coords.consume(self._in_name)) for out_name in self.out_names
102 }
104 @property
105 def dependencies(self) -> Tuple[str, ...]:
106 return (self._in_name,)
108 def __str__(self):
109 return f'Rename {self._format_out_names()} <- {self._in_name}'
112class ComputeRule(Rule):
113 """
114 Compute new coordinates using the provided callable.
115 """
117 def __init__(self, out_names: Tuple[str, ...], func: Callable):
118 super().__init__(out_names)
119 self._func = func
120 self._arg_names = _arg_names(func)
122 def __call__(self, coords: _CoordProvider) -> Dict[str, Coord]:
123 inputs = {
124 name: coords.consume(coord) for coord, name in self._arg_names.items()
125 }
126 outputs = None
127 if any(coord.has_event for coord in inputs.values()):
128 outputs = self._compute_with_events(inputs)
129 if all(coord.has_dense for coord in inputs.values()):
130 dense_outputs = self._compute_pure_dense(inputs)
131 if outputs is None:
132 outputs = dense_outputs
133 else:
134 for name, coord in dense_outputs.items():
135 outputs[name].dense = coord.dense
136 return self._without_unrequested(outputs)
138 def _compute_pure_dense(self, inputs):
139 outputs = self._func(**{name: coord.dense for name, coord in inputs.items()})
140 outputs = self._to_dict(outputs)
141 return {
142 name: Coord(dense=var, event=None, aligned=True)
143 for name, var in outputs.items()
144 }
146 def _compute_with_events(self, inputs):
147 args = {
148 name: coord.event if coord.has_event else coord.dense
149 for name, coord in inputs.items()
150 }
151 outputs = self._to_dict(self._func(**args))
152 # Dense outputs may be produced as side effects of processing event
153 # coords.
154 outputs = {
155 name: Coord(
156 dense=var if var.bins is None else None,
157 event=var if var.bins is not None else None,
158 aligned=True,
159 )
160 for name, var in outputs.items()
161 }
162 return outputs
164 def _without_unrequested(self, d: Dict[str, Any]) -> Dict[str, Any]:
165 missing_outputs = [key for key in self.out_names if key not in d]
166 if missing_outputs:
167 raise TypeError(
168 f'transform_coords was expected to compute {missing_outputs} '
169 f'using `{self._func.__name__}` but the function returned '
170 f'{list(d.keys())} instead.'
171 )
172 return {key: d[key] for key in self.out_names}
174 def _to_dict(self, output) -> Dict[str, Variable]:
175 if not isinstance(output, dict):
176 if len(self.out_names) != 1:
177 raise TypeError(
178 'Function returned a single output but '
179 f'{len(self.out_names)} were expected.'
180 )
181 return {self.out_names[0]: output}
182 return output
184 @property
185 def dependencies(self) -> Tuple[str, ...]:
186 return tuple(self._arg_names)
188 @property
189 def func_name(self) -> str:
190 return self._func.__name__
192 def __str__(self):
193 # Class instances defining __call__ as well as objects created with
194 # functools.partial may/do not define __name__.
195 name = getattr(self._func, '__name__', repr(self._func))
196 return (
197 f'Compute {self._format_out_names()} = {name}'
198 f'({", ".join(self.dependencies)})'
199 )
202def rules_of_type(rules: List[Rule], rule_type: type) -> Iterable[Rule]:
203 yield from filter(lambda rule: isinstance(rule, rule_type), rules)
206def rule_output_names(rules: List[Rule], rule_type: type) -> Iterable[str]:
207 for rule in rules_of_type(rules, rule_type):
208 yield from rule.out_names
211def _arg_names(func) -> Dict[str, str]:
212 spec = inspect.getfullargspec(func)
213 if spec.varargs is not None or spec.varkw is not None:
214 raise ValueError(
215 'Function with variable arguments not allowed in '
216 f'conversion graph: `{func.__name__}`.'
217 )
218 if inspect.isfunction(func) or func.__class__ == partial:
219 args = spec.args
220 else:
221 # Strip off the 'self'. Objects returned by functools.partial are not
222 # functions, but nevertheless do not have 'self'.
223 args = spec.args[1:]
224 names = tuple(args + spec.kwonlyargs)
225 coords = getattr(func, '__transform_coords_input_keys__', names)
226 return dict(zip(coords, names))