Coverage for install/scipp/core/data_group.py: 48%
297 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2022 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock
4from __future__ import annotations
6import copy
7import functools
8import itertools
9import numbers
10import operator
11from collections.abc import MutableMapping
12from functools import wraps
13from typing import (
14 TYPE_CHECKING,
15 Any,
16 Callable,
17 Dict,
18 Iterable,
19 NoReturn,
20 Optional,
21 Tuple,
22 TypeVar,
23 Union,
24 cast,
25 overload,
26)
28import numpy as np
30from .. import _binding
31from .cpp_classes import (
32 DataArray,
33 Dataset,
34 DimensionError,
35 GroupByDataArray,
36 GroupByDataset,
37 Variable,
38)
40if TYPE_CHECKING:
41 # typing imports data_group.
42 # So the following import would create a cycle at runtime.
43 from ..typing import ScippIndex
46def _item_dims(item):
47 return getattr(item, 'dims', ())
50def _is_binned(item):
51 from .bins import Bins
53 if isinstance(item, Bins):
54 return True
55 return hasattr(item, 'bins') and item.bins is not None
58def _summarize(item):
59 if isinstance(item, DataGroup):
60 return f'{type(item).__name__}({len(item)}, {item.sizes})'
61 if hasattr(item, 'sizes'):
62 return f'{type(item).__name__}({item.sizes})'
63 return str(item)
66def _is_positional_index(key) -> bool:
67 def is_int(x):
68 return isinstance(x, numbers.Integral)
70 if is_int(key):
71 return True
72 if isinstance(key, slice):
73 if is_int(key.start) or is_int(key.stop) or is_int(key.step):
74 return True
75 if key.start is None and key.stop is None and key.step is None:
76 return True
77 return False
80def _is_list_index(key) -> bool:
81 return isinstance(key, (list, np.ndarray))
84class DataGroup(MutableMapping):
85 """
86 A dict-like group of data. Additionally provides dims and shape properties.
88 DataGroup acts like a Python dict but additionally supports Scipp functionality
89 such as positional- and label-based indexing and Scipp operations by mapping them
90 to the values in the dict. This may happen recursively to support tree-like data
91 structures.
93 .. versionadded:: 23.01.0
94 """
96 def __init__(self, /, *args, **kwargs):
97 self._items = dict(*args, **kwargs)
98 if not all(isinstance(k, str) for k in self._items.keys()):
99 raise ValueError("DataGroup keys must be strings.")
101 def __copy__(self) -> DataGroup:
102 return DataGroup(copy.copy(self._items))
104 def __len__(self) -> int:
105 """Return the number of items in the data group."""
106 return len(self._items)
108 def __iter__(self):
109 yield from self._items
111 @overload
112 def __getitem__(self, name: str) -> Any: ...
114 @overload
115 def __getitem__(self, name: ScippIndex) -> DataGroup: ...
117 def __getitem__(self, name):
118 """Return item of given name or index all items.
120 When ``name`` is a string, return the item of the given name. Otherwise, this
121 returns a new DataGroup, with items created by indexing the items in this
122 DataGroup. This may perform, e.g., Scipp's positional indexing, label-based
123 indexing, or advanced indexing on items that are scipp.Variable or
124 scipp.DataArray.
126 Label-based indexing is only possible when all items have a coordinate for the
127 indexed dimension.
129 Advanced indexing comprises integer-array indexing and boolean-variable
130 indexing. Unlike positional indexing, integer-array indexing works even when
131 the item shapes are inconsistent for the indexed dimensions, provided that all
132 items contain the maximal index in the integer array. Boolean-variable indexing
133 is only possible when the shape of all items is compatible with the boolean
134 variable.
135 """
136 from .bins import Bins
138 if isinstance(name, str):
139 return self._items[name]
140 if isinstance(name, tuple) and name == ():
141 return self.apply(operator.itemgetter(name))
142 if isinstance(name, Variable): # boolean indexing
143 return self.apply(operator.itemgetter(name))
144 if _is_positional_index(name) or _is_list_index(name):
145 if self.ndim != 1:
146 raise DimensionError(
147 "Slicing with implicit dimension label is only possible "
148 f"for 1-D objects. Got {self.sizes} with ndim={self.ndim}. Provide "
149 "an explicit dimension label, e.g., var['x', 0] instead of var[0]."
150 )
151 dim = self.dims[0]
152 index = name
153 else:
154 dim, index = name
155 return DataGroup(
156 {
157 key: var[dim, index]
158 if (isinstance(var, Bins) or dim in _item_dims(var))
159 else var
160 for key, var in self.items()
161 }
162 )
164 @overload
165 def __setitem__(self, name: str, value: Any): ...
167 def __setitem__(self, name, value):
168 """Set self[key] to value."""
169 if isinstance(name, str):
170 self._items[name] = value
171 else:
172 raise TypeError('Keys must be strings')
174 def __delitem__(self, name: str):
175 """Delete self[key]."""
176 del self._items[name]
178 def __sizeof__(self) -> int:
179 return self.underlying_size()
181 def underlying_size(self) -> int:
182 # TODO Return the underlying size of all items in DataGroup
183 total_size = super.__sizeof__(self)
184 for item in self.values():
185 if isinstance(item, (DataArray, Dataset, Variable, DataGroup)):
186 total_size += item.underlying_size()
187 elif hasattr(item, 'nbytes'):
188 total_size += item.nbytes
189 else:
190 total_size += item.__sizeof__()
192 return total_size
194 @property
195 def dims(self) -> Tuple[Optional[str], ...]:
196 """Union of dims of all items. Non-Scipp items are handled as dims=()."""
197 return tuple(self.sizes)
199 @property
200 def ndim(self):
201 """Number of dimensions, i.e., len(self.dims)."""
202 return len(self.dims)
204 @property
205 def shape(self) -> Tuple[Optional[int], ...]:
206 """Union of shape of all items. Non-Scipp items are handled as shape=()."""
207 return tuple(self.sizes.values())
209 @property
210 def sizes(self) -> Dict[str, Optional[int]]:
211 """Dict combining dims and shape, i.e., mapping dim labels to their size."""
212 all_sizes = {}
213 for x in self.values():
214 for dim, size in getattr(x, 'sizes', {}).items():
215 all_sizes.setdefault(dim, set()).add(size)
216 return {d: next(iter(s)) if len(s) == 1 else None for d, s in all_sizes.items()}
218 def _repr_html_(self):
219 from ..visualization.formatting_datagroup_html import datagroup_repr
221 return datagroup_repr(self)
223 def __repr__(self):
224 r = f'DataGroup(sizes={self.sizes}, keys=[\n'
225 for name, var in self.items():
226 r += f' {name}: {_summarize(var)},\n'
227 r += '])'
228 return r
230 def __str__(self):
231 return f'DataGroup(sizes={self.sizes}, keys={list(self.keys())})'
233 @property
234 def bins(self):
235 # TODO Returning a regular DataGroup here may be wrong, since the `bins`
236 # property provides a different set of attrs and methods.
237 return self.apply(operator.attrgetter('bins'))
239 def apply(self, func: Callable, *args, **kwargs) -> DataGroup:
240 """Call func on all values and return new DataGroup containing the results."""
241 return DataGroup({key: func(v, *args, **kwargs) for key, v in self.items()})
243 def _transform_dim(
244 self, func: Callable, *, dim: Union[None, str, Iterable[str]], **kwargs
245 ) -> DataGroup:
246 """Transform items that depend on one or more dimensions given by `dim`."""
247 dims = (dim,) if isinstance(dim, str) else dim
249 def intersects(item):
250 item_dims = _item_dims(item)
251 if dims is None:
252 return item_dims != ()
253 return set(dims).intersection(item_dims) != set()
255 return DataGroup(
256 {
257 key: v
258 if not intersects(v)
259 else operator.methodcaller(func, dim, **kwargs)(v)
260 for key, v in self.items()
261 }
262 )
264 def _reduce(
265 self, method: str, dim: Union[None, str, Tuple[str, ...]] = None, **kwargs
266 ) -> DataGroup:
267 reduce_all = operator.methodcaller(method, **kwargs)
269 def _reduce_child(v, dim):
270 if isinstance(v, (GroupByDataArray, GroupByDataset)):
271 child_dims = (dim,)
272 else:
273 child_dims = _item_dims(v)
274 # Reduction operations on binned data implicitly reduce over bin content.
275 # Therefore, a purely dimension-based logic is not sufficient to determine
276 # if the item has to be reduced or not.
277 binned = _is_binned(v)
278 if child_dims == () and not binned:
279 return v
280 if dim is None:
281 return reduce_all(v)
282 if isinstance(dim, str):
283 dims_to_reduce = dim if dim in child_dims else ()
284 else:
285 dims_to_reduce = tuple(d for d in dim if d in child_dims)
286 if dims_to_reduce == () and binned:
287 return reduce_all(v)
288 return (
289 v
290 if dims_to_reduce == ()
291 else operator.methodcaller(method, dims_to_reduce, **kwargs)(v)
292 )
294 return DataGroup({key: _reduce_child(v, dim) for key, v in self.items()})
296 def copy(self, deep: bool = True) -> DataGroup:
297 return copy.deepcopy(self) if deep else copy.copy(self)
299 def all(self, *args, **kwargs):
300 return self._reduce('all', *args, **kwargs)
302 def any(self, *args, **kwargs):
303 return self._reduce('any', *args, **kwargs)
305 def astype(self, *args, **kwargs):
306 return self.apply(operator.methodcaller('astype', *args, **kwargs))
308 def bin(self, *args, **kwargs):
309 return self.apply(operator.methodcaller('bin', *args, **kwargs))
311 def broadcast(self, *args, **kwargs):
312 return self.apply(operator.methodcaller('broadcast', *args, **kwargs))
314 def ceil(self, *args, **kwargs):
315 return self.apply(operator.methodcaller('ceil', *args, **kwargs))
317 def flatten(self, dims: Union[None, Iterable[str]] = None, **kwargs):
318 return self._transform_dim('flatten', dim=dims, **kwargs)
320 def floor(self, *args, **kwargs):
321 return self.apply(operator.methodcaller('floor', *args, **kwargs))
323 def fold(self, dim: str, **kwargs):
324 return self._transform_dim('fold', dim=dim, **kwargs)
326 def group(self, *args, **kwargs):
327 return self.apply(operator.methodcaller('group', *args, **kwargs))
329 def groupby(self, *args, **kwargs):
330 return self.apply(operator.methodcaller('groupby', *args, **kwargs))
332 def hist(self, *args, **kwargs):
333 return self.apply(operator.methodcaller('hist', *args, **kwargs))
335 def max(self, *args, **kwargs):
336 return self._reduce('max', *args, **kwargs)
338 def mean(self, *args, **kwargs):
339 return self._reduce('mean', *args, **kwargs)
341 def median(self, *args, **kwargs):
342 return self._reduce('median', *args, **kwargs)
344 def min(self, *args, **kwargs):
345 return self._reduce('min', *args, **kwargs)
347 def nanhist(self, *args, **kwargs):
348 return self.apply(operator.methodcaller('nanhist', *args, **kwargs))
350 def nanmax(self, *args, **kwargs):
351 return self._reduce('nanmax', *args, **kwargs)
353 def nanmean(self, *args, **kwargs):
354 return self._reduce('nanmean', *args, **kwargs)
356 def nanmedian(self, *args, **kwargs):
357 return self._reduce('nanmedian', *args, **kwargs)
359 def nanmin(self, *args, **kwargs):
360 return self._reduce('nanmin', *args, **kwargs)
362 def nansum(self, *args, **kwargs):
363 return self._reduce('nansum', *args, **kwargs)
365 def nanstd(self, *args, **kwargs):
366 return self._reduce('nanstd', *args, **kwargs)
368 def nanvar(self, *args, **kwargs):
369 return self._reduce('nanvar', *args, **kwargs)
371 def rebin(self, *args, **kwargs):
372 return self.apply(operator.methodcaller('rebin', *args, **kwargs))
374 def rename(self, *args, **kwargs):
375 return self.apply(operator.methodcaller('rename', *args, **kwargs))
377 def rename_dims(self, *args, **kwargs):
378 return self.apply(operator.methodcaller('rename_dims', *args, **kwargs))
380 def round(self, *args, **kwargs):
381 return self.apply(operator.methodcaller('round', *args, **kwargs))
383 def squeeze(self, *args, **kwargs):
384 return self._reduce('squeeze', *args, **kwargs)
386 def std(self, *args, **kwargs):
387 return self._reduce('std', *args, **kwargs)
389 def sum(self, *args, **kwargs):
390 return self._reduce('sum', *args, **kwargs)
392 def to(self, *args, **kwargs):
393 return self.apply(operator.methodcaller('to', *args, **kwargs))
395 def transform_coords(self, *args, **kwargs):
396 return self.apply(operator.methodcaller('transform_coords', *args, **kwargs))
398 def transpose(self, dims: Union[None, Tuple[str, ...]] = None):
399 return self._transform_dim('transpose', dim=dims)
401 def var(self, *args, **kwargs):
402 return self._reduce('var', *args, **kwargs)
404 def plot(self, *args, **kwargs):
405 import plopp
407 return plopp.plot(self, *args, **kwargs)
409 def __eq__(
410 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
411 ) -> DataGroup:
412 """Item-wise equal."""
413 return data_group_nary(operator.eq, self, other)
415 def __ne__(
416 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
417 ) -> DataGroup:
418 """Item-wise not-equal."""
419 return data_group_nary(operator.ne, self, other)
421 def __gt__(
422 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
423 ) -> DataGroup:
424 """Item-wise greater-than."""
425 return data_group_nary(operator.gt, self, other)
427 def __ge__(
428 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
429 ) -> DataGroup:
430 """Item-wise greater-equal."""
431 return data_group_nary(operator.ge, self, other)
433 def __lt__(
434 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
435 ) -> DataGroup:
436 """Item-wise less-than."""
437 return data_group_nary(operator.lt, self, other)
439 def __le__(
440 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
441 ) -> DataGroup:
442 """Item-wise less-equal."""
443 return data_group_nary(operator.le, self, other)
445 def __add__(
446 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
447 ) -> DataGroup:
448 """Apply ``add`` item-by-item."""
449 return data_group_nary(operator.add, self, other)
451 def __sub__(
452 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
453 ) -> DataGroup:
454 """Apply ``sub`` item-by-item."""
455 return data_group_nary(operator.sub, self, other)
457 def __mul__(
458 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
459 ) -> DataGroup:
460 """Apply ``mul`` item-by-item."""
461 return data_group_nary(operator.mul, self, other)
463 def __truediv__(
464 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
465 ) -> DataGroup:
466 """Apply ``truediv`` item-by-item."""
467 return data_group_nary(operator.truediv, self, other)
469 def __floordiv__(
470 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
471 ) -> DataGroup:
472 """Apply ``floordiv`` item-by-item."""
473 return data_group_nary(operator.floordiv, self, other)
475 def __mod__(
476 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
477 ) -> DataGroup:
478 """Apply ``mod`` item-by-item."""
479 return data_group_nary(operator.mod, self, other)
481 def __pow__(
482 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
483 ) -> DataGroup:
484 """Apply ``pow`` item-by-item."""
485 return data_group_nary(operator.pow, self, other)
487 def __radd__(
488 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
489 ) -> DataGroup:
490 """Apply ``add`` item-by-item."""
491 return data_group_nary(operator.add, other, self)
493 def __rsub__(
494 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
495 ) -> DataGroup:
496 """Apply ``sub`` item-by-item."""
497 return data_group_nary(operator.sub, other, self)
499 def __rmul__(
500 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
501 ) -> DataGroup:
502 """Apply ``mul`` item-by-item."""
503 return data_group_nary(operator.mul, other, self)
505 def __rtruediv__(
506 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
507 ) -> DataGroup:
508 """Apply ``truediv`` item-by-item."""
509 return data_group_nary(operator.truediv, other, self)
511 def __rfloordiv__(
512 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
513 ) -> DataGroup:
514 """Apply ``floordiv`` item-by-item."""
515 return data_group_nary(operator.floordiv, other, self)
517 def __rmod__(
518 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
519 ) -> DataGroup:
520 """Apply ``mod`` item-by-item."""
521 return data_group_nary(operator.mod, other, self)
523 def __rpow__(
524 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
525 ) -> DataGroup:
526 """Apply ``pow`` item-by-item."""
527 return data_group_nary(operator.pow, other, self)
529 def __and__(
530 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
531 ) -> DataGroup:
532 """Return the element-wise ``and`` of items."""
533 return data_group_nary(operator.and_, self, other)
535 def __or__(
536 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
537 ) -> DataGroup:
538 """Return the element-wise ``or`` of items."""
539 return data_group_nary(operator.or_, self, other)
541 def __xor__(
542 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
543 ) -> DataGroup:
544 """Return the element-wise ``xor`` of items."""
545 return data_group_nary(operator.xor, self, other)
547 def __invert__(self) -> DataGroup:
548 """Return the element-wise ``or`` of items."""
549 return self.apply(operator.invert)
552def _data_group_binary(
553 func: Callable, dg1: DataGroup, dg2: DataGroup, *args, **kwargs
554) -> DataGroup:
555 return DataGroup(
556 {
557 key: func(dg1[key], dg2[key], *args, **kwargs)
558 for key in dg1.keys() & dg2.keys()
559 }
560 )
563def data_group_nary(func: Callable, *args, **kwargs) -> DataGroup:
564 dgs = filter(
565 lambda x: isinstance(x, DataGroup), itertools.chain(args, kwargs.values())
566 )
567 keys = functools.reduce(operator.and_, [dg.keys() for dg in dgs])
569 def elem(x, key):
570 return x[key] if isinstance(x, DataGroup) else x
572 return DataGroup(
573 {
574 key: func(
575 *[elem(x, key) for x in args],
576 **{name: elem(x, key) for name, x in kwargs.items()},
577 )
578 for key in keys
579 }
580 )
583def _apply_to_items(
584 func: Callable, dgs: Iterable[DataGroup], *args, **kwargs
585) -> DataGroup:
586 keys = functools.reduce(operator.and_, [dg.keys() for dg in dgs])
587 return DataGroup(
588 {key: func([dg[key] for dg in dgs], *args, **kwargs) for key in keys}
589 )
592_F = TypeVar('_F', bound=Callable[..., Any])
595def data_group_overload(func: _F) -> _F:
596 """Add an overload for DataGroup to a function.
598 If the first argument of the function is a data group,
599 then the decorated function is mapped over all items.
600 It is applied recursively for items that are themselves data groups.
602 Otherwise, the original function is applied directly.
604 Parameters
605 ----------
606 func:
607 Function to decorate.
609 Returns
610 -------
611 :
612 Decorated function.
613 """
615 # Do not assign '__annotations__' because that causes an error in Sphinx.
616 @wraps(func, assigned=('__module__', '__name__', '__qualname__', '__doc__'))
617 def impl(data, *args, **kwargs):
618 if isinstance(data, DataGroup):
619 return data.apply(impl, *args, **kwargs)
620 return func(data, *args, **kwargs)
622 return cast(_F, impl)
625# There are currently no in-place operations (__iadd__, etc.) because they require
626# a check if the operation would fail before doing it. As otherwise, a failure could
627# leave a partially modified data group behind. Dataset implements such a check, but
628# it is simpler than for DataGroup because the latter supports more data types.
629# So for now, we went with the simple solution and
630# not support in-place operations at all.
631#
632# Binding these functions dynamically has the added benefit that type checkers think
633# that the operations are not implemented.
634def _make_inplace_binary_op(name: str):
635 def impl(
636 self, other: Union[DataGroup, DataArray, Variable, numbers.Real]
637 ) -> NoReturn:
638 raise TypeError(f'In-place operation i{name} is not supported by DataGroup.')
640 return impl
643for _name in ('add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'pow'):
644 full_name = f'__i{_name}__'
645 _binding.bind_function_as_method(
646 cls=DataGroup, name=full_name, func=_make_inplace_binary_op(full_name)
647 )
649del _name, full_name