Coverage for install/scipp/coords/transform_coords.py: 80%

113 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-17 01:51 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Simon Heybrock, Jan-Lukas Wynen 

4from collections.abc import Iterable, Mapping 

5from dataclasses import fields 

6from fractions import Fraction 

7 

8from ..core import DataArray, Dataset, DimensionError, VariableError, bins, empty 

9from ..logging import get_logger 

10from .coord_table import Coord, CoordTable 

11from .graph import Graph, GraphDict, rule_sequence 

12from .options import Options 

13from .rule import ComputeRule, FetchRule, Kernel, RenameRule, Rule, rule_output_names 

14 

15 

16def transform_coords( 

17 x: DataArray | Dataset, 

18 targets: str | Iterable[str] | None = None, 

19 /, 

20 graph: GraphDict | None = None, 

21 *, 

22 rename_dims: bool = True, 

23 keep_aliases: bool = True, 

24 keep_intermediate: bool = True, 

25 keep_inputs: bool = True, 

26 quiet: bool = False, 

27 **kwargs: Kernel, 

28) -> DataArray | Dataset: 

29 """Compute new coords based on transformations of input coords. 

30 

31 See the section in the user guide on 

32 `Coordinate transformations <../../user-guide/coordinate-transformations.rst>`_ 

33 for detailed explanations. 

34 

35 Parameters 

36 ---------- 

37 x: 

38 Input object with coords. 

39 targets: 

40 Name or list of names of desired output coords. 

41 graph: 

42 A graph defining how new coords can be computed from existing 

43 coords. This may be done in multiple steps. 

44 The graph is given by a :class:`dict` where: 

45 

46 - Dict keys are :class:`str` or :class:`tuple` of :class:`str`, 

47 defining the names of outputs generated by a dict value. 

48 - Dict values are :class:`str` or a callable (function). 

49 If :class:`str`, this is a synonym for renaming a coord. 

50 If a callable, it must either return a single variable or a dict of 

51 variables. The argument names of callables must be coords in ``x`` or be 

52 computable by other nodes in ``graph``. The coord names can be overridden by 

53 the callable by providing a ``__transform_coords_input_keys__`` property, 

54 returning a list of coord names in the same order as the function arguments. 

55 rename_dims: 

56 Rename dimensions if the corresponding dimension coords 

57 are used as inputs and there is a single output coord 

58 that can be associated with that dimension. 

59 See the user guide for more details and examples. 

60 Default is True. 

61 keep_aliases: 

62 If True, include aliases in the output. 

63 Default is True. 

64 keep_intermediate: 

65 If True, include intermediate results in the output. 

66 Default is True. 

67 keep_inputs: 

68 Include consumed input coordinates in the output. 

69 Default is True. 

70 quiet: 

71 If True, no log output is produced. Otherwise, ``transform_coords`` 

72 produces a log of its actions. 

73 **kwargs: 

74 Mapping of coords to callables. This can be used as an alternate and brief 

75 way of specifying targets and graph. If provided, neither ``targets`` nor 

76 ``graph`` may be given. 

77 

78 Returns 

79 ------- 

80 : 

81 New object with desired coords. Existing data and meta-data is shallow-copied. 

82 

83 Examples 

84 -------- 

85 

86 Transform input coordinates ``x`` and ``y`` to a new output coordinate ``xy``: 

87 

88 >>> da = sc.data.table_xyz(nrow=10) 

89 >>> transformed = da.transform_coords(xy=lambda x, y: x + y) 

90 

91 Equivalent full syntax based on a target name and a graph: 

92 

93 >>> da = sc.data.table_xyz(nrow=10) 

94 >>> transformed = da.transform_coords('xy', graph={'xy': lambda x, y: x + y}) 

95 

96 Multiple new coordinates can be computed at once. Here ``z2`` is setup as an alias 

97 of ``z``: 

98 

99 >>> da = sc.data.table_xyz(nrow=10) 

100 >>> transformed = da.transform_coords(xy=lambda x, y: x + y, z2='z') 

101 

102 This is equivalent to 

103 

104 >>> da = sc.data.table_xyz(nrow=10) 

105 >>> graph = {'xy': lambda x, y: x + y, 'z2':'z'} 

106 >>> transformed = da.transform_coords(['xy', 'z2'], graph=graph) 

107 

108 Multi-step transformations that do not keep intermediate results as coordinates can 

109 be performed with a graph containing nodes that depend on outputs of other nodes: 

110 

111 >>> da = sc.data.table_xyz(nrow=10) 

112 >>> graph = {'xy': lambda x, y: x + y, 'xyz': lambda xy, z: xy + z} 

113 >>> transformed = da.transform_coords('xyz', graph=graph) 

114 """ 

