Coverage for install/scipp/coords/coord_table.py: 72%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-12-01 01:59 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Jan-Lukas Wynen 

4 

5import dataclasses 

6from collections.abc import Iterable 

7 

8from .coord import Coord 

9from .options import Options 

10from .rule import FetchRule, RenameRule, Rule, rule_output_names 

11 

12 

13class CoordTable: 

14 """ 

15 Stores a dictionary of coordinates for use in coord transforms. 

16 

17 Coords have an associated number of usages. 

18 When that number drops to 0, the coord is removed. 

19 """ 

20 

21 def __init__(self, rules: list[Rule], targets: set[str], options: Options): 

22 self._coords = {} 

23 self._total_usages = _apply_keep_options( 

24 _count_usages(rules), rules, targets, options 

25 ) 

26 # Preserve all targets regardless of keep_* options. 

27 for name in targets: 

28 self._total_usages[name] = -1 

29 

30 def add(self, name: str, coord: Coord): 

31 self._coords[name] = dataclasses.replace(coord, usages=self.total_usages(name)) 

32 

33 def consume(self, name: str) -> Coord: 

34 coord = self._coords[name] 

35 coord.aligned = False 

36 coord.use() 

37 if coord.usages == 0: 

38 # The coord's data is no longer needed in the table. 

39 # But the caller of `consume` does need it, so return `coord` as is. 

40 self._coords[name] = dataclasses.replace(coord, dense=None, event=None) 

41 return coord 

42 

43 def total_usages(self, name: str) -> int: 

44 return self._total_usages.get(name, -1) 

45 

46 def items(self) -> Iterable[tuple[str, Coord]]: 

47 yield from self._coords.items() 

48 

49 

50def _count_usages(rules: list[Rule]) -> dict[str, int]: 

51 usages = {} 

52 for rule in rules: 

53 for name in rule.dependencies: 

54 usages.setdefault(name, 0) 

55 usages[name] += 1 

56 return usages 

57 

58 

59def _apply_keep_options( 

60 usages: dict[str, int], rules: list[Rule], targets: set[str], options: Options 

61) -> dict[str, int]: 

62 def out_names(rule_type): 

63 yield from filter( 

64 lambda name: name not in targets, rule_output_names(rules, rule_type) 

65 ) 

66 

67 def handle_in(names): 

68 for name in names: 

69 usages[name] = -1 

70 

71 inputs = set(out_names(FetchRule)) 

72 aliases = set(out_names(RenameRule)) 

73 all_inputs = {dep for rule in rules for dep in rule.dependencies} 

74 if options.keep_inputs: 

75 handle_in(inputs) 

76 if options.keep_intermediate: 

77 handle_in(all_inputs - inputs - aliases) 

78 if options.keep_aliases: 

79 handle_in(aliases) 

80 return usages