Coverage for install/scipp/core/data_group.py: 48%

312 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-17 01:51 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2022 Scipp contributors (https://github.com/scipp) 

3# @author Simon Heybrock 

4from __future__ import annotations 

5 

6import copy 

7import functools 

8import itertools 

9import numbers 

10import operator 

11from collections.abc import ( 

12 Callable, 

13 ItemsView, 

14 Iterable, 

15 Iterator, 

16 KeysView, 

17 Mapping, 

18 MutableMapping, 

19 Sequence, 

20 ValuesView, 

21) 

22from functools import wraps 

23from typing import ( 

24 TYPE_CHECKING, 

25 Any, 

26 Concatenate, 

27 NoReturn, 

28 ParamSpec, 

29 TypeVar, 

30 cast, 

31 overload, 

32) 

33 

34import numpy as np 

35 

36from .. import _binding 

37from .cpp_classes import ( 

38 DataArray, 

39 Dataset, 

40 DimensionError, 

41 GroupByDataArray, 

42 GroupByDataset, 

43 Unit, 

44 Variable, 

45) 

46 

47if TYPE_CHECKING: 

48 # Avoid cyclic imports 

49 from ..coords.graph import GraphDict 

50 from ..typing import ScippIndex 

51 from .bins import Bins 

52 

53 

54_T = TypeVar("_T") # Any type 

55_V = TypeVar("_V") # Value type of self 

56_R = TypeVar("_R") # Return type of a callable 

57_P = ParamSpec('_P') 

58 

59 

60def _item_dims(item: Any) -> tuple[str, ...]: 

61 return getattr(item, 'dims', ()) 

62 

63 

64def _is_binned(item: Any) -> bool: 

65 from .bins import Bins 

66 

67 if isinstance(item, Bins): 

68 return True 

69 return getattr(item, 'bins', None) is not None 

70 

71 

72def _summarize(item: Any) -> str: 

73 if isinstance(item, DataGroup): 

74 return f'{type(item).__name__}({len(item)}, {item.sizes})' 

75 if hasattr(item, 'sizes'): 

76 return f'{type(item).__name__}({item.sizes})' 

77 return str(item) 

78 

79 

80def _is_positional_index(key: Any) -> bool: 

81 def is_int(x: object) -> bool: 

82 return isinstance(x, numbers.Integral) 

83 

84 if is_int(key): 

85 return True 

86 if isinstance(key, slice): 

87 if is_int(key.start) or is_int(key.stop) or is_int(key.step): 

88 return True 

89 if key.start is None and key.stop is None and key.step is None: 

90 return True 

91 return False 

92 

93 

94def _is_list_index(key: Any) -> bool: 

95 return isinstance(key, list | np.ndarray) 

96 

97 

98class DataGroup(MutableMapping[str, _V]): 

99 """ 

100 A dict-like group of data. Additionally provides dims and shape properties. 

101 

102 DataGroup acts like a Python dict but additionally supports Scipp functionality 

103 such as positional- and label-based indexing and Scipp operations by mapping them 

104 to the values in the dict. This may happen recursively to support tree-like data 

105 structures. 

106 

107 .. versionadded:: 23.01.0 

108 """ 

109 

110 def __init__( 

111 self, /, *args: Iterable[tuple[str, _V]] | Mapping[str, _V], **kwargs: _V 

112 ) -> None: 

113 self._items = dict(*args, **kwargs) 

114 if not all(isinstance(k, str) for k in self._items.keys()): 

115 raise ValueError("DataGroup keys must be strings.") 

116 

117 def __copy__(self) -> DataGroup[_V]: 

118 return DataGroup(copy.copy(self._items)) 

119 

120 def __len__(self) -> int: 

121 """Return the number of items in the data group.""" 

122 return len(self._items) 

123 

124 def __iter__(self) -> Iterator[str]: 

125 return iter(self._items) 

126 

127 def keys(self) -> KeysView[str]: 

128 return self._items.keys() 

129 

130 def values(self) -> ValuesView[_V]: 

131 return self._items.values() 

132 

133 def items(self) -> ItemsView[str, _V]: 

134 return self._items.items() 

135 

136 @overload 

137 def __getitem__(self, name: str) -> _V: ... 

138 

139 @overload 

140 def __getitem__(self, name: ScippIndex) -> DataGroup[_V]: ... 

141 

142 def __getitem__(self, name: Any) -> Any: 

143 """Return item of given name or index all items. 

144 

145 When ``name`` is a string, return the item of the given name. Otherwise, this 

146 returns a new DataGroup, with items created by indexing the items in this 

147 DataGroup. This may perform, e.g., Scipp's positional indexing, label-based 

148 indexing, or advanced indexing on items that are scipp.Variable or 

149 scipp.DataArray. 

150 

151 Label-based indexing is only possible when all items have a coordinate for the 

152 indexed dimension. 

153 

154 Advanced indexing comprises integer-array indexing and boolean-variable 

155 indexing. Unlike positional indexing, integer-array indexing works even when 

156 the item shapes are inconsistent for the indexed dimensions, provided that all 

157 items contain the maximal index in the integer array. Boolean-variable indexing 

158 is only possible when the shape of all items is compatible with the boolean 

159 variable. 

160 """ 

161 from .bins import Bins 

162 

163 if isinstance(name, str): 

164 return self._items[name] 

165 if isinstance(name, tuple) and name == (): 

166 return cast(DataGroup[Any], self).apply(operator.itemgetter(name)) 

167 if isinstance(name, Variable): # boolean indexing 

168 return cast(DataGroup[Any], self).apply(operator.itemgetter(name)) 

169 if _is_positional_index(name) or _is_list_index(name): 

170 if self.ndim != 1: 

171 raise DimensionError( 

172 "Slicing with implicit dimension label is only possible " 

173 f"for 1-D objects. Got {self.sizes} with ndim={self.ndim}. Provide " 

174 "an explicit dimension label, e.g., var['x', 0] instead of var[0]." 

175 ) 

176 dim = self.dims[0] 

177 index = name 

178 else: 

179 dim, index = name 

180 return DataGroup( 

181 { 

182 key: var[dim, index] # type: ignore[index] 

183 if (isinstance(var, Bins) or dim in _item_dims(var)) 

184 else var 

185 for key, var in self.items() 

186 } 

187 ) 

188 

189 def __setitem__(self, name: str, value: _V) -> None: 

190 """Set self[key] to value.""" 

191 if isinstance(name, str): 

192 self._items[name] = value 

193 else: 

194 raise TypeError('Keys must be strings') 

195 

196 def __delitem__(self, name: str) -> None: 

197 """Delete self[key].""" 

198 del self._items[name] 

199 

200 def __sizeof__(self) -> int: 

201 return self.underlying_size() 

202 

203 def underlying_size(self) -> int: 

204 # TODO Return the underlying size of all items in DataGroup 

205 total_size = super.__sizeof__(self) 

206 for item in self.values(): 

207 if isinstance(item, DataArray | Dataset | Variable | DataGroup): 

208 total_size += item.underlying_size() 

209 elif hasattr(item, 'nbytes'): 

210 total_size += item.nbytes 

211 else: 

212 total_size += item.__sizeof__() 

213 

214 return total_size 

215 

216 @property 

217 def dims(self) -> tuple[str, ...]: 

218 """Union of dims of all items. Non-Scipp items are handled as dims=().""" 

219 return tuple(self.sizes) 

220 

221 @property 

222 def ndim(self) -> int: 

223 """Number of dimensions, i.e., len(self.dims).""" 

224 return len(self.dims) 

225 

226 @property 

227 def shape(self) -> tuple[int | None, ...]: 

228 """Union of shape of all items. Non-Scipp items are handled as shape=().""" 

229 return tuple(self.sizes.values()) 

230 

231 @property 

232 def sizes(self) -> dict[str, int | None]: 

233 """Dict combining dims and shape, i.e., mapping dim labels to their size.""" 

234 all_sizes: dict[str, set[int]] = {} 

235 for x in self.values(): 

236 for dim, size in getattr(x, 'sizes', {}).items(): 

237 all_sizes.setdefault(dim, set()).add(size) 

238 return {d: next(iter(s)) if len(s) == 1 else None for d, s in all_sizes.items()} 

239 

240 def _repr_html_(self) -> str: 

241 from ..visualization.formatting_datagroup_html import datagroup_repr 

242 

243 return datagroup_repr(self) 

244 

245 def __repr__(self) -> str: 

246 r = f'DataGroup(sizes={self.sizes}, keys=[\n' 

247 for name, var in self.items(): 

248 r += f' {name}: {_summarize(var)},\n' 

249 r += '])' 

250 return r 

251 

252 def __str__(self) -> str: 

253 return f'DataGroup(sizes={self.sizes}, keys={list(self.keys())})' 

254 

255 @property 

256 def bins(self) -> DataGroup[DataGroup[Any] | Bins[Any] | None]: 

257 # TODO Returning a regular DataGroup here may be wrong, since the `bins` 

258 # property provides a different set of attrs and methods. 

259 return self.apply(operator.attrgetter('bins')) 

260 

261 def apply( 

262 self, 

263 func: Callable[Concatenate[_V, _P], _R], 

264 *args: _P.args, 

265 **kwargs: _P.kwargs, 

266 ) -> DataGroup[_R]: 

267 """Call func on all values and return new DataGroup containing the results.""" 

268 return DataGroup({key: func(v, *args, **kwargs) for key, v in self.items()}) 

269 

270 def _transform_dim( 

271 self, func: str, *, dim: None | str | Iterable[str], **kwargs: Any 

272 ) -> DataGroup[Any]: 

273 """Transform items that depend on one or more dimensions given by `dim`.""" 

274 dims = (dim,) if isinstance(dim, str) else dim 

275 

276 def intersects(item: _V) -> bool: 

277 item_dims = _item_dims(item) 

278 if dims is None: 

279 return item_dims != () 

280 return set(dims).intersection(item_dims) != set() 

281 

282 return DataGroup( 

283 { 

284 key: v 

285 if not intersects(v) 

286 else operator.methodcaller(func, dim, **kwargs)(v) 

287 for key, v in self.items() 

288 } 

289 ) 

290 

291 def _reduce( 

292 self, method: str, dim: None | str | Sequence[str] = None, **kwargs: Any 

293 ) -> DataGroup[Any]: 

294 reduce_all = operator.methodcaller(method, **kwargs) 

295 

296 def _reduce_child(v: _V) -> Any: 

297 if isinstance(v, GroupByDataArray | GroupByDataset): 

298 child_dims: tuple[None | str | Sequence[str], ...] = (dim,) 

299 else: 

300 child_dims = _item_dims(v) 

301 # Reduction operations on binned data implicitly reduce over bin content. 

302 # Therefore, a purely dimension-based logic is not sufficient to determine 

303 # if the item has to be reduced or not. 

304 binned = _is_binned(v) 

305 if child_dims == () and not binned: 

306 return v 

307 if dim is None: 

308 return reduce_all(v) 

309 if isinstance(dim, str): 

310 dims_to_reduce: tuple[str, ...] | str = dim if dim in child_dims else () 

311 else: 

312 dims_to_reduce = tuple(d for d in dim if d in child_dims) 

313 if dims_to_reduce == () and binned: 

314 return reduce_all(v) 

315 return ( 

316 v 

317 if dims_to_reduce == () 

318 else operator.methodcaller(method, dims_to_reduce, **kwargs)(v) 

319 ) 

320 

321 return DataGroup({key: _reduce_child(v) for key, v in self.items()}) 

322 

323 def copy(self, deep: bool = True) -> DataGroup[_V]: 

324 return copy.deepcopy(self) if deep else copy.copy(self) 

325 

326 def all(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]: 

327 return self._reduce('all', dim) 

328 

329 def any(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]: 

330 return self._reduce('any', dim) 

331 

332 def astype(self, type: Any, *, copy: bool = True) -> DataGroup[_V]: 

333 return self.apply(operator.methodcaller('astype', type, copy=copy)) 

334 

335 def bin( 

336 self, 

337 arg_dict: dict[str, int | Variable] | None = None, 

338 /, 

339 **kwargs: int | Variable, 

340 ) -> DataGroup[_V]: 

341 return self.apply(operator.methodcaller('bin', arg_dict, **kwargs)) 

342 

343 @overload 

344 def broadcast( 

345 self, 

346 *, 

347 dims: Sequence[str], 

348 shape: Sequence[int], 

349 ) -> DataGroup[_V]: ... 

350 

351 @overload 

352 def broadcast( 

353 self, 

354 *, 

355 sizes: dict[str, int], 

356 ) -> DataGroup[_V]: ... 

357 

358 def broadcast( 

359 self, 

360 *, 

361 dims: Sequence[str] | None = None, 

362 shape: Sequence[int] | None = None, 

363 sizes: dict[str, int] | None = None, 

364 ) -> DataGroup[_V]: 

365 return self.apply( 

366 operator.methodcaller('broadcast', dims=dims, shape=shape, sizes=sizes) 

367 ) 

368 

369 def ceil(self) -> DataGroup[_V]: 

370 return self.apply(operator.methodcaller('ceil')) 

371 

372 def flatten( 

373 self, dims: Sequence[str] | None = None, to: str | None = None 

374 ) -> DataGroup[_V]: 

375 return self._transform_dim('flatten', dim=dims, to=to) 

376 

377 def floor(self) -> DataGroup[_V]: 

378 return self.apply(operator.methodcaller('floor')) 

379 

380 @overload 

381 def fold( 

382 self, 

383 dim: str, 

384 *, 

385 dims: Sequence[str], 

386 shape: Sequence[int], 

387 ) -> DataGroup[_V]: ... 

388 

389 @overload 

390 def fold( 

391 self, 

392 dim: str, 

393 *, 

394 sizes: dict[str, int], 

395 ) -> DataGroup[_V]: ... 

396 

397 def fold( 

398 self, 

399 dim: str, 

400 *, 

401 dims: Sequence[str] | None = None, 

402 shape: Sequence[int] | None = None, 

403 sizes: dict[str, int] | None = None, 

404 ) -> DataGroup[_V]: 

405 return self._transform_dim('fold', dim=dim, dims=dims, shape=shape, sizes=sizes) 

406 

407 def group(self, /, *args: str | Variable) -> DataGroup[_V]: 

408 return self.apply(operator.methodcaller('group', *args)) 

409 

410 def groupby( 

411 self, /, group: Variable | str, *, bins: Variable | None = None 

412 ) -> DataGroup[GroupByDataArray | GroupByDataset]: 

413 return self.apply(operator.methodcaller('groupby', group, bins=bins)) 

414 

415 def hist( 

416 self, 

417 arg_dict: dict[str, int | Variable] | None = None, 

418 /, 

419 **kwargs: int | Variable, 

420 ) -> DataGroup[DataArray | Dataset]: 

421 return self.apply(operator.methodcaller('hist', arg_dict, **kwargs)) 

422 

423 def max(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]: 

424 return self._reduce('max', dim) 

425 

426 def mean(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]: 

427 return self._reduce('mean', dim) 

428 

429 def median(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]: 

430 return self._reduce('median', dim) 

431 

432 def min(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]: 

433 return self._reduce('min', dim) 

434 

435 def nanhist( 

436 self, 

437 arg_dict: dict[str, int | Variable] | None = None, 

438 /, 

439 **kwargs: int | Variable, 

440 ) -> DataGroup[DataArray]: 

441 return self.apply(operator.methodcaller('nanhist', arg_dict, **kwargs)) 

442 

443 def nanmax(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]: 

444 return self._reduce('nanmax', dim) 

445 

446 def nanmean(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]: 

447 return self._reduce('nanmean', dim) 

448 

449 def nanmedian(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]: 

450 return self._reduce('nanmedian', dim) 

451 

452 def nanmin(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]: 

453 return self._reduce('nanmin', dim) 

454 

455 def nansum(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]: 

456 return self._reduce('nansum', dim) 

457 

458 def nanstd( 

459 self, dim: None | str | tuple[str, ...] = None, *, ddof: int 

460 ) -> DataGroup[_V]: 

461 return self._reduce('nanstd', dim, ddof=ddof) 

462 

463 def nanvar( 

464 self, dim: None | str | tuple[str, ...] = None, *, ddof: int 

465 ) -> DataGroup[_V]: 

466 return self._reduce('nanvar', dim, ddof=ddof) 

467 

468 def rebin( 

469 self, 

470 arg_dict: dict[str, int | Variable] | None = None, 

471 /, 

472 **kwargs: int | Variable, 

473 ) -> DataGroup[_V]: 

474 return self.apply(operator.methodcaller('rebin', arg_dict, **kwargs)) 

475 

476 def rename( 

477 self, dims_dict: dict[str, str] | None = None, /, **names: str 

478 ) -> DataGroup[_V]: 

479 return self.apply(operator.methodcaller('rename', dims_dict, **names)) 

480 

481 def rename_dims( 

482 self, dims_dict: dict[str, str] | None = None, /, **names: str 

483 ) -> DataGroup[_V]: 

484 return self.apply(operator.methodcaller('rename_dims', dims_dict, **names)) 

485 

486 def round(self) -> DataGroup[_V]: 

487 return self.apply(operator.methodcaller('round')) 

488 

489 def squeeze(self, dim: str | Sequence[str] | None = None) -> DataGroup[_V]: 

490 return self._reduce('squeeze', dim) 

491 

492 def std( 

493 self, dim: None | str | tuple[str, ...] = None, *, ddof: int 

494 ) -> DataGroup[_V]: 

495 return self._reduce('std', dim, ddof=ddof) 

496 

497 def sum(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]: 

498 return self._reduce('sum', dim) 

499 

500 def to( 

501 self, 

502 *, 

503 unit: Unit | str | None = None, 

504 dtype: Any | None = None, 

505 copy: bool = True, 

506 ) -> DataGroup[_V]: 

507 return self.apply( 

508 operator.methodcaller('to', unit=unit, dtype=dtype, copy=copy) 

509 ) 

510 

511 def transform_coords( 

512 self, 

513 targets: str | Iterable[str] | None = None, 

514 /, 

515 graph: GraphDict | None = None, 

516 *, 

517 rename_dims: bool = True, 

518 keep_aliases: bool = True, 

519 keep_intermediate: bool = True, 

520 keep_inputs: bool = True, 

521 quiet: bool = False, 

522 **kwargs: Callable[..., Variable], 

523 ) -> DataGroup[_V]: 

524 return self.apply( 

525 operator.methodcaller( 

526 'transform_coords', 

527 targets, 

528 graph=graph, 

529 rename_dims=rename_dims, 

530 keep_aliases=keep_aliases, 

531 keep_intermediate=keep_intermediate, 

532 keep_inputs=keep_inputs, 

533 quiet=quiet, 

534 **kwargs, 

535 ) 

536 ) 

537 

538 def transpose(self, dims: None | tuple[str, ...] = None) -> DataGroup[_V]: 

539 return self._transform_dim('transpose', dim=dims) 

540 

541 def var( 

542 self, dim: None | str | tuple[str, ...] = None, *, ddof: int 

543 ) -> DataGroup[_V]: 

544 return self._reduce('var', dim, ddof=ddof) 

545 

546 def plot(self, *args: Any, **kwargs: Any) -> Any: 

547 import plopp 

548 

549 return plopp.plot(self, *args, **kwargs) 

550 

551 def __eq__( # type: ignore[override] 

552 self, other: DataGroup[object] | DataArray | Variable | float 

553 ) -> DataGroup[_V | bool]: 

554 """Item-wise equal.""" 

555 return data_group_nary(operator.eq, self, other) 

556 

557 def __ne__( # type: ignore[override] 

558 self, other: DataGroup[object] | DataArray | Variable | float 

559 ) -> DataGroup[_V | bool]: 

560 """Item-wise not-equal.""" 

561 return data_group_nary(operator.ne, self, other) 

562 

563 def __gt__( 

564 self, other: DataGroup[object] | DataArray | Variable | float 

565 ) -> DataGroup[_V | bool]: 

566 """Item-wise greater-than.""" 

567 return data_group_nary(operator.gt, self, other) 

568 

569 def __ge__( 

570 self, other: DataGroup[object] | DataArray | Variable | float 

571 ) -> DataGroup[_V | bool]: 

572 """Item-wise greater-equal.""" 

573 return data_group_nary(operator.ge, self, other) 

574 

575 def __lt__( 

576 self, other: DataGroup[object] | DataArray | Variable | float 

577 ) -> DataGroup[_V | bool]: 

578 """Item-wise less-than.""" 

579 return data_group_nary(operator.lt, self, other) 

580 

581 def __le__( 

582 self, other: DataGroup[object] | DataArray | Variable | float 

583 ) -> DataGroup[_V | bool]: 

584 """Item-wise less-equal.""" 

585 return data_group_nary(operator.le, self, other) 

586 

587 def __add__( 

588 self, other: DataGroup[Any] | DataArray | Variable | float 

589 ) -> DataGroup[Any]: 

590 """Apply ``add`` item-by-item.""" 

591 return data_group_nary(operator.add, self, other) 

592 

593 def __sub__( 

594 self, other: DataGroup[Any] | DataArray | Variable | float 

595 ) -> DataGroup[Any]: 

596 """Apply ``sub`` item-by-item.""" 

597 return data_group_nary(operator.sub, self, other) 

598 

599 def __mul__( 

600 self, other: DataGroup[Any] | DataArray | Variable | float 

601 ) -> DataGroup[Any]: 

602 """Apply ``mul`` item-by-item.""" 

603 return data_group_nary(operator.mul, self, other) 

604 

605 def __truediv__( 

606 self, other: DataGroup[Any] | DataArray | Variable | float 

607 ) -> DataGroup[Any]: 

608 """Apply ``truediv`` item-by-item.""" 

609 return data_group_nary(operator.truediv, self, other) 

610 

611 def __floordiv__( 

612 self, other: DataGroup[Any] | DataArray | Variable | float 

613 ) -> DataGroup[Any]: 

614 """Apply ``floordiv`` item-by-item.""" 

615 return data_group_nary(operator.floordiv, self, other) 

616 

617 def __mod__( 

618 self, other: DataGroup[Any] | DataArray | Variable | float 

619 ) -> DataGroup[Any]: 

620 """Apply ``mod`` item-by-item.""" 

621 return data_group_nary(operator.mod, self, other) 

622 

623 def __pow__( 

624 self, other: DataGroup[Any] | DataArray | Variable | float 

625 ) -> DataGroup[Any]: 

626 """Apply ``pow`` item-by-item.""" 

627 return data_group_nary(operator.pow, self, other) 

628 

629 def __radd__(self, other: DataArray | Variable | float) -> DataGroup[Any]: 

630 """Apply ``add`` item-by-item.""" 

631 return data_group_nary(operator.add, other, self) 

632 

633 def __rsub__(self, other: DataArray | Variable | float) -> DataGroup[Any]: 

634 """Apply ``sub`` item-by-item.""" 

635 return data_group_nary(operator.sub, other, self) 

636 

637 def __rmul__(self, other: DataArray | Variable | float) -> DataGroup[Any]: 

638 """Apply ``mul`` item-by-item.""" 

639 return data_group_nary(operator.mul, other, self) 

640 

641 def __rtruediv__(self, other: DataArray | Variable | float) -> DataGroup[Any]: 

642 """Apply ``truediv`` item-by-item.""" 

643 return data_group_nary(operator.truediv, other, self) 

644 

645 def __rfloordiv__(self, other: DataArray | Variable | float) -> DataGroup[Any]: 

646 """Apply ``floordiv`` item-by-item.""" 

647 return data_group_nary(operator.floordiv, other, self) 

648 

649 def __rmod__(self, other: DataArray | Variable | float) -> DataGroup[Any]: 

650 """Apply ``mod`` item-by-item.""" 

651 return data_group_nary(operator.mod, other, self) 

652 

653 def __rpow__(self, other: DataArray | Variable | float) -> DataGroup[Any]: 

654 """Apply ``pow`` item-by-item.""" 

655 return data_group_nary(operator.pow, other, self) 

656 

657 def __and__( 

658 self, other: DataGroup[Any] | DataArray | Variable | float 

659 ) -> DataGroup[Any]: 

660 """Return the element-wise ``and`` of items.""" 

661 return data_group_nary(operator.and_, self, other) 

662 

663 def __or__( 

664 self, other: DataGroup[Any] | DataArray | Variable | float 

665 ) -> DataGroup[Any]: 

666 """Return the element-wise ``or`` of items.""" 

667 return data_group_nary(operator.or_, self, other) 

668 

669 def __xor__( 

670 self, other: DataGroup[Any] | DataArray | Variable | float 

671 ) -> DataGroup[Any]: 

672 """Return the element-wise ``xor`` of items.""" 

673 return data_group_nary(operator.xor, self, other) 

674 

675 def __invert__(self) -> DataGroup[Any]: 

676 """Return the element-wise ``or`` of items.""" 

677 return self.apply(operator.invert) # type: ignore[arg-type] 

678 

679 

680def data_group_nary( 

681 func: Callable[..., _R], *args: Any, **kwargs: Any 

682) -> DataGroup[_R]: 

683 dgs = filter( 

684 lambda x: isinstance(x, DataGroup), itertools.chain(args, kwargs.values()) 

685 ) 

686 keys = functools.reduce(operator.and_, [dg.keys() for dg in dgs]) 

687 

688 def elem(x: Any, key: str) -> Any: 

689 return x[key] if isinstance(x, DataGroup) else x 

690 

691 return DataGroup( 

692 { 

693 key: func( 

694 *[elem(x, key) for x in args], 

695 **{name: elem(x, key) for name, x in kwargs.items()}, 

696 ) 

697 for key in keys 

698 } 

699 ) 

700 

701 

702def apply_to_items( 

703 func: Callable[..., _R], dgs: Iterable[DataGroup[Any]], *args: Any, **kwargs: Any 

704) -> DataGroup[_R]: 

705 keys = functools.reduce(operator.and_, [dg.keys() for dg in dgs]) 

706 return DataGroup( 

707 {key: func([dg[key] for dg in dgs], *args, **kwargs) for key in keys} 

708 ) 

709 

710 

711def data_group_overload( 

712 func: Callable[Concatenate[_T, _P], _R], 

713) -> Callable[..., _R | DataGroup[_R]]: 

714 """Add an overload for DataGroup to a function. 

715 

716 If the first argument of the function is a data group, 

717 then the decorated function is mapped over all items. 

718 It is applied recursively for items that are themselves data groups. 

719 

720 Otherwise, the original function is applied directly. 

721 

722 Parameters 

723 ---------- 

724 func: 

725 Function to decorate. 

726 

727 Returns 

728 ------- 

729 : 

730 Decorated function. 

731 """ 

732 

733 # Do not assign '__annotations__' because that causes an error in Sphinx. 

734 @wraps(func, assigned=('__module__', '__name__', '__qualname__', '__doc__')) 

735 def impl( 

736 data: _T | DataGroup[Any], *args: _P.args, **kwargs: _P.kwargs 

737 ) -> _R | DataGroup[_R]: 

738 if isinstance(data, DataGroup): 

739 return data.apply(impl, *args, **kwargs) # type: ignore[arg-type] 

740 return func(data, *args, **kwargs) 

741 

742 return impl 

743 

744 

745# There are currently no in-place operations (__iadd__, etc.) because they require 

746# a check if the operation would fail before doing it. As otherwise, a failure could 

747# leave a partially modified data group behind. Dataset implements such a check, but 

748# it is simpler than for DataGroup because the latter supports more data types. 

749# So for now, we went with the simple solution and 

750# not support in-place operations at all. 

751# 

752# Binding these functions dynamically has the added benefit that type checkers think 

753# that the operations are not implemented. 

754def _make_inplace_binary_op(name: str) -> Callable[..., NoReturn]: 

755 def impl( 

756 self: DataGroup[Any], other: DataGroup[Any] | DataArray | Variable | float 

757 ) -> NoReturn: 

758 raise TypeError(f'In-place operation i{name} is not supported by DataGroup.') 

759 

760 return impl 

761 

762 

763for _name in ('add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'pow'): 

764 full_name = f'__i{_name}__' 

765 _binding.bind_function_as_method( 

766 cls=DataGroup, name=full_name, func=_make_inplace_binary_op(full_name) 

767 ) 

768 

769del _name, full_name