115 options = Options( 

116 rename_dims=rename_dims, 

117 keep_aliases=keep_aliases, 

118 keep_intermediate=keep_intermediate, 

119 keep_inputs=keep_inputs, 

120 quiet=quiet, 

121 ) 

122 for field in fields(options): 

123 if not isinstance(getattr(options, field.name), bool): 

124 raise TypeError( 

125 f"'{field.name}' is a reserved for keyword argument. " 

126 "Use explicit targets and graph arguments to create an output " 

127 "coordinate of this name." 

128 ) 

129 

130 if kwargs: 

131 if targets is not None or graph is not None: 

132 raise ValueError( 

133 "Explicit targets or graph not allowed since keyword arguments " 

134 f"{kwargs} define targets and graph." 

135 ) 

136 

137 if targets is None: 

138 targets = set(kwargs) 

139 graph = kwargs 

140 else: 

141 targets = {targets} if isinstance(targets, str) else set(targets) 

142 

143 _transform = _transform_dataset if isinstance(x, Dataset) else _transform_data_array 

144 return _transform(x, targets=targets, graph=Graph(graph), options=options) 

145 

146 

147def show_graph(graph: GraphDict, size: str | None = None, simplified: bool = False): 

148 """Show graphical representation of a graph as required by 

149 :py:func:`transform_coords` 

150 

151 Requires the `python-graphviz` package. 

152 

153 Parameters 

154 ---------- 

155 graph: 

156 Transformation graph to show. 

157 size: 

158 Size forwarded to graphviz, must be a string, "width,height" 

159 or "size". In the latter case, the same value is used for 

160 both width and height. 

161 simplified: 

162 If ``True``, do not show the conversion functions, 

163 only the potential input and output coordinates. 

164 

165 Returns 

166 ------- 

167 graph: graphviz.Digraph 

168 Can be displayed directly in Jupyter. 

169 See the 

170 `documentation <https://graphviz.readthedocs.io/en/stable/api.html#graphviz.Digraph>`_ 

171 for details. 

172 

173 Raises 

174 ------ 

175 RuntimeError 

176 If graphviz is not installed. 

177 """ 

178 return Graph(graph).show(size=size, simplified=simplified) 

179 

180 

181def _transform_data_array( 

182 original: DataArray, targets: set[str], graph: Graph, options: Options 

183) -> DataArray: 

184 graph = graph.graph_for(original, targets) 

185 rules = rule_sequence(graph) 

186 working_coords = CoordTable(rules, targets, options) 

187 dim_coords = set() 

188 for rule in rules: 

189 for name, coord in rule(working_coords).items(): 

190 working_coords.add(name, coord) 

191 # Check if coord is a dimension-coord. Need to also check if it is in the 

192 # data dimensions because slicing can produce coords with dims that are 

193 # no longer in the data. 

194 if name in original.dims and coord.has_dim(name): 

195 dim_coords.add(name) 

196 

197 dim_name_changes = ( 

198 _dim_name_changes(graph, dim_coords) if options.rename_dims else {} 

199 ) 

200 if not options.quiet: 

201 _log_transform(rules, targets, dim_name_changes, working_coords) 

202 res = _store_results(original, working_coords, targets) 

203 return res.rename_dims(dim_name_changes) 

204 

205 

206def _transform_dataset( 

207 original: Dataset, targets: set[str], graph: Graph, *, options: Options 

208) -> Dataset: 

209 # Note the inefficiency here in datasets with multiple items: Coord transform is 

210 # repeated for every item rather than sharing what is possible. Since we may have 

211 # dataset items with binned data this is far from trivial. Unless we have clear 

212 # performance requirements, we go with the safe and simple solution. 

213 if len(original) > 0: 

214 return Dataset( 

215 data={ 

216 name: _transform_data_array( 

217 original[name], targets=targets, graph=graph, options=options 

218 ) 

219 for name in original 

220 } 

221 ) 

222 dummy = DataArray(empty(sizes=original.sizes), coords=original.coords) 

223 transformed = _transform_data_array( 

224 dummy, targets=targets, graph=graph, options=options 

225 ) 

226 return Dataset(coords=transformed.coords) 

227 

228 

229def _log_transform( 

230 rules: list[Rule], 

231 targets: set[str], 

232 dim_name_changes: Mapping[str, str], 

233 coords: CoordTable, 

234) -> None: 

235 inputs = set(rule_output_names(rules, FetchRule)) 

236 byproducts = { 

237 name 

238 for name in ( 

239 set(rule_output_names(rules, RenameRule)) 

240 | set(rule_output_names(rules, ComputeRule)) 

241 ) 

242 - targets 

243 if coords.total_usages(name) < 0 

244 } 

245 preexisting = {target for target in targets if target in inputs} 

246 steps = [rule for rule in rules if not isinstance(rule, FetchRule)] 

247 

248 message = ( 

249 f'Transformed coords ({", ".join(sorted(inputs))}) ' 

250 f'-> ({", ".join(sorted(targets))})' 

251 ) 

252 if byproducts: 

253 message += f'\n Byproducts:\n {", ".join(sorted(byproducts))}' 

254 if dim_name_changes: 

255 dim_rename_steps = '\n'.join( 

256 f' {t} <- {f}' for f, t in dim_name_changes.items() 

257 ) 

258 message += '\n Renamed dimensions:\n' + dim_rename_steps 

259 if preexisting: 

260 message += ( 

261 '\n Outputs already present in input:' 

262 f'\n {", ".join(sorted(preexisting))}' 

263 ) 

264 message += '\n Steps:\n' + ( 

265 '\n'.join(f' {rule}' for rule in steps) if steps else ' None' 

266 ) 

267 

268 get_logger().info(message) 

269 

270 

271def _store_coord(da: DataArray, name: str, coord: Coord) -> None: 

272 def try_del(): 

273 da.coords.pop(name, None) 

274 if da.bins is not None: 

275 da.bins.coords.pop(name, None) 

276 

277 def store(x, c): 

278 x.coords[name] = c 

279 x.coords.set_aligned(name, coord.aligned) 

280 

281 if coord.usages == 0: 

282 try_del() 

283 else: 

284 if coord.has_dense: 

285 store(da, coord.dense) 

286 if coord.has_event: 

287 try: 

288 store(da.bins, coord.event) 

289 except (DimensionError, VariableError): 

290 # Thrown on mismatching bin indices, e.g. slice 

291 da.data = da.data.copy() 

292 store(da.bins, coord.event) 

293 

294 

295def _store_results(da: DataArray, coords: CoordTable, targets: set[str]) -> DataArray: 

296 da = da.copy(deep=False) 

297 # See #2773 for why this is necessary. 

298 if da.bins is not None: 

299 da.data = bins(**da.bins.constituents) 

300 for name, coord in coords.items(): 

301 if name in targets: 

302 coord.aligned = True 

303 _store_coord(da, name, coord) 

304 return da 

305 

306 

307def _color_dims(graph: Graph, dim_coords: set[str]) -> dict[str, dict[str, Fraction]]: 

308 colors = { 

309 coord: {dim: Fraction(0, 1) for dim in dim_coords} for coord in graph.nodes() 

310 } 

311 for dim in dim_coords: 

312 colors[dim][dim] = Fraction(1, 1) 

313 depth_first_stack = [dim] 

314 while depth_first_stack: 

315 coord = depth_first_stack.pop() 

316 children = tuple(graph.children_of(coord)) 

317 for child in children: 

318 # test for produced dim coords 

319 if child not in dim_coords: 

320 colors[child][dim] += colors[coord][dim] * Fraction( 

321 1, len(children) 

322 ) 

323 depth_first_stack.extend(children) 

324 

325 return colors 

326 

327 

328def _has_full_color_of_dim(colors: dict[str, Fraction], dim: str) -> bool: 

329 return all( 

330 fraction == 1 if d == dim else fraction != 1 for d, fraction in colors.items() 

331 ) 

332 

333 

334def _dim_name_changes(rule_graph: Graph, dim_coords: set[str]) -> dict[str, str]: 

335 colors = _color_dims(rule_graph, dim_coords) 

336 nodes = list(rule_graph.nodes_topologically())[::-1] 

337 name_changes = {} 

338 for dim in dim_coords: 

339 for node in nodes: 

340 if _has_full_color_of_dim(colors[node], dim): 

341 name_changes[dim] = node 

342 break 

343 return name_changes