Coverage for install/scipp/coords/coord_table.py: 72%
47 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Jan-Lukas Wynen
5import dataclasses
6from collections.abc import Iterable
8from .coord import Coord
9from .options import Options
10from .rule import FetchRule, RenameRule, Rule, rule_output_names
13class CoordTable:
14 """
15 Stores a dictionary of coordinates for use in coord transforms.
17 Coords have an associated number of usages.
18 When that number drops to 0, the coord is removed.
19 """
21 def __init__(self, rules: list[Rule], targets: set[str], options: Options):
22 self._coords = {}
23 self._total_usages = _apply_keep_options(
24 _count_usages(rules), rules, targets, options
25 )
26 # Preserve all targets regardless of keep_* options.
27 for name in targets:
28 self._total_usages[name] = -1
30 def add(self, name: str, coord: Coord):
31 self._coords[name] = dataclasses.replace(coord, usages=self.total_usages(name))
33 def consume(self, name: str) -> Coord:
34 coord = self._coords[name]
35 coord.aligned = False
36 coord.use()
37 if coord.usages == 0:
38 # The coord's data is no longer needed in the table.
39 # But the caller of `consume` does need it, so return `coord` as is.
40 self._coords[name] = dataclasses.replace(coord, dense=None, event=None)
41 return coord
43 def total_usages(self, name: str) -> int:
44 return self._total_usages.get(name, -1)
46 def items(self) -> Iterable[tuple[str, Coord]]:
47 yield from self._coords.items()
50def _count_usages(rules: list[Rule]) -> dict[str, int]:
51 usages = {}
52 for rule in rules:
53 for name in rule.dependencies:
54 usages.setdefault(name, 0)
55 usages[name] += 1
56 return usages
59def _apply_keep_options(
60 usages: dict[str, int], rules: list[Rule], targets: set[str], options: Options
61) -> dict[str, int]:
62 def out_names(rule_type):
63 yield from filter(
64 lambda name: name not in targets, rule_output_names(rules, rule_type)
65 )
67 def handle_in(names):
68 for name in names:
69 usages[name] = -1
71 inputs = set(out_names(FetchRule))
72 aliases = set(out_names(RenameRule))
73 all_inputs = {dep for rule in rules for dep in rule.dependencies}
74 if options.keep_inputs:
75 handle_in(inputs)
76 if options.keep_intermediate:
77 handle_in(all_inputs - inputs - aliases)
78 if options.keep_aliases:
79 handle_in(aliases)
80 return usages