Coverage for install/scipp/core/data_group.py: 48%
312 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-01 01:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-01 01:59 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2022 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock
4from __future__ import annotations
6import copy
7import functools
8import itertools
9import numbers
10import operator
11from collections.abc import (
12 Callable,
13 ItemsView,
14 Iterable,
15 Iterator,
16 KeysView,
17 Mapping,
18 MutableMapping,
19 Sequence,
20 ValuesView,
21)
22from functools import wraps
23from typing import (
24 TYPE_CHECKING,
25 Any,
26 Concatenate,
27 NoReturn,
28 ParamSpec,
29 TypeVar,
30 cast,
31 overload,
32)
34import numpy as np
36from .. import _binding
37from .cpp_classes import (
38 DataArray,
39 Dataset,
40 DimensionError,
41 GroupByDataArray,
42 GroupByDataset,
43 Unit,
44 Variable,
45)
47if TYPE_CHECKING:
48 # Avoid cyclic imports
49 from ..coords.graph import GraphDict
50 from ..typing import ScippIndex
51 from .bins import Bins
54_T = TypeVar("_T") # Any type
55_V = TypeVar("_V") # Value type of self
56_R = TypeVar("_R") # Return type of a callable
57_P = ParamSpec('_P')
60def _item_dims(item: Any) -> tuple[str, ...]:
61 return getattr(item, 'dims', ())
64def _is_binned(item: Any) -> bool:
65 from .bins import Bins
67 if isinstance(item, Bins):
68 return True
69 return getattr(item, 'bins', None) is not None
72def _summarize(item: Any) -> str:
73 if isinstance(item, DataGroup):
74 return f'{type(item).__name__}({len(item)}, {item.sizes})'
75 if hasattr(item, 'sizes'):
76 return f'{type(item).__name__}({item.sizes})'
77 return str(item)
80def _is_positional_index(key: Any) -> bool:
81 def is_int(x: object) -> bool:
82 return isinstance(x, numbers.Integral)
84 if is_int(key):
85 return True
86 if isinstance(key, slice):
87 if is_int(key.start) or is_int(key.stop) or is_int(key.step):
88 return True
89 if key.start is None and key.stop is None and key.step is None:
90 return True
91 return False
94def _is_list_index(key: Any) -> bool:
95 return isinstance(key, list | np.ndarray)
98class DataGroup(MutableMapping[str, _V]):
99 """
100 A dict-like group of data. Additionally provides dims and shape properties.
102 DataGroup acts like a Python dict but additionally supports Scipp functionality
103 such as positional- and label-based indexing and Scipp operations by mapping them
104 to the values in the dict. This may happen recursively to support tree-like data
105 structures.
107 .. versionadded:: 23.01.0
108 """
110 def __init__(
111 self, /, *args: Iterable[tuple[str, _V]] | Mapping[str, _V], **kwargs: _V
112 ) -> None:
113 self._items = dict(*args, **kwargs)
114 if not all(isinstance(k, str) for k in self._items.keys()):
115 raise ValueError("DataGroup keys must be strings.")
117 def __copy__(self) -> DataGroup[_V]:
118 return DataGroup(copy.copy(self._items))
120 def __len__(self) -> int:
121 """Return the number of items in the data group."""
122 return len(self._items)
124 def __iter__(self) -> Iterator[str]:
125 return iter(self._items)
127 def keys(self) -> KeysView[str]:
128 return self._items.keys()
130 def values(self) -> ValuesView[_V]:
131 return self._items.values()
133 def items(self) -> ItemsView[str, _V]:
134 return self._items.items()
136 @overload
137 def __getitem__(self, name: str) -> _V: ...
139 @overload
140 def __getitem__(self, name: ScippIndex) -> DataGroup[_V]: ...
142 def __getitem__(self, name: Any) -> Any:
143 """Return item of given name or index all items.
145 When ``name`` is a string, return the item of the given name. Otherwise, this
146 returns a new DataGroup, with items created by indexing the items in this
147 DataGroup. This may perform, e.g., Scipp's positional indexing, label-based
148 indexing, or advanced indexing on items that are scipp.Variable or
149 scipp.DataArray.
151 Label-based indexing is only possible when all items have a coordinate for the
152 indexed dimension.
154 Advanced indexing comprises integer-array indexing and boolean-variable
155 indexing. Unlike positional indexing, integer-array indexing works even when
156 the item shapes are inconsistent for the indexed dimensions, provided that all
157 items contain the maximal index in the integer array. Boolean-variable indexing
158 is only possible when the shape of all items is compatible with the boolean
159 variable.
160 """
161 from .bins import Bins
163 if isinstance(name, str):
164 return self._items[name]
165 if isinstance(name, tuple) and name == ():
166 return cast(DataGroup[Any], self).apply(operator.itemgetter(name))
167 if isinstance(name, Variable): # boolean indexing
168 return cast(DataGroup[Any], self).apply(operator.itemgetter(name))
169 if _is_positional_index(name) or _is_list_index(name):
170 if self.ndim != 1:
171 raise DimensionError(
172 "Slicing with implicit dimension label is only possible "
173 f"for 1-D objects. Got {self.sizes} with ndim={self.ndim}. Provide "
174 "an explicit dimension label, e.g., var['x', 0] instead of var[0]."
175 )
176 dim = self.dims[0]
177 index = name
178 else:
179 dim, index = name
180 return DataGroup(
181 {
182 key: var[dim, index] # type: ignore[index]
183 if (isinstance(var, Bins) or dim in _item_dims(var))
184 else var
185 for key, var in self.items()
186 }
187 )
189 def __setitem__(self, name: str, value: _V) -> None:
190 """Set self[key] to value."""
191 if isinstance(name, str):
192 self._items[name] = value
193 else:
194 raise TypeError('Keys must be strings')
196 def __delitem__(self, name: str) -> None:
197 """Delete self[key]."""
198 del self._items[name]
200 def __sizeof__(self) -> int:
201 return self.underlying_size()
203 def underlying_size(self) -> int:
204 # TODO Return the underlying size of all items in DataGroup
205 total_size = super.__sizeof__(self)
206 for item in self.values():
207 if isinstance(item, DataArray | Dataset | Variable | DataGroup):
208 total_size += item.underlying_size()
209 elif hasattr(item, 'nbytes'):
210 total_size += item.nbytes
211 else:
212 total_size += item.__sizeof__()
214 return total_size
216 @property
217 def dims(self) -> tuple[str, ...]:
218 """Union of dims of all items. Non-Scipp items are handled as dims=()."""
219 return tuple(self.sizes)
221 @property
222 def ndim(self) -> int:
223 """Number of dimensions, i.e., len(self.dims)."""
224 return len(self.dims)
226 @property
227 def shape(self) -> tuple[int | None, ...]:
228 """Union of shape of all items. Non-Scipp items are handled as shape=()."""
229 return tuple(self.sizes.values())
231 @property
232 def sizes(self) -> dict[str, int | None]:
233 """Dict combining dims and shape, i.e., mapping dim labels to their size."""
234 all_sizes: dict[str, set[int]] = {}
235 for x in self.values():
236 for dim, size in getattr(x, 'sizes', {}).items():
237 all_sizes.setdefault(dim, set()).add(size)
238 return {d: next(iter(s)) if len(s) == 1 else None for d, s in all_sizes.items()}
240 def _repr_html_(self) -> str:
241 from ..visualization.formatting_datagroup_html import datagroup_repr
243 return datagroup_repr(self)
245 def __repr__(self) -> str:
246 r = f'DataGroup(sizes={self.sizes}, keys=[\n'
247 for name, var in self.items():
248 r += f' {name}: {_summarize(var)},\n'
249 r += '])'
250 return r
252 def __str__(self) -> str:
253 return f'DataGroup(sizes={self.sizes}, keys={list(self.keys())})'
255 @property
256 def bins(self) -> DataGroup[DataGroup[Any] | Bins[Any] | None]:
257 # TODO Returning a regular DataGroup here may be wrong, since the `bins`
258 # property provides a different set of attrs and methods.
259 return self.apply(operator.attrgetter('bins'))
261 def apply(
262 self,
263 func: Callable[Concatenate[_V, _P], _R],
264 *args: _P.args,
265 **kwargs: _P.kwargs,
266 ) -> DataGroup[_R]:
267 """Call func on all values and return new DataGroup containing the results."""
268 return DataGroup({key: func(v, *args, **kwargs) for key, v in self.items()})
270 def _transform_dim(
271 self, func: str, *, dim: None | str | Iterable[str], **kwargs: Any
272 ) -> DataGroup[Any]:
273 """Transform items that depend on one or more dimensions given by `dim`."""
274 dims = (dim,) if isinstance(dim, str) else dim
276 def intersects(item: _V) -> bool:
277 item_dims = _item_dims(item)
278 if dims is None:
279 return item_dims != ()
280 return set(dims).intersection(item_dims) != set()
282 return DataGroup(
283 {
284 key: v
285 if not intersects(v)
286 else operator.methodcaller(func, dim, **kwargs)(v)
287 for key, v in self.items()
288 }
289 )
291 def _reduce(
292 self, method: str, dim: None | str | Sequence[str] = None, **kwargs: Any
293 ) -> DataGroup[Any]:
294 reduce_all = operator.methodcaller(method, **kwargs)
296 def _reduce_child(v: _V) -> Any:
297 if isinstance(v, GroupByDataArray | GroupByDataset):
298 child_dims: tuple[None | str | Sequence[str], ...] = (dim,)
299 else:
300 child_dims = _item_dims(v)
301 # Reduction operations on binned data implicitly reduce over bin content.
302 # Therefore, a purely dimension-based logic is not sufficient to determine
303 # if the item has to be reduced or not.
304 binned = _is_binned(v)
305 if child_dims == () and not binned:
306 return v
307 if dim is None:
308 return reduce_all(v)
309 if isinstance(dim, str):
310 dims_to_reduce: tuple[str, ...] | str = dim if dim in child_dims else ()
311 else:
312 dims_to_reduce = tuple(d for d in dim if d in child_dims)
313 if dims_to_reduce == () and binned:
314 return reduce_all(v)
315 return (
316 v
317 if dims_to_reduce == ()
318 else operator.methodcaller(method, dims_to_reduce, **kwargs)(v)
319 )
321 return DataGroup({key: _reduce_child(v) for key, v in self.items()})
323 def copy(self, deep: bool = True) -> DataGroup[_V]:
324 return copy.deepcopy(self) if deep else copy.copy(self)
326 def all(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
327 return self._reduce('all', dim)
329 def any(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
330 return self._reduce('any', dim)
332 def astype(self, type: Any, *, copy: bool = True) -> DataGroup[_V]:
333 return self.apply(operator.methodcaller('astype', type, copy=copy))
335 def bin(
336 self,
337 arg_dict: dict[str, int | Variable] | None = None,
338 /,
339 **kwargs: int | Variable,
340 ) -> DataGroup[_V]:
341 return self.apply(operator.methodcaller('bin', arg_dict, **kwargs))
343 @overload
344 def broadcast(
345 self,
346 *,
347 dims: Sequence[str],
348 shape: Sequence[int],
349 ) -> DataGroup[_V]: ...
351 @overload
352 def broadcast(
353 self,
354 *,
355 sizes: dict[str, int],
356 ) -> DataGroup[_V]: ...
358 def broadcast(
359 self,
360 *,
361 dims: Sequence[str] | None = None,
362 shape: Sequence[int] | None = None,
363 sizes: dict[str, int] | None = None,
364 ) -> DataGroup[_V]:
365 return self.apply(
366 operator.methodcaller('broadcast', dims=dims, shape=shape, sizes=sizes)
367 )
369 def ceil(self) -> DataGroup[_V]:
370 return self.apply(operator.methodcaller('ceil'))
372 def flatten(
373 self, dims: Sequence[str] | None = None, to: str | None = None
374 ) -> DataGroup[_V]:
375 return self._transform_dim('flatten', dim=dims, to=to)
377 def floor(self) -> DataGroup[_V]:
378 return self.apply(operator.methodcaller('floor'))
380 @overload
381 def fold(
382 self,
383 dim: str,
384 *,
385 dims: Sequence[str],
386 shape: Sequence[int],
387 ) -> DataGroup[_V]: ...
389 @overload
390 def fold(
391 self,
392 dim: str,
393 *,
394 sizes: dict[str, int],
395 ) -> DataGroup[_V]: ...
397 def fold(
398 self,
399 dim: str,
400 *,
401 dims: Sequence[str] | None = None,
402 shape: Sequence[int] | None = None,
403 sizes: dict[str, int] | None = None,
404 ) -> DataGroup[_V]:
405 return self._transform_dim('fold', dim=dim, dims=dims, shape=shape, sizes=sizes)
407 def group(self, /, *args: str | Variable) -> DataGroup[_V]:
408 return self.apply(operator.methodcaller('group', *args))
410 def groupby(
411 self, /, group: Variable | str, *, bins: Variable | None = None
412 ) -> DataGroup[GroupByDataArray | GroupByDataset]:
413 return self.apply(operator.methodcaller('groupby', group, bins=bins))
415 def hist(
416 self,
417 arg_dict: dict[str, int | Variable] | None = None,
418 /,
419 **kwargs: int | Variable,
420 ) -> DataGroup[DataArray | Dataset]:
421 return self.apply(operator.methodcaller('hist', arg_dict, **kwargs))
423 def max(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
424 return self._reduce('max', dim)
426 def mean(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
427 return self._reduce('mean', dim)
429 def median(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
430 return self._reduce('median', dim)
432 def min(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
433 return self._reduce('min', dim)
435 def nanhist(
436 self,
437 arg_dict: dict[str, int | Variable] | None = None,
438 /,
439 **kwargs: int | Variable,
440 ) -> DataGroup[DataArray]:
441 return self.apply(operator.methodcaller('nanhist', arg_dict, **kwargs))
443 def nanmax(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
444 return self._reduce('nanmax', dim)
446 def nanmean(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
447 return self._reduce('nanmean', dim)
449 def nanmedian(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
450 return self._reduce('nanmedian', dim)
452 def nanmin(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
453 return self._reduce('nanmin', dim)
455 def nansum(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
456 return self._reduce('nansum', dim)
458 def nanstd(
459 self, dim: None | str | tuple[str, ...] = None, *, ddof: int
460 ) -> DataGroup[_V]:
461 return self._reduce('nanstd', dim, ddof=ddof)
463 def nanvar(
464 self, dim: None | str | tuple[str, ...] = None, *, ddof: int
465 ) -> DataGroup[_V]:
466 return self._reduce('nanvar', dim, ddof=ddof)
468 def rebin(
469 self,
470 arg_dict: dict[str, int | Variable] | None = None,
471 /,
472 **kwargs: int | Variable,
473 ) -> DataGroup[_V]:
474 return self.apply(operator.methodcaller('rebin', arg_dict, **kwargs))
476 def rename(
477 self, dims_dict: dict[str, str] | None = None, /, **names: str
478 ) -> DataGroup[_V]:
479 return self.apply(operator.methodcaller('rename', dims_dict, **names))
481 def rename_dims(
482 self, dims_dict: dict[str, str] | None = None, /, **names: str
483 ) -> DataGroup[_V]:
484 return self.apply(operator.methodcaller('rename_dims', dims_dict, **names))
486 def round(self) -> DataGroup[_V]:
487 return self.apply(operator.methodcaller('round'))
489 def squeeze(self, dim: str | Sequence[str] | None = None) -> DataGroup[_V]:
490 return self._reduce('squeeze', dim)
492 def std(
493 self, dim: None | str | tuple[str, ...] = None, *, ddof: int
494 ) -> DataGroup[_V]:
495 return self._reduce('std', dim, ddof=ddof)
497 def sum(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
498 return self._reduce('sum', dim)
500 def to(
501 self,
502 *,
503 unit: Unit | str | None = None,
504 dtype: Any | None = None,
505 copy: bool = True,
506 ) -> DataGroup[_V]:
507 return self.apply(
508 operator.methodcaller('to', unit=unit, dtype=dtype, copy=copy)
509 )
511 def transform_coords(
512 self,
513 targets: str | Iterable[str] | None = None,
514 /,
515 graph: GraphDict | None = None,
516 *,
517 rename_dims: bool = True,
518 keep_aliases: bool = True,
519 keep_intermediate: bool = True,
520 keep_inputs: bool = True,
521 quiet: bool = False,
522 **kwargs: Callable[..., Variable],
523 ) -> DataGroup[_V]:
524 return self.apply(
525 operator.methodcaller(
526 'transform_coords',
527 targets,
528 graph=graph,
529 rename_dims=rename_dims,
530 keep_aliases=keep_aliases,
531 keep_intermediate=keep_intermediate,
532 keep_inputs=keep_inputs,
533 quiet=quiet,
534 **kwargs,
535 )
536 )
538 def transpose(self, dims: None | tuple[str, ...] = None) -> DataGroup[_V]:
539 return self._transform_dim('transpose', dim=dims)
541 def var(
542 self, dim: None | str | tuple[str, ...] = None, *, ddof: int
543 ) -> DataGroup[_V]:
544 return self._reduce('var', dim, ddof=ddof)
546 def plot(self, *args: Any, **kwargs: Any) -> Any:
547 import plopp
549 return plopp.plot(self, *args, **kwargs)
551 def __eq__( # type: ignore[override]
552 self, other: DataGroup[object] | DataArray | Variable | float
553 ) -> DataGroup[_V | bool]:
554 """Item-wise equal."""
555 return data_group_nary(operator.eq, self, other)
557 def __ne__( # type: ignore[override]
558 self, other: DataGroup[object] | DataArray | Variable | float
559 ) -> DataGroup[_V | bool]:
560 """Item-wise not-equal."""
561 return data_group_nary(operator.ne, self, other)
563 def __gt__(
564 self, other: DataGroup[object] | DataArray | Variable | float
565 ) -> DataGroup[_V | bool]:
566 """Item-wise greater-than."""
567 return data_group_nary(operator.gt, self, other)
569 def __ge__(
570 self, other: DataGroup[object] | DataArray | Variable | float
571 ) -> DataGroup[_V | bool]:
572 """Item-wise greater-equal."""
573 return data_group_nary(operator.ge, self, other)
575 def __lt__(
576 self, other: DataGroup[object] | DataArray | Variable | float
577 ) -> DataGroup[_V | bool]:
578 """Item-wise less-than."""
579 return data_group_nary(operator.lt, self, other)
581 def __le__(
582 self, other: DataGroup[object] | DataArray | Variable | float
583 ) -> DataGroup[_V | bool]:
584 """Item-wise less-equal."""
585 return data_group_nary(operator.le, self, other)
587 def __add__(
588 self, other: DataGroup[Any] | DataArray | Variable | float
589 ) -> DataGroup[Any]:
590 """Apply ``add`` item-by-item."""
591 return data_group_nary(operator.add, self, other)
593 def __sub__(
594 self, other: DataGroup[Any] | DataArray | Variable | float
595 ) -> DataGroup[Any]:
596 """Apply ``sub`` item-by-item."""
597 return data_group_nary(operator.sub, self, other)
599 def __mul__(
600 self, other: DataGroup[Any] | DataArray | Variable | float
601 ) -> DataGroup[Any]:
602 """Apply ``mul`` item-by-item."""
603 return data_group_nary(operator.mul, self, other)
605 def __truediv__(
606 self, other: DataGroup[Any] | DataArray | Variable | float
607 ) -> DataGroup[Any]:
608 """Apply ``truediv`` item-by-item."""
609 return data_group_nary(operator.truediv, self, other)
611 def __floordiv__(
612 self, other: DataGroup[Any] | DataArray | Variable | float
613 ) -> DataGroup[Any]:
614 """Apply ``floordiv`` item-by-item."""
615 return data_group_nary(operator.floordiv, self, other)
617 def __mod__(
618 self, other: DataGroup[Any] | DataArray | Variable | float
619 ) -> DataGroup[Any]:
620 """Apply ``mod`` item-by-item."""
621 return data_group_nary(operator.mod, self, other)
623 def __pow__(
624 self, other: DataGroup[Any] | DataArray | Variable | float
625 ) -> DataGroup[Any]:
626 """Apply ``pow`` item-by-item."""
627 return data_group_nary(operator.pow, self, other)
629 def __radd__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
630 """Apply ``add`` item-by-item."""
631 return data_group_nary(operator.add, other, self)
633 def __rsub__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
634 """Apply ``sub`` item-by-item."""
635 return data_group_nary(operator.sub, other, self)
637 def __rmul__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
638 """Apply ``mul`` item-by-item."""
639 return data_group_nary(operator.mul, other, self)
641 def __rtruediv__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
642 """Apply ``truediv`` item-by-item."""
643 return data_group_nary(operator.truediv, other, self)
645 def __rfloordiv__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
646 """Apply ``floordiv`` item-by-item."""
647 return data_group_nary(operator.floordiv, other, self)
649 def __rmod__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
650 """Apply ``mod`` item-by-item."""
651 return data_group_nary(operator.mod, other, self)
653 def __rpow__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
654 """Apply ``pow`` item-by-item."""
655 return data_group_nary(operator.pow, other, self)
657 def __and__(
658 self, other: DataGroup[Any] | DataArray | Variable | float
659 ) -> DataGroup[Any]:
660 """Return the element-wise ``and`` of items."""
661 return data_group_nary(operator.and_, self, other)
663 def __or__(
664 self, other: DataGroup[Any] | DataArray | Variable | float
665 ) -> DataGroup[Any]:
666 """Return the element-wise ``or`` of items."""
667 return data_group_nary(operator.or_, self, other)
669 def __xor__(
670 self, other: DataGroup[Any] | DataArray | Variable | float
671 ) -> DataGroup[Any]:
672 """Return the element-wise ``xor`` of items."""
673 return data_group_nary(operator.xor, self, other)
675 def __invert__(self) -> DataGroup[Any]:
676 """Return the element-wise ``or`` of items."""
677 return self.apply(operator.invert) # type: ignore[arg-type]
680def data_group_nary(
681 func: Callable[..., _R], *args: Any, **kwargs: Any
682) -> DataGroup[_R]:
683 dgs = filter(
684 lambda x: isinstance(x, DataGroup), itertools.chain(args, kwargs.values())
685 )
686 keys = functools.reduce(operator.and_, [dg.keys() for dg in dgs])
688 def elem(x: Any, key: str) -> Any:
689 return x[key] if isinstance(x, DataGroup) else x
691 return DataGroup(
692 {
693 key: func(
694 *[elem(x, key) for x in args],
695 **{name: elem(x, key) for name, x in kwargs.items()},
696 )
697 for key in keys
698 }
699 )
702def apply_to_items(
703 func: Callable[..., _R], dgs: Iterable[DataGroup[Any]], *args: Any, **kwargs: Any
704) -> DataGroup[_R]:
705 keys = functools.reduce(operator.and_, [dg.keys() for dg in dgs])
706 return DataGroup(
707 {key: func([dg[key] for dg in dgs], *args, **kwargs) for key in keys}
708 )
711def data_group_overload(
712 func: Callable[Concatenate[_T, _P], _R],
713) -> Callable[..., _R | DataGroup[_R]]:
714 """Add an overload for DataGroup to a function.
716 If the first argument of the function is a data group,
717 then the decorated function is mapped over all items.
718 It is applied recursively for items that are themselves data groups.
720 Otherwise, the original function is applied directly.
722 Parameters
723 ----------
724 func:
725 Function to decorate.
727 Returns
728 -------
729 :
730 Decorated function.
731 """
733 # Do not assign '__annotations__' because that causes an error in Sphinx.
734 @wraps(func, assigned=('__module__', '__name__', '__qualname__', '__doc__'))
735 def impl(
736 data: _T | DataGroup[Any], *args: _P.args, **kwargs: _P.kwargs
737 ) -> _R | DataGroup[_R]:
738 if isinstance(data, DataGroup):
739 return data.apply(impl, *args, **kwargs) # type: ignore[arg-type]
740 return func(data, *args, **kwargs)
742 return impl
745# There are currently no in-place operations (__iadd__, etc.) because they require
746# a check if the operation would fail before doing it. As otherwise, a failure could
747# leave a partially modified data group behind. Dataset implements such a check, but
748# it is simpler than for DataGroup because the latter supports more data types.
749# So for now, we went with the simple solution and
750# not support in-place operations at all.
751#
752# Binding these functions dynamically has the added benefit that type checkers think
753# that the operations are not implemented.
754def _make_inplace_binary_op(name: str) -> Callable[..., NoReturn]:
755 def impl(
756 self: DataGroup[Any], other: DataGroup[Any] | DataArray | Variable | float
757 ) -> NoReturn:
758 raise TypeError(f'In-place operation i{name} is not supported by DataGroup.')
760 return impl
763for _name in ('add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'pow'):
764 full_name = f'__i{_name}__'
765 _binding.bind_function_as_method(
766 cls=DataGroup, name=full_name, func=_make_inplace_binary_op(full_name)
767 )
769del _name, full_name