Coverage for install/scipp/coords/graph.py: 58%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-01 01:59 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Simon Heybrock, Jan-Lukas Wynen 

4 

5from __future__ import annotations 

6 

7import collections 

8from collections.abc import Iterable 

9from graphlib import TopologicalSorter 

10 

11from ..core import DataArray 

12from ..utils.graph import make_graphviz_digraph 

13from .rule import ComputeRule, FetchRule, Kernel, RenameRule, Rule 

14 

15GraphDict = dict[str | tuple[str, ...], str | Kernel] 

16 

17 

18class Graph: 

19 def __init__(self, graph: GraphDict | dict[str, Rule]): 

20 if not isinstance(graph, collections.abc.Mapping): 

21 raise TypeError("'graph' must be a dict") 

22 if not graph: 

23 self._rules = {} 

24 elif isinstance(next(iter(graph.values())), Rule): 

25 self._rules: dict[str, Rule] = graph 

26 else: 

27 self._rules: dict[str, Rule] = _convert_to_rule_graph(graph) 

28 

29 def __getitem__(self, name: str) -> Rule: 

30 return self._rules[name] 

31 

32 def items(self) -> Iterable[tuple[str, Rule]]: 

33 yield from self._rules.items() 

34 

35 def parents_of(self, node: str) -> Iterable[str]: 

36 try: 

37 yield from self._rules[node].dependencies 

38 except KeyError: 

39 # Input nodes have no parents but are not represented in the 

40 # graph unless the corresponding FetchRules have been added. 

41 return 

42 

43 def children_of(self, node: str) -> Iterable[str]: 

44 for candidate, rule in self.items(): 

45 if node in rule.dependencies: 

46 yield candidate 

47 

48 def nodes(self) -> Iterable[str]: 

49 yield from self._rules.keys() 

50 

51 def nodes_topologically(self) -> Iterable[str]: 

52 yield from TopologicalSorter( 

53 {out: rule.dependencies for out, rule in self._rules.items()} 

54 ).static_order() 

55 

56 def graph_for(self, da: DataArray, targets: set[str]) -> Graph: 

57 """ 

58 Construct a graph containing only rules needed for the given DataArray 

59 and targets, including FetchRules for the inputs. 

60 """ 

61 subgraph = {} 

62 depth_first_stack = list(targets) 

63 while depth_first_stack: 

64 out_name = depth_first_stack.pop() 

65 if out_name in subgraph: 

66 continue 

67 rule = self._rule_for(out_name, da) 

68 for name in rule.out_names: 

69 subgraph[name] = rule 

70 depth_first_stack.extend(rule.dependencies) 

71 return Graph(subgraph) 

72 

73 def _rule_for(self, out_name: str, da: DataArray) -> Rule: 

74 if _is_in_coords(out_name, da): 

75 return FetchRule((out_name,), da.coords, da.bins.coords if da.bins else {}) 

76 try: 

77 return self._rules[out_name] 

78 except KeyError: 

79 raise KeyError( 

80 f"Coordinate '{out_name}' does not exist in the input data " 

81 "and no rule has been provided to compute it." 

82 ) from None 

83 

84 def show(self, size=None, simplified=False): 

85 dot = make_graphviz_digraph(strict=True) 

86 dot.attr('node', shape='box', height='0.1') 

87 dot.attr(size=size) 

88 for output, rule in self._rules.items(): 

89 if isinstance(rule, RenameRule): 

90 dot.edge(rule.dependencies[0], output, style='dashed') 

91 elif isinstance(rule, ComputeRule): 

92 if not simplified: 

93 # Get a unique name for every node, 

94 # works because str contains address of func. 

95 name = str(rule) 

96 label = f'{rule.func_name}(...)' 

97 dot.node( 

98 name, 

99 label=label, 

100 shape='plain', 

101 style='filled', 

102 color='lightgrey', 

103 ) 

104 dot.edge(name, output) 

105 else: 

106 name = output 

107 for arg in rule.dependencies: 

108 dot.edge(arg, name) 

109 return dot 

110 

111 

112def rule_sequence(rules: Graph) -> list[Rule]: 

113 already_used = set() 

114 result = [] 

115 for rule in ( 

116 r for n in rules.nodes_topologically() if (r := rules[n]) not in already_used 

117 ): 

118 already_used.add(rule) 

119 result.append(rule) 

120 return result 

121 

122 

123def _make_rule(products, producer) -> Rule: 

124 if isinstance(producer, str): 

125 return RenameRule(products, producer) 

126 return ComputeRule(products, producer) 

127 

128 

129def _convert_to_rule_graph(graph: GraphDict) -> dict[str, Rule]: 

130 rule_graph = {} 

131 for products, producer in graph.items(): 

132 products = (products,) if isinstance(products, str) else tuple(products) 

133 rule = _make_rule(products, producer) 

134 for product in products: 

135 if product in rule_graph: 

136 raise ValueError( 

137 f'Duplicate output name defined in conversion graph: {product}' 

138 ) 

139 rule_graph[product] = rule 

140 return rule_graph 

141 

142 

143def _is_in_coords(name: str, da: DataArray) -> bool: 

144 return name in da.coords or (da.bins is not None and name in da.bins.coords)