Coverage for install/scipp/core/data_group.py: 48%

297 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-04-28 01:28 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2022 Scipp contributors (https://github.com/scipp) 

3# @author Simon Heybrock 

4from __future__ import annotations 

5 

6import copy 

7import functools 

8import itertools 

9import numbers 

10import operator 

11from collections.abc import MutableMapping 

12from functools import wraps 

13from typing import ( 

14 TYPE_CHECKING, 

15 Any, 

16 Callable, 

17 Dict, 

18 Iterable, 

19 NoReturn, 

20 Optional, 

21 Tuple, 

22 TypeVar, 

23 Union, 

24 cast, 

25 overload, 

26) 

27 

28import numpy as np 

29 

30from .. import _binding 

31from .cpp_classes import ( 

32 DataArray, 

33 Dataset, 

34 DimensionError, 

35 GroupByDataArray, 

36 GroupByDataset, 

37 Variable, 

38) 

39 

40if TYPE_CHECKING: 

41 # typing imports data_group. 

42 # So the following import would create a cycle at runtime. 

43 from ..typing import ScippIndex 

44 

45 

46def _item_dims(item): 

47 return getattr(item, 'dims', ()) 

48 

49 

50def _is_binned(item): 

51 from .bins import Bins 

52 

53 if isinstance(item, Bins): 

54 return True 

55 return hasattr(item, 'bins') and item.bins is not None 

56 

57 

58def _summarize(item): 

59 if isinstance(item, DataGroup): 

60 return f'{type(item).__name__}({len(item)}, {item.sizes})' 

61 if hasattr(item, 'sizes'): 

62 return f'{type(item).__name__}({item.sizes})' 

63 return str(item) 

64 

65 

66def _is_positional_index(key) -> bool: 

67 def is_int(x): 

68 return isinstance(x, numbers.Integral) 

69 

70 if is_int(key): 

71 return True 

72 if isinstance(key, slice): 

73 if is_int(key.start) or is_int(key.stop) or is_int(key.step): 

74 return True 

75 if key.start is None and key.stop is None and key.step is None: 

76 return True 

77 return False 

78 

79 

80def _is_list_index(key) -> bool: 

81 return isinstance(key, (list, np.ndarray)) 

82 

83 

84class DataGroup(MutableMapping): 

85 """ 

86 A dict-like group of data. Additionally provides dims and shape properties. 

87 

88 DataGroup acts like a Python dict but additionally supports Scipp functionality 

89 such as positional- and label-based indexing and Scipp operations by mapping them 

90 to the values in the dict. This may happen recursively to support tree-like data 

91 structures. 

92 

93 .. versionadded:: 23.01.0 

94 """ 

95 

96 def __init__(self, /, *args, **kwargs): 

97 self._items = dict(*args, **kwargs) 

98 if not all(isinstance(k, str) for k in self._items.keys()): 

99 raise ValueError("DataGroup keys must be strings.") 

100 

101 def __copy__(self) -> DataGroup: 

102 return DataGroup(copy.copy(self._items)) 

103 

104 def __len__(self) -> int: 

105 """Return the number of items in the data group.""" 

106 return len(self._items) 

107 

108 def __iter__(self): 

109 yield from self._items 

110 

111 @overload 

112 def __getitem__(self, name: str) -> Any: ... 

113 

114 @overload 

115 def __getitem__(self, name: ScippIndex) -> DataGroup: ... 

116 

117 def __getitem__(self, name): 

118 """Return item of given name or index all items. 

119 

120 When ``name`` is a string, return the item of the given name. Otherwise, this 

121 returns a new DataGroup, with items created by indexing the items in this 

122 DataGroup. This may perform, e.g., Scipp's positional indexing, label-based 

123 indexing, or advanced indexing on items that are scipp.Variable or 

124 scipp.DataArray. 

125 

126 Label-based indexing is only possible when all items have a coordinate for the 

127 indexed dimension. 

128 

129 Advanced indexing comprises integer-array indexing and boolean-variable 

130 indexing. Unlike positional indexing, integer-array indexing works even when 

131 the item shapes are inconsistent for the indexed dimensions, provided that all 

132 items contain the maximal index in the integer array. Boolean-variable indexing 

133 is only possible when the shape of all items is compatible with the boolean 

134 variable. 

135 """ 

136 from .bins import Bins 

137 

138 if isinstance(name, str): 

139 return self._items[name] 

140 if isinstance(name, tuple) and name == (): 

141 return self.apply(operator.itemgetter(name)) 

142 if isinstance(name, Variable): # boolean indexing 

143 return self.apply(operator.itemgetter(name)) 

144 if _is_positional_index(name) or _is_list_index(name): 

145 if self.ndim != 1: 

146 raise DimensionError( 

147 "Slicing with implicit dimension label is only possible " 

148 f"for 1-D objects. Got {self.sizes} with ndim={self.ndim}. Provide " 

149 "an explicit dimension label, e.g., var['x', 0] instead of var[0]." 

150 ) 

151 dim = self.dims[0] 

152 index = name 

153 else: 

154 dim, index = name 

155 return DataGroup( 

156 { 

157 key: var[dim, index] 

158 if (isinstance(var, Bins) or dim in _item_dims(var)) 

159 else var 

160 for key, var in self.items() 

161 } 

162 ) 

163 

164 @overload 

165 def __setitem__(self, name: str, value: Any): ... 

166 

167 def __setitem__(self, name, value): 

168 """Set self[key] to value.""" 

169 if isinstance(name, str): 

170 self._items[name] = value 

171 else: 

172 raise TypeError('Keys must be strings') 

173 

174 def __delitem__(self, name: str): 

175 """Delete self[key].""" 

176 del self._items[name] 

177 

178 def __sizeof__(self) -> int: 

179 return self.underlying_size() 

180 

181 def underlying_size(self) -> int: 

182 # TODO Return the underlying size of all items in DataGroup 

183 total_size = super.__sizeof__(self) 

184 for item in self.values(): 

185 if isinstance(item, (DataArray, Dataset, Variable, DataGroup)): 

186 total_size += item.underlying_size() 

187 elif hasattr(item, 'nbytes'): 

188 total_size += item.nbytes 

189 else: 

190 total_size += item.__sizeof__() 

191 

192 return total_size 

193 

194 @property 

195 def dims(self) -> Tuple[Optional[str], ...]: 

196 """Union of dims of all items. Non-Scipp items are handled as dims=().""" 

197 return tuple(self.sizes) 

198 

199 @property 

200 def ndim(self): 

201 """Number of dimensions, i.e., len(self.dims).""" 

202 return len(self.dims) 

203 

204 @property 

205 def shape(self) -> Tuple[Optional[int], ...]: 

206 """Union of shape of all items. Non-Scipp items are handled as shape=().""" 

207 return tuple(self.sizes.values()) 

208 

209 @property 

210 def sizes(self) -> Dict[str, Optional[int]]: 

211 """Dict combining dims and shape, i.e., mapping dim labels to their size.""" 

212 all_sizes = {} 

213 for x in self.values(): 

214 for dim, size in getattr(x, 'sizes', {}).items(): 

215 all_sizes.setdefault(dim, set()).add(size) 

216 return {d: next(iter(s)) if len(s) == 1 else None for d, s in all_sizes.items()} 

217 

218 def _repr_html_(self): 

219 from ..visualization.formatting_datagroup_html import datagroup_repr 

220 

221 return datagroup_repr(self) 

222 

223 def __repr__(self): 

224 r = f'DataGroup(sizes={self.sizes}, keys=[\n' 

225 for name, var in self.items(): 

226 r += f' {name}: {_summarize(var)},\n' 

227 r += '])' 

228 return r 

229 

230 def __str__(self): 

231 return f'DataGroup(sizes={self.sizes}, keys={list(self.keys())})' 

232 

233 @property 

234 def bins(self): 

235 # TODO Returning a regular DataGroup here may be wrong, since the `bins` 

236 # property provides a different set of attrs and methods. 

237 return self.apply(operator.attrgetter('bins')) 

238 

239 def apply(self, func: Callable, *args, **kwargs) -> DataGroup: 

240 """Call func on all values and return new DataGroup containing the results.""" 

241 return DataGroup({key: func(v, *args, **kwargs) for key, v in self.items()}) 

242 

243 def _transform_dim( 

244 self, func: Callable, *, dim: Union[None, str, Iterable[str]], **kwargs 

245 ) -> DataGroup: 

246 """Transform items that depend on one or more dimensions given by `dim`.""" 

247 dims = (dim,) if isinstance(dim, str) else dim 

248 

249 def intersects(item): 

250 item_dims = _item_dims(item) 

251 if dims is None: 

252 return item_dims != () 

253 return set(dims).intersection(item_dims) != set() 

254 

255 return DataGroup( 

256 { 

257 key: v 

258 if not intersects(v) 

259 else operator.methodcaller(func, dim, **kwargs)(v) 

260 for key, v in self.items() 

261 } 

262 ) 

263 

264 def _reduce( 

265 self, method: str, dim: Union[None, str, Tuple[str, ...]] = None, **kwargs 

266 ) -> DataGroup: 

267 reduce_all = operator.methodcaller(method, **kwargs) 

268 

269 def _reduce_child(v, dim): 

270 if isinstance(v, (GroupByDataArray, GroupByDataset)): 

271 child_dims = (dim,) 

272 else: 

273 child_dims = _item_dims(v) 

274 # Reduction operations on binned data implicitly reduce over bin content. 

275 # Therefore, a purely dimension-based logic is not sufficient to determine 

276 # if the item has to be reduced or not. 

277 binned = _is_binned(v) 

278 if child_dims == () and not binned: 

279 return v 

280 if dim is None: 

281 return reduce_all(v) 

282 if isinstance(dim, str): 

283 dims_to_reduce = dim if dim in child_dims else () 

284 else: 

285 dims_to_reduce = tuple(d for d in dim if d in child_dims) 

286 if dims_to_reduce == () and binned: 

287 return reduce_all(v) 

288 return ( 

289 v 

290 if dims_to_reduce == () 

291 else operator.methodcaller(method, dims_to_reduce, **kwargs)(v) 

292 ) 

293 

294 return DataGroup({key: _reduce_child(v, dim) for key, v in self.items()}) 

295 

296 def copy(self, deep: bool = True) -> DataGroup: 

297 return copy.deepcopy(self) if deep else copy.copy(self) 

298 

299 def all(self, *args, **kwargs): 

300 return self._reduce('all', *args, **kwargs) 

301 

302 def any(self, *args, **kwargs): 

303 return self._reduce('any', *args, **kwargs) 

304 

305 def astype(self, *args, **kwargs): 

306 return self.apply(operator.methodcaller('astype', *args, **kwargs)) 

307 

308 def bin(self, *args, **kwargs): 

309 return self.apply(operator.methodcaller('bin', *args, **kwargs)) 

310 

311 def broadcast(self, *args, **kwargs): 

312 return self.apply(operator.methodcaller('broadcast', *args, **kwargs)) 

313 

314 def ceil(self, *args, **kwargs): 

315 return self.apply(operator.methodcaller('ceil', *args, **kwargs)) 

316 

317 def flatten(self, dims: Union[None, Iterable[str]] = None, **kwargs): 

318 return self._transform_dim('flatten', dim=dims, **kwargs) 

319 

320 def floor(self, *args, **kwargs): 

321 return self.apply(operator.methodcaller('floor', *args, **kwargs)) 

322 

323 def fold(self, dim: str, **kwargs): 

324 return self._transform_dim('fold', dim=dim, **kwargs) 

325 

326 def group(self, *args, **kwargs): 

327 return self.apply(operator.methodcaller('group', *args, **kwargs)) 

328 

329 def groupby(self, *args, **kwargs): 

330 return self.apply(operator.methodcaller('groupby', *args, **kwargs)) 

331 

332 def hist(self, *args, **kwargs): 

333 return self.apply(operator.methodcaller('hist', *args, **kwargs)) 

334 

335 def max(self, *args, **kwargs): 

336 return self._reduce('max', *args, **kwargs) 

337 

338 def mean(self, *args, **kwargs): 

339 return self._reduce('mean', *args, **kwargs) 

340 

341 def median(self, *args, **kwargs): 

342 return self._reduce('median', *args, **kwargs) 

343 

344 def min(self, *args, **kwargs): 

345 return self._reduce('min', *args, **kwargs) 

346 

347 def nanhist(self, *args, **kwargs): 

348 return self.apply(operator.methodcaller('nanhist', *args, **kwargs)) 

349 

350 def nanmax(self, *args, **kwargs): 

351 return self._reduce('nanmax', *args, **kwargs) 

352 

353 def nanmean(self, *args, **kwargs): 

354 return self._reduce('nanmean', *args, **kwargs) 

355 

356 def nanmedian(self, *args, **kwargs): 

357 return self._reduce('nanmedian', *args, **kwargs) 

358 

359 def nanmin(self, *args, **kwargs): 

360 return self._reduce('nanmin', *args, **kwargs) 

361 

362 def nansum(self, *args, **kwargs): 

363 return self._reduce('nansum', *args, **kwargs) 

364 

365 def nanstd(self, *args, **kwargs): 

366 return self._reduce('nanstd', *args, **kwargs) 

367 

368 def nanvar(self, *args, **kwargs): 

369 return self._reduce('nanvar', *args, **kwargs) 

370 

371 def rebin(self, *args, **kwargs): 

372 return self.apply(operator.methodcaller('rebin', *args, **kwargs)) 

373 

374 def rename(self, *args, **kwargs): 

375 return self.apply(operator.methodcaller('rename', *args, **kwargs)) 

376 

377 def rename_dims(self, *args, **kwargs): 

378 return self.apply(operator.methodcaller('rename_dims', *args, **kwargs)) 

379 

380 def round(self, *args, **kwargs): 

381 return self.apply(operator.methodcaller('round', *args, **kwargs)) 

382 

383 def squeeze(self, *args, **kwargs): 

384 return self._reduce('squeeze', *args, **kwargs) 

385 

386 def std(self, *args, **kwargs): 

387 return self._reduce('std', *args, **kwargs) 

388 

389 def sum(self, *args, **kwargs): 

390 return self._reduce('sum', *args, **kwargs) 

391 

392 def to(self, *args, **kwargs): 

393 return self.apply(operator.methodcaller('to', *args, **kwargs)) 

394 

395 def transform_coords(self, *args, **kwargs): 

396 return self.apply(operator.methodcaller('transform_coords', *args, **kwargs)) 

397 

398 def transpose(self, dims: Union[None, Tuple[str, ...]] = None): 

399 return self._transform_dim('transpose', dim=dims) 

400 

401 def var(self, *args, **kwargs): 

402 return self._reduce('var', *args, **kwargs) 

403 

404 def plot(self, *args, **kwargs): 

405 import plopp 

406 

407 return plopp.plot(self, *args, **kwargs) 

408 

409 def __eq__( 

410 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

411 ) -> DataGroup: 

412 """Item-wise equal.""" 

413 return data_group_nary(operator.eq, self, other) 

414 

415 def __ne__( 

416 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

417 ) -> DataGroup: 

418 """Item-wise not-equal.""" 

419 return data_group_nary(operator.ne, self, other) 

420 

421 def __gt__( 

422 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

423 ) -> DataGroup: 

424 """Item-wise greater-than.""" 

425 return data_group_nary(operator.gt, self, other) 

426 

427 def __ge__( 

428 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

429 ) -> DataGroup: 

430 """Item-wise greater-equal.""" 

431 return data_group_nary(operator.ge, self, other) 

432 

433 def __lt__( 

434 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

435 ) -> DataGroup: 

436 """Item-wise less-than.""" 

437 return data_group_nary(operator.lt, self, other) 

438 

439 def __le__( 

440 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

441 ) -> DataGroup: 

442 """Item-wise less-equal.""" 

443 return data_group_nary(operator.le, self, other) 

444 

445 def __add__( 

446 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

447 ) -> DataGroup: 

448 """Apply ``add`` item-by-item.""" 

449 return data_group_nary(operator.add, self, other) 

450 

451 def __sub__( 

452 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

453 ) -> DataGroup: 

454 """Apply ``sub`` item-by-item.""" 

455 return data_group_nary(operator.sub, self, other) 

456 

457 def __mul__( 

458 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

459 ) -> DataGroup: 

460 """Apply ``mul`` item-by-item.""" 

461 return data_group_nary(operator.mul, self, other) 

462 

463 def __truediv__( 

464 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

465 ) -> DataGroup: 

466 """Apply ``truediv`` item-by-item.""" 

467 return data_group_nary(operator.truediv, self, other) 

468 

469 def __floordiv__( 

470 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

471 ) -> DataGroup: 

472 """Apply ``floordiv`` item-by-item.""" 

473 return data_group_nary(operator.floordiv, self, other) 

474 

475 def __mod__( 

476 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

477 ) -> DataGroup: 

478 """Apply ``mod`` item-by-item.""" 

479 return data_group_nary(operator.mod, self, other) 

480 

481 def __pow__( 

482 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

483 ) -> DataGroup: 

484 """Apply ``pow`` item-by-item.""" 

485 return data_group_nary(operator.pow, self, other) 

486 

487 def __radd__( 

488 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

489 ) -> DataGroup: 

490 """Apply ``add`` item-by-item.""" 

491 return data_group_nary(operator.add, other, self) 

492 

493 def __rsub__( 

494 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

495 ) -> DataGroup: 

496 """Apply ``sub`` item-by-item.""" 

497 return data_group_nary(operator.sub, other, self) 

498 

499 def __rmul__( 

500 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

501 ) -> DataGroup: 

502 """Apply ``mul`` item-by-item.""" 

503 return data_group_nary(operator.mul, other, self) 

504 

505 def __rtruediv__( 

506 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

507 ) -> DataGroup: 

508 """Apply ``truediv`` item-by-item.""" 

509 return data_group_nary(operator.truediv, other, self) 

510 

511 def __rfloordiv__( 

512 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

513 ) -> DataGroup: 

514 """Apply ``floordiv`` item-by-item.""" 

515 return data_group_nary(operator.floordiv, other, self) 

516 

517 def __rmod__( 

518 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

519 ) -> DataGroup: 

520 """Apply ``mod`` item-by-item.""" 

521 return data_group_nary(operator.mod, other, self) 

522 

523 def __rpow__( 

524 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

525 ) -> DataGroup: 

526 """Apply ``pow`` item-by-item.""" 

527 return data_group_nary(operator.pow, other, self) 

528 

529 def __and__( 

530 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

531 ) -> DataGroup: 

532 """Return the element-wise ``and`` of items.""" 

533 return data_group_nary(operator.and_, self, other) 

534 

535 def __or__( 

536 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

537 ) -> DataGroup: 

538 """Return the element-wise ``or`` of items.""" 

539 return data_group_nary(operator.or_, self, other) 

540 

541 def __xor__( 

542 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

543 ) -> DataGroup: 

544 """Return the element-wise ``xor`` of items.""" 

545 return data_group_nary(operator.xor, self, other) 

546 

547 def __invert__(self) -> DataGroup: 

548 """Return the element-wise ``or`` of items.""" 

549 return self.apply(operator.invert) 

550 

551 

552def _data_group_binary( 

553 func: Callable, dg1: DataGroup, dg2: DataGroup, *args, **kwargs 

554) -> DataGroup: 

555 return DataGroup( 

556 { 

557 key: func(dg1[key], dg2[key], *args, **kwargs) 

558 for key in dg1.keys() & dg2.keys() 

559 } 

560 ) 

561 

562 

563def data_group_nary(func: Callable, *args, **kwargs) -> DataGroup: 

564 dgs = filter( 

565 lambda x: isinstance(x, DataGroup), itertools.chain(args, kwargs.values()) 

566 ) 

567 keys = functools.reduce(operator.and_, [dg.keys() for dg in dgs]) 

568 

569 def elem(x, key): 

570 return x[key] if isinstance(x, DataGroup) else x 

571 

572 return DataGroup( 

573 { 

574 key: func( 

575 *[elem(x, key) for x in args], 

576 **{name: elem(x, key) for name, x in kwargs.items()}, 

577 ) 

578 for key in keys 

579 } 

580 ) 

581 

582 

583def _apply_to_items( 

584 func: Callable, dgs: Iterable[DataGroup], *args, **kwargs 

585) -> DataGroup: 

586 keys = functools.reduce(operator.and_, [dg.keys() for dg in dgs]) 

587 return DataGroup( 

588 {key: func([dg[key] for dg in dgs], *args, **kwargs) for key in keys} 

589 ) 

590 

591 

592_F = TypeVar('_F', bound=Callable[..., Any]) 

593 

594 

595def data_group_overload(func: _F) -> _F: 

596 """Add an overload for DataGroup to a function. 

597 

598 If the first argument of the function is a data group, 

599 then the decorated function is mapped over all items. 

600 It is applied recursively for items that are themselves data groups. 

601 

602 Otherwise, the original function is applied directly. 

603 

604 Parameters 

605 ---------- 

606 func: 

607 Function to decorate. 

608 

609 Returns 

610 ------- 

611 : 

612 Decorated function. 

613 """ 

614 

615 # Do not assign '__annotations__' because that causes an error in Sphinx. 

616 @wraps(func, assigned=('__module__', '__name__', '__qualname__', '__doc__')) 

617 def impl(data, *args, **kwargs): 

618 if isinstance(data, DataGroup): 

619 return data.apply(impl, *args, **kwargs) 

620 return func(data, *args, **kwargs) 

621 

622 return cast(_F, impl) 

623 

624 

625# There are currently no in-place operations (__iadd__, etc.) because they require 

626# a check if the operation would fail before doing it. As otherwise, a failure could 

627# leave a partially modified data group behind. Dataset implements such a check, but 

628# it is simpler than for DataGroup because the latter supports more data types. 

629# So for now, we went with the simple solution and 

630# not support in-place operations at all. 

631# 

632# Binding these functions dynamically has the added benefit that type checkers think 

633# that the operations are not implemented. 

634def _make_inplace_binary_op(name: str): 

635 def impl( 

636 self, other: Union[DataGroup, DataArray, Variable, numbers.Real] 

637 ) -> NoReturn: 

638 raise TypeError(f'In-place operation i{name} is not supported by DataGroup.') 

639 

640 return impl 

641 

642 

643for _name in ('add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'pow'): 

644 full_name = f'__i{_name}__' 

645 _binding.bind_function_as_method( 

646 cls=DataGroup, name=full_name, func=_make_inplace_binary_op(full_name) 

647 ) 

648 

649del _name, full_name