Coverage for install/scipp/coords/graph.py: 58%
93 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock, Jan-Lukas Wynen
5from __future__ import annotations
7import collections
8from collections.abc import Iterable
9from graphlib import TopologicalSorter
11from ..core import DataArray
12from ..utils.graph import make_graphviz_digraph
13from .rule import ComputeRule, FetchRule, Kernel, RenameRule, Rule
15GraphDict = dict[str | tuple[str, ...], str | Kernel]
18class Graph:
19 def __init__(self, graph: GraphDict | dict[str, Rule]):
20 if not isinstance(graph, collections.abc.Mapping):
21 raise TypeError("'graph' must be a dict")
22 if not graph:
23 self._rules = {}
24 elif isinstance(next(iter(graph.values())), Rule):
25 self._rules: dict[str, Rule] = graph
26 else:
27 self._rules: dict[str, Rule] = _convert_to_rule_graph(graph)
29 def __getitem__(self, name: str) -> Rule:
30 return self._rules[name]
32 def items(self) -> Iterable[tuple[str, Rule]]:
33 yield from self._rules.items()
35 def parents_of(self, node: str) -> Iterable[str]:
36 try:
37 yield from self._rules[node].dependencies
38 except KeyError:
39 # Input nodes have no parents but are not represented in the
40 # graph unless the corresponding FetchRules have been added.
41 return
43 def children_of(self, node: str) -> Iterable[str]:
44 for candidate, rule in self.items():
45 if node in rule.dependencies:
46 yield candidate
48 def nodes(self) -> Iterable[str]:
49 yield from self._rules.keys()
51 def nodes_topologically(self) -> Iterable[str]:
52 yield from TopologicalSorter(
53 {out: rule.dependencies for out, rule in self._rules.items()}
54 ).static_order()
56 def graph_for(self, da: DataArray, targets: set[str]) -> Graph:
57 """
58 Construct a graph containing only rules needed for the given DataArray
59 and targets, including FetchRules for the inputs.
60 """
61 subgraph = {}
62 depth_first_stack = list(targets)
63 while depth_first_stack:
64 out_name = depth_first_stack.pop()
65 if out_name in subgraph:
66 continue
67 rule = self._rule_for(out_name, da)
68 for name in rule.out_names:
69 subgraph[name] = rule
70 depth_first_stack.extend(rule.dependencies)
71 return Graph(subgraph)
73 def _rule_for(self, out_name: str, da: DataArray) -> Rule:
74 if _is_in_coords(out_name, da):
75 return FetchRule((out_name,), da.coords, da.bins.coords if da.bins else {})
76 try:
77 return self._rules[out_name]
78 except KeyError:
79 raise KeyError(
80 f"Coordinate '{out_name}' does not exist in the input data "
81 "and no rule has been provided to compute it."
82 ) from None
84 def show(self, size=None, simplified=False):
85 dot = make_graphviz_digraph(strict=True)
86 dot.attr('node', shape='box', height='0.1')
87 dot.attr(size=size)
88 for output, rule in self._rules.items():
89 if isinstance(rule, RenameRule):
90 dot.edge(rule.dependencies[0], output, style='dashed')
91 elif isinstance(rule, ComputeRule):
92 if not simplified:
93 # Get a unique name for every node,
94 # works because str contains address of func.
95 name = str(rule)
96 label = f'{rule.func_name}(...)'
97 dot.node(
98 name,
99 label=label,
100 shape='plain',
101 style='filled',
102 color='lightgrey',
103 )
104 dot.edge(name, output)
105 else:
106 name = output
107 for arg in rule.dependencies:
108 dot.edge(arg, name)
109 return dot
112def rule_sequence(rules: Graph) -> list[Rule]:
113 already_used = set()
114 result = []
115 for rule in (
116 r for n in rules.nodes_topologically() if (r := rules[n]) not in already_used
117 ):
118 already_used.add(rule)
119 result.append(rule)
120 return result
123def _make_rule(products, producer) -> Rule:
124 if isinstance(producer, str):
125 return RenameRule(products, producer)
126 return ComputeRule(products, producer)
129def _convert_to_rule_graph(graph: GraphDict) -> dict[str, Rule]:
130 rule_graph = {}
131 for products, producer in graph.items():
132 products = (products,) if isinstance(products, str) else tuple(products)
133 rule = _make_rule(products, producer)
134 for product in products:
135 if product in rule_graph:
136 raise ValueError(
137 f'Duplicate output name defined in conversion graph: {product}'
138 )
139 rule_graph[product] = rule
140 return rule_graph
143def _is_in_coords(name: str, da: DataArray) -> bool:
144 return name in da.coords or (da.bins is not None and name in da.bins.coords)