Coverage for install/scipp/coords/rule.py: 51%
111 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-01 01:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-01 01:59 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Jan-Lukas Wynen
4"""
5Rules encode instructions for how to compute a coordinate in ``transform_coords``.
6They provide a common interface for renaming and computing new coordinates.
7"""
9from __future__ import annotations
11import inspect
12from abc import ABC, abstractmethod
13from collections.abc import Callable, Iterable, Mapping
14from copy import copy
15from functools import partial
16from typing import TYPE_CHECKING, Any
18from ..core import Variable
19from .coord import Coord
21if TYPE_CHECKING:
22 from typing import Protocol as _Protocol
24 # Importing CoordTable from coord_table.py would result in an import
25 # cycle because that module imports rule.py
26 # CoordTable is only needed for type annotations here,
27 # so a protocol is enough.
28 class _CoordProvider(_Protocol):
29 def consume(self, name: str) -> Coord:
30 pass
32else:
33 _Protocol = object
34 _CoordProvider = Any
36Kernel = Callable[..., Variable]
39class Rule(ABC):
40 def __init__(self, out_names: tuple[str, ...]):
41 self.out_names = out_names
43 @abstractmethod
44 def __call__(self, coords: _CoordProvider) -> dict[str, Coord]:
45 """Evaluate the rule."""
47 @property
48 @abstractmethod
49 def dependencies(self) -> tuple[str]:
50 """Return names of coords that this rule needs as inputs."""
52 def _format_out_names(self):
53 return f'({", ".join(self.out_names)})'
56class FetchRule(Rule):
57 """
58 Get coords from the provided dict-like sources.
60 Can be used to abstract away retrieving coords from the input DataArray.
61 """
63 def __init__(
64 self,
65 out_names: tuple[str, ...],
66 dense_sources: Mapping[str, Variable],
67 event_sources: Mapping[str, Variable],
68 ):
69 super().__init__(out_names)
70 self._dense_sources = dense_sources
71 self._event_sources = event_sources
73 def __call__(self, coords: _CoordProvider) -> dict[str, Coord]:
74 return {
75 out_name: Coord(
76 dense=self._dense_sources.get(out_name, None),
77 event=self._event_sources.get(out_name, None),
78 aligned=True,
79 )
80 for out_name in self.out_names
81 }
83 @property
84 def dependencies(self) -> tuple[str, ...]:
85 return ()
87 def __str__(self):
88 return f'Input {self._format_out_names()}'
91class RenameRule(Rule):
92 """
93 Return the input coordinate and give it a new name.
94 """
96 def __init__(self, out_names: tuple[str, ...], in_name: str):
97 super().__init__(out_names)
98 self._in_name = in_name
100 def __call__(self, coords: _CoordProvider) -> dict[str, Coord]:
101 # Shallow copy the _Coord object to allow the alias to have
102 # a different alignment and usage count than the original.
103 return {
104 out_name: copy(coords.consume(self._in_name)) for out_name in self.out_names
105 }
107 @property
108 def dependencies(self) -> tuple[str, ...]:
109 return (self._in_name,)
111 def __str__(self):
112 return f'Rename {self._format_out_names()} <- {self._in_name}'
115class ComputeRule(Rule):
116 """
117 Compute new coordinates using the provided callable.
118 """
120 def __init__(self, out_names: tuple[str, ...], func: Kernel):
121 super().__init__(out_names)
122 self._func = func
123 self._arg_names = _arg_names(func)
125 def __call__(self, coords: _CoordProvider) -> dict[str, Coord]:
126 inputs = {
127 name: coords.consume(coord) for coord, name in self._arg_names.items()
128 }
129 outputs = None
130 if any(coord.has_event for coord in inputs.values()):
131 outputs = self._compute_with_events(inputs)
132 if all(coord.has_dense for coord in inputs.values()):
133 dense_outputs = self._compute_pure_dense(inputs)
134 if outputs is None:
135 outputs = dense_outputs
136 else:
137 for name, coord in dense_outputs.items():
138 outputs[name].dense = coord.dense
139 return self._without_unrequested(outputs)
141 def _compute_pure_dense(self, inputs):
142 outputs = self._func(**{name: coord.dense for name, coord in inputs.items()})
143 outputs = self._to_dict(outputs)
144 return {
145 name: Coord(dense=var, event=None, aligned=True)
146 for name, var in outputs.items()
147 }
149 def _compute_with_events(self, inputs):
150 args = {
151 name: coord.event if coord.has_event else coord.dense
152 for name, coord in inputs.items()
153 }
154 outputs = self._to_dict(self._func(**args))
155 # Dense outputs may be produced as side effects of processing event
156 # coords.
157 outputs = {
158 name: Coord(
159 dense=var if var.bins is None else None,
160 event=var if var.bins is not None else None,
161 aligned=True,
162 )
163 for name, var in outputs.items()
164 }
165 return outputs
167 def _without_unrequested(self, d: dict[str, Any]) -> dict[str, Any]:
168 missing_outputs = [key for key in self.out_names if key not in d]
169 if missing_outputs:
170 raise TypeError(
171 f'transform_coords was expected to compute {missing_outputs} '
172 f'using `{self._func.__name__}` but the function returned '
173 f'{list(d.keys())} instead.'
174 )
175 return {key: d[key] for key in self.out_names}
177 def _to_dict(self, output) -> dict[str, Variable]:
178 if not isinstance(output, dict):
179 if len(self.out_names) != 1:
180 raise TypeError(
181 'Function returned a single output but '
182 f'{len(self.out_names)} were expected.'
183 )
184 return {self.out_names[0]: output}
185 return output
187 @property
188 def dependencies(self) -> tuple[str, ...]:
189 return tuple(self._arg_names)
191 @property
192 def func_name(self) -> str:
193 return self._func.__name__
195 def __str__(self):
196 # Class instances defining __call__ as well as objects created with
197 # functools.partial may/do not define __name__.
198 name = getattr(self._func, '__name__', repr(self._func))
199 return (
200 f'Compute {self._format_out_names()} = {name}'
201 f'({", ".join(self.dependencies)})'
202 )
205def rules_of_type(rules: list[Rule], rule_type: type) -> Iterable[Rule]:
206 yield from filter(lambda rule: isinstance(rule, rule_type), rules)
209def rule_output_names(rules: list[Rule], rule_type: type) -> Iterable[str]:
210 for rule in rules_of_type(rules, rule_type):
211 yield from rule.out_names
214def _arg_names(func) -> dict[str, str]:
215 spec = inspect.getfullargspec(func)
216 if spec.varargs is not None or spec.varkw is not None:
217 raise ValueError(
218 'Function with variable arguments not allowed in '
219 f'conversion graph: `{func.__name__}`.'
220 )
221 if inspect.isfunction(func) or func.__class__ == partial:
222 args = spec.args
223 else:
224 # Strip off the 'self'. Objects returned by functools.partial are not
225 # functions, but nevertheless do not have 'self'.
226 args = spec.args[1:]
227 names = tuple(args + spec.kwonlyargs)
228 coords = getattr(func, '__transform_coords_input_keys__', names)
229 return dict(zip(coords, names, strict=True))