Coverage for install/scipp/coords/transform_coords.py: 80%
113 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-01 01:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-01 01:59 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock, Jan-Lukas Wynen
4from collections.abc import Iterable, Mapping
5from dataclasses import fields
6from fractions import Fraction
8from ..core import DataArray, Dataset, DimensionError, VariableError, bins, empty
9from ..logging import get_logger
10from .coord_table import Coord, CoordTable
11from .graph import Graph, GraphDict, rule_sequence
12from .options import Options
13from .rule import ComputeRule, FetchRule, Kernel, RenameRule, Rule, rule_output_names
16def transform_coords(
17 x: DataArray | Dataset,
18 targets: str | Iterable[str] | None = None,
19 /,
20 graph: GraphDict | None = None,
21 *,
22 rename_dims: bool = True,
23 keep_aliases: bool = True,
24 keep_intermediate: bool = True,
25 keep_inputs: bool = True,
26 quiet: bool = False,
27 **kwargs: Kernel,
28) -> DataArray | Dataset:
29 """Compute new coords based on transformations of input coords.
31 See the section in the user guide on
32 `Coordinate transformations <../../user-guide/coordinate-transformations.rst>`_
33 for detailed explanations.
35 Parameters
36 ----------
37 x:
38 Input object with coords.
39 targets:
40 Name or list of names of desired output coords.
41 graph:
42 A graph defining how new coords can be computed from existing
43 coords. This may be done in multiple steps.
44 The graph is given by a :class:`dict` where:
46 - Dict keys are :class:`str` or :class:`tuple` of :class:`str`,
47 defining the names of outputs generated by a dict value.
48 - Dict values are :class:`str` or a callable (function).
49 If :class:`str`, this is a synonym for renaming a coord.
50 If a callable, it must either return a single variable or a dict of
51 variables. The argument names of callables must be coords in ``x`` or be
52 computable by other nodes in ``graph``. The coord names can be overridden by
53 the callable by providing a ``__transform_coords_input_keys__`` property,
54 returning a list of coord names in the same order as the function arguments.
55 rename_dims:
56 Rename dimensions if the corresponding dimension coords
57 are used as inputs and there is a single output coord
58 that can be associated with that dimension.
59 See the user guide for more details and examples.
60 Default is True.
61 keep_aliases:
62 If True, include aliases in the output.
63 Default is True.
64 keep_intermediate:
65 If True, include intermediate results in the output.
66 Default is True.
67 keep_inputs:
68 Include consumed input coordinates in the output.
69 Default is True.
70 quiet:
71 If True, no log output is produced. Otherwise, ``transform_coords``
72 produces a log of its actions.
73 **kwargs:
74 Mapping of coords to callables. This can be used as an alternate and brief
75 way of specifying targets and graph. If provided, neither ``targets`` nor
76 ``graph`` may be given.
78 Returns
79 -------
80 :
81 New object with desired coords. Existing data and meta-data is shallow-copied.
83 Examples
84 --------
86 Transform input coordinates ``x`` and ``y`` to a new output coordinate ``xy``:
88 >>> da = sc.data.table_xyz(nrow=10)
89 >>> transformed = da.transform_coords(xy=lambda x, y: x + y)
91 Equivalent full syntax based on a target name and a graph:
93 >>> da = sc.data.table_xyz(nrow=10)
94 >>> transformed = da.transform_coords('xy', graph={'xy': lambda x, y: x + y})
96 Multiple new coordinates can be computed at once. Here ``z2`` is setup as an alias
97 of ``z``:
99 >>> da = sc.data.table_xyz(nrow=10)
100 >>> transformed = da.transform_coords(xy=lambda x, y: x + y, z2='z')
102 This is equivalent to
104 >>> da = sc.data.table_xyz(nrow=10)
105 >>> graph = {'xy': lambda x, y: x + y, 'z2':'z'}
106 >>> transformed = da.transform_coords(['xy', 'z2'], graph=graph)
108 Multi-step transformations that do not keep intermediate results as coordinates can
109 be performed with a graph containing nodes that depend on outputs of other nodes:
111 >>> da = sc.data.table_xyz(nrow=10)
112 >>> graph = {'xy': lambda x, y: x + y, 'xyz': lambda xy, z: xy + z}
113 >>> transformed = da.transform_coords('xyz', graph=graph)
114 """
115 options = Options(
116 rename_dims=rename_dims,
117 keep_aliases=keep_aliases,
118 keep_intermediate=keep_intermediate,
119 keep_inputs=keep_inputs,
120 quiet=quiet,
121 )
122 for field in fields(options):
123 if not isinstance(getattr(options, field.name), bool):
124 raise TypeError(
125 f"'{field.name}' is a reserved for keyword argument. "
126 "Use explicit targets and graph arguments to create an output "
127 "coordinate of this name."
128 )
130 if kwargs:
131 if targets is not None or graph is not None:
132 raise ValueError(
133 "Explicit targets or graph not allowed since keyword arguments "
134 f"{kwargs} define targets and graph."
135 )
137 if targets is None:
138 targets = set(kwargs)
139 graph = kwargs
140 else:
141 targets = {targets} if isinstance(targets, str) else set(targets)
143 _transform = _transform_dataset if isinstance(x, Dataset) else _transform_data_array
144 return _transform(x, targets=targets, graph=Graph(graph), options=options)
147def show_graph(graph: GraphDict, size: str | None = None, simplified: bool = False):
148 """Show graphical representation of a graph as required by
149 :py:func:`transform_coords`
151 Requires the `python-graphviz` package.
153 Parameters
154 ----------
155 graph:
156 Transformation graph to show.
157 size:
158 Size forwarded to graphviz, must be a string, "width,height"
159 or "size". In the latter case, the same value is used for
160 both width and height.
161 simplified:
162 If ``True``, do not show the conversion functions,
163 only the potential input and output coordinates.
165 Returns
166 -------
167 graph: graphviz.Digraph
168 Can be displayed directly in Jupyter.
169 See the
170 `documentation <https://graphviz.readthedocs.io/en/stable/api.html#graphviz.Digraph>`_
171 for details.
173 Raises
174 ------
175 RuntimeError
176 If graphviz is not installed.
177 """
178 return Graph(graph).show(size=size, simplified=simplified)
181def _transform_data_array(
182 original: DataArray, targets: set[str], graph: Graph, options: Options
183) -> DataArray:
184 graph = graph.graph_for(original, targets)
185 rules = rule_sequence(graph)
186 working_coords = CoordTable(rules, targets, options)
187 dim_coords = set()
188 for rule in rules:
189 for name, coord in rule(working_coords).items():
190 working_coords.add(name, coord)
191 # Check if coord is a dimension-coord. Need to also check if it is in the
192 # data dimensions because slicing can produce coords with dims that are
193 # no longer in the data.
194 if name in original.dims and coord.has_dim(name):
195 dim_coords.add(name)
197 dim_name_changes = (
198 _dim_name_changes(graph, dim_coords) if options.rename_dims else {}
199 )
200 if not options.quiet:
201 _log_transform(rules, targets, dim_name_changes, working_coords)
202 res = _store_results(original, working_coords, targets)
203 return res.rename_dims(dim_name_changes)
206def _transform_dataset(
207 original: Dataset, targets: set[str], graph: Graph, *, options: Options
208) -> Dataset:
209 # Note the inefficiency here in datasets with multiple items: Coord transform is
210 # repeated for every item rather than sharing what is possible. Since we may have
211 # dataset items with binned data this is far from trivial. Unless we have clear
212 # performance requirements, we go with the safe and simple solution.
213 if len(original) > 0:
214 return Dataset(
215 data={
216 name: _transform_data_array(
217 original[name], targets=targets, graph=graph, options=options
218 )
219 for name in original
220 }
221 )
222 dummy = DataArray(empty(sizes=original.sizes), coords=original.coords)
223 transformed = _transform_data_array(
224 dummy, targets=targets, graph=graph, options=options
225 )
226 return Dataset(coords=transformed.coords)
229def _log_transform(
230 rules: list[Rule],
231 targets: set[str],
232 dim_name_changes: Mapping[str, str],
233 coords: CoordTable,
234) -> None:
235 inputs = set(rule_output_names(rules, FetchRule))
236 byproducts = {
237 name
238 for name in (
239 set(rule_output_names(rules, RenameRule))
240 | set(rule_output_names(rules, ComputeRule))
241 )
242 - targets
243 if coords.total_usages(name) < 0
244 }
245 preexisting = {target for target in targets if target in inputs}
246 steps = [rule for rule in rules if not isinstance(rule, FetchRule)]
248 message = (
249 f'Transformed coords ({", ".join(sorted(inputs))}) '
250 f'-> ({", ".join(sorted(targets))})'
251 )
252 if byproducts:
253 message += f'\n Byproducts:\n {", ".join(sorted(byproducts))}'
254 if dim_name_changes:
255 dim_rename_steps = '\n'.join(
256 f' {t} <- {f}' for f, t in dim_name_changes.items()
257 )
258 message += '\n Renamed dimensions:\n' + dim_rename_steps
259 if preexisting:
260 message += (
261 '\n Outputs already present in input:'
262 f'\n {", ".join(sorted(preexisting))}'
263 )
264 message += '\n Steps:\n' + (
265 '\n'.join(f' {rule}' for rule in steps) if steps else ' None'
266 )
268 get_logger().info(message)
271def _store_coord(da: DataArray, name: str, coord: Coord) -> None:
272 def try_del():
273 da.coords.pop(name, None)
274 if da.bins is not None:
275 da.bins.coords.pop(name, None)
277 def store(x, c):
278 x.coords[name] = c
279 x.coords.set_aligned(name, coord.aligned)
281 if coord.usages == 0:
282 try_del()
283 else:
284 if coord.has_dense:
285 store(da, coord.dense)
286 if coord.has_event:
287 try:
288 store(da.bins, coord.event)
289 except (DimensionError, VariableError):
290 # Thrown on mismatching bin indices, e.g. slice
291 da.data = da.data.copy()
292 store(da.bins, coord.event)
295def _store_results(da: DataArray, coords: CoordTable, targets: set[str]) -> DataArray:
296 da = da.copy(deep=False)
297 # See #2773 for why this is necessary.
298 if da.bins is not None:
299 da.data = bins(**da.bins.constituents)
300 for name, coord in coords.items():
301 if name in targets:
302 coord.aligned = True
303 _store_coord(da, name, coord)
304 return da
307def _color_dims(graph: Graph, dim_coords: set[str]) -> dict[str, dict[str, Fraction]]:
308 colors = {
309 coord: {dim: Fraction(0, 1) for dim in dim_coords} for coord in graph.nodes()
310 }
311 for dim in dim_coords:
312 colors[dim][dim] = Fraction(1, 1)
313 depth_first_stack = [dim]
314 while depth_first_stack:
315 coord = depth_first_stack.pop()
316 children = tuple(graph.children_of(coord))
317 for child in children:
318 # test for produced dim coords
319 if child not in dim_coords:
320 colors[child][dim] += colors[coord][dim] * Fraction(
321 1, len(children)
322 )
323 depth_first_stack.extend(children)
325 return colors
328def _has_full_color_of_dim(colors: dict[str, Fraction], dim: str) -> bool:
329 return all(
330 fraction == 1 if d == dim else fraction != 1 for d, fraction in colors.items()
331 )
334def _dim_name_changes(rule_graph: Graph, dim_coords: set[str]) -> dict[str, str]:
335 colors = _color_dims(rule_graph, dim_coords)
336 nodes = list(rule_graph.nodes_topologically())[::-1]
337 name_changes = {}
338 for dim in dim_coords:
339 for node in nodes:
340 if _has_full_color_of_dim(colors[node], dim):
341 name_changes[dim] = node
342 break
343 return name_changes