Coverage for install/scipp/coords/transform_coords.py: 80%
113 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock, Jan-Lukas Wynen
4from dataclasses import fields
5from fractions import Fraction
6from typing import Callable, Dict, Iterable, List, Mapping, Optional, Set, Union
8from ..core import DataArray, Dataset, DimensionError, VariableError, bins, empty
9from ..logging import get_logger
10from .coord_table import Coord, CoordTable
11from .graph import Graph, GraphDict, rule_sequence
12from .options import Options
13from .rule import ComputeRule, FetchRule, RenameRule, Rule, rule_output_names
16def transform_coords(
17 x: Union[DataArray, Dataset],
18 targets: Optional[Union[str, Iterable[str]]] = None,
19 /,
20 graph: Optional[GraphDict] = None,
21 *,
22 rename_dims: bool = True,
23 keep_aliases: bool = True,
24 keep_intermediate: bool = True,
25 keep_inputs: bool = True,
26 quiet: bool = False,
27 **kwargs: Callable,
28) -> Union[DataArray, Dataset]:
29 """Compute new coords based on transformations of input coords.
31 See the section in the user guide on
32 `Coordinate transformations <../../user-guide/coordinate-transformations.rst>`_
33 for detailed explanations.
35 Parameters
36 ----------
37 x:
38 Input object with coords.
39 targets:
40 Name or list of names of desired output coords.
41 graph:
42 A graph defining how new coords can be computed from existing
43 coords. This may be done in multiple steps.
44 The graph is given by a :class:`dict` where:
46 - Dict keys are :class:`str` or :class:`tuple` of :class:`str`,
47 defining the names of outputs generated by a dict value.
48 - Dict values are :class:`str` or a callable (function).
49 If :class:`str`, this is a synonym for renaming a coord.
50 If a callable, it must either return a single variable or a dict of
51 variables. The argument names of callables must be coords in ``x`` or be
52 computable by other nodes in ``graph``. The coord names can be overridden by
53 the callable by providing a ``__transform_coords_input_keys__`` property,
54 returning a list of coord names in the same order as the function arguments.
55 rename_dims:
56 Rename dimensions if the corresponding dimension coords
57 are used as inputs and there is a single output coord
58 that can be associated with that dimension.
59 See the user guide for more details and examples.
60 Default is True.
61 keep_aliases:
62 If True, include aliases in the output.
63 Default is True.
64 keep_intermediate:
65 If True, include intermediate results in the output.
66 Default is True.
67 keep_inputs:
68 Include consumed input coordinates in the output.
69 Default is True.
70 quiet:
71 If True, no log output is produced. Otherwise, ``transform_coords``
72 produces a log of its actions.
73 **kwargs:
74 Mapping of coords to callables. This can be used as an alternate and brief
75 way of specifying targets and graph. If provided, neither ``targets`` nor
76 ``graph`` may be given.
78 Returns
79 -------
80 :
81 New object with desired coords. Existing data and meta-data is shallow-copied.
83 Examples
84 --------
86 Transform input coordinates ``x`` and ``y`` to a new output coordinate ``xy``:
88 >>> da = sc.data.table_xyz(nrow=10)
89 >>> transformed = da.transform_coords(xy=lambda x, y: x + y)
91 Equivalent full syntax based on a target name and a graph:
93 >>> da = sc.data.table_xyz(nrow=10)
94 >>> transformed = da.transform_coords('xy', graph={'xy': lambda x, y: x + y})
96 Multiple new coordinates can be computed at once. Here ``z2`` is setup as an alias
97 of ``z``:
99 >>> da = sc.data.table_xyz(nrow=10)
100 >>> transformed = da.transform_coords(xy=lambda x, y: x + y, z2='z')
102 This is equivalent to
104 >>> da = sc.data.table_xyz(nrow=10)
105 >>> graph = {'xy': lambda x, y: x + y, 'z2':'z'}
106 >>> transformed = da.transform_coords(['xy', 'z2'], graph=graph)
108 Multi-step transformations that do not keep intermediate results as coordinates can
109 be performed with a graph containing nodes that depend on outputs of other nodes:
111 >>> da = sc.data.table_xyz(nrow=10)
112 >>> graph = {'xy': lambda x, y: x + y, 'xyz': lambda xy, z: xy + z}
113 >>> transformed = da.transform_coords('xyz', graph=graph)
114 """
115 options = Options(
116 rename_dims=rename_dims,
117 keep_aliases=keep_aliases,
118 keep_intermediate=keep_intermediate,
119 keep_inputs=keep_inputs,
120 quiet=quiet,
121 )
122 for field in fields(options):
123 if not isinstance(getattr(options, field.name), bool):
124 raise TypeError(
125 f"'{field.name}' is a reserved for keyword argument. "
126 "Use explicit targets and graph arguments to create an output "
127 "coordinate of this name."
128 )
130 if kwargs:
131 if targets is not None or graph is not None:
132 raise ValueError(
133 "Explicit targets or graph not allowed since keyword arguments "
134 f"{kwargs} define targets and graph."
135 )
137 if targets is None:
138 targets = set(kwargs)
139 graph = kwargs
140 else:
141 targets = {targets} if isinstance(targets, str) else set(targets)
143 _transform = _transform_dataset if isinstance(x, Dataset) else _transform_data_array
144 return _transform(x, targets=targets, graph=Graph(graph), options=options)
147def show_graph(graph: GraphDict, size: Optional[str] = None, simplified: bool = False):
148 """Show graphical representation of a graph as required by
149 :py:func:`transform_coords`
151 Requires the `python-graphviz` package.
153 Parameters
154 ----------
155 graph:
156 Transformation graph to show.
157 size:
158 Size forwarded to graphviz, must be a string, "width,height"
159 or "size". In the latter case, the same value is used for
160 both width and height.
161 simplified:
162 If ``True``, do not show the conversion functions,
163 only the potential input and output coordinates.
165 Returns
166 -------
167 graph: graphviz.Digraph
168 Can be displayed directly in Jupyter.
169 See the
170 `documentation <https://graphviz.readthedocs.io/en/stable/api.html#graphviz.Digraph>`_
171 for details.
173 Raises
174 ------
175 RuntimeError
176 If graphviz is not installed.
177 """
178 return Graph(graph).show(size=size, simplified=simplified)
181def _transform_data_array(
182 original: DataArray, targets: Set[str], graph: Graph, options: Options
183) -> DataArray:
184 graph = graph.graph_for(original, targets)
185 rules = rule_sequence(graph)
186 working_coords = CoordTable(rules, targets, options)
187 dim_coords = set()
188 for rule in rules:
189 for name, coord in rule(working_coords).items():
190 working_coords.add(name, coord)
191 # Check if coord is a dimension-coord. Need to also check if it is in the
192 # data dimensions because slicing can produce coords with dims that are
193 # no longer in the data.
194 if name in original.dims and coord.has_dim(name):
195 dim_coords.add(name)
197 dim_name_changes = (
198 _dim_name_changes(graph, dim_coords) if options.rename_dims else {}
199 )
200 if not options.quiet:
201 _log_transform(rules, targets, dim_name_changes, working_coords)
202 res = _store_results(original, working_coords, targets)
203 return res.rename_dims(dim_name_changes)
206def _transform_dataset(
207 original: Dataset, targets: Set[str], graph: Graph, *, options: Options
208) -> Dataset:
209 # Note the inefficiency here in datasets with multiple items: Coord transform is
210 # repeated for every item rather than sharing what is possible. Since we may have
211 # dataset items with binned data this is far from trivial. Unless we have clear
212 # performance requirements, we go with the safe and simple solution.
213 if len(original) > 0:
214 return Dataset(
215 data={
216 name: _transform_data_array(
217 original[name], targets=targets, graph=graph, options=options
218 )
219 for name in original
220 }
221 )
222 dummy = DataArray(empty(sizes=original.sizes), coords=original.coords)
223 transformed = _transform_data_array(
224 dummy, targets=targets, graph=graph, options=options
225 )
226 return Dataset(coords=transformed.coords)
229def _log_transform(
230 rules: List[Rule],
231 targets: Set[str],
232 dim_name_changes: Mapping[str, str],
233 coords: CoordTable,
234) -> None:
235 inputs = set(rule_output_names(rules, FetchRule))
236 byproducts = {
237 name
238 for name in (
239 set(rule_output_names(rules, RenameRule))
240 | set(rule_output_names(rules, ComputeRule))
241 )
242 - targets
243 if coords.total_usages(name) < 0
244 }
245 preexisting = {target for target in targets if target in inputs}
246 steps = [rule for rule in rules if not isinstance(rule, FetchRule)]
248 message = (
249 f'Transformed coords ({", ".join(sorted(inputs))}) '
250 f'-> ({", ".join(sorted(targets))})'
251 )
252 if byproducts:
253 message += f'\n Byproducts:\n {", ".join(sorted(byproducts))}'
254 if dim_name_changes:
255 dim_rename_steps = '\n'.join(
256 f' {t} <- {f}' for f, t in dim_name_changes.items()
257 )
258 message += '\n Renamed dimensions:\n' + dim_rename_steps
259 if preexisting:
260 message += (
261 '\n Outputs already present in input:'
262 f'\n {", ".join(sorted(preexisting))}'
263 )
264 message += '\n Steps:\n' + (
265 '\n'.join(f' {rule}' for rule in steps) if steps else ' None'
266 )
268 get_logger().info(message)
271def _store_coord(da: DataArray, name: str, coord: Coord) -> None:
272 def try_del():
273 da.coords.pop(name, None)
274 if da.bins is not None:
275 da.bins.coords.pop(name, None)
277 def store(x, c):
278 x.coords[name] = c
279 x.coords.set_aligned(name, coord.aligned)
281 if coord.usages == 0:
282 try_del()
283 else:
284 if coord.has_dense:
285 store(da, coord.dense)
286 if coord.has_event:
287 try:
288 store(da.bins, coord.event)
289 except (DimensionError, VariableError):
290 # Thrown on mismatching bin indices, e.g. slice
291 da.data = da.data.copy()
292 store(da.bins, coord.event)
295def _store_results(da: DataArray, coords: CoordTable, targets: Set[str]) -> DataArray:
296 da = da.copy(deep=False)
297 # See #2773 for why this is necessary.
298 if da.bins is not None:
299 da.data = bins(**da.bins.constituents)
300 for name, coord in coords.items():
301 if name in targets:
302 coord.aligned = True
303 _store_coord(da, name, coord)
304 return da
307def _color_dims(graph: Graph, dim_coords: Set[str]) -> Dict[str, Dict[str, Fraction]]:
308 colors = {
309 coord: {dim: Fraction(0, 1) for dim in dim_coords} for coord in graph.nodes()
310 }
311 for dim in dim_coords:
312 colors[dim][dim] = Fraction(1, 1)
313 depth_first_stack = [dim]
314 while depth_first_stack:
315 coord = depth_first_stack.pop()
316 children = tuple(graph.children_of(coord))
317 for child in children:
318 # test for produced dim coords
319 if child not in dim_coords:
320 colors[child][dim] += colors[coord][dim] * Fraction(
321 1, len(children)
322 )
323 depth_first_stack.extend(children)
325 return colors
328def _has_full_color_of_dim(colors: Dict[str, Fraction], dim: str) -> bool:
329 return all(
330 fraction == 1 if d == dim else fraction != 1 for d, fraction in colors.items()
331 )
334def _dim_name_changes(rule_graph: Graph, dim_coords: Set[str]) -> Dict[str, str]:
335 colors = _color_dims(rule_graph, dim_coords)
336 nodes = list(rule_graph.nodes_topologically())[::-1]
337 name_changes = {}
338 for dim in dim_coords:
339 for node in nodes:
340 if _has_full_color_of_dim(colors[node], dim):
341 name_changes[dim] = node
342 break
343 return name_changes