# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2022 Scipp contributors (https://github.com/scipp)
# @author Simon Heybrock
from __future__ import annotations
import copy
import functools
import itertools
import numbers
import operator
from collections.abc import (
Callable,
ItemsView,
Iterable,
Iterator,
KeysView,
Mapping,
MutableMapping,
Sequence,
ValuesView,
)
from functools import wraps
from typing import (
TYPE_CHECKING,
Any,
Concatenate,
NoReturn,
ParamSpec,
TypeVar,
cast,
overload,
)
import numpy as np
from .. import _binding
from .cpp_classes import (
DataArray,
Dataset,
DimensionError,
GroupByDataArray,
GroupByDataset,
Unit,
Variable,
)
if TYPE_CHECKING:
# Avoid cyclic imports
from ..coords.graph import GraphDict
from ..typing import ScippIndex
from .bins import Bins
_T = TypeVar("_T") # Any type
_V = TypeVar("_V") # Value type of self
_R = TypeVar("_R") # Return type of a callable
_P = ParamSpec('_P')
def _item_dims(item: Any) -> tuple[str, ...]:
return getattr(item, 'dims', ())
def _is_binned(item: Any) -> bool:
from .bins import Bins
if isinstance(item, Bins):
return True
return getattr(item, 'bins', None) is not None
def _summarize(item: Any) -> str:
if isinstance(item, DataGroup):
return f'{type(item).__name__}({len(item)}, {item.sizes})'
if hasattr(item, 'sizes'):
return f'{type(item).__name__}({item.sizes})'
return str(item)
def _is_positional_index(key: Any) -> bool:
def is_int(x: object) -> bool:
return isinstance(x, numbers.Integral)
if is_int(key):
return True
if isinstance(key, slice):
if is_int(key.start) or is_int(key.stop) or is_int(key.step):
return True
if key.start is None and key.stop is None and key.step is None:
return True
return False
def _is_list_index(key: Any) -> bool:
return isinstance(key, list | np.ndarray)
class DataGroup(MutableMapping[str, _V]):
"""
A dict-like group of data. Additionally provides dims and shape properties.
DataGroup acts like a Python dict but additionally supports Scipp functionality
such as positional- and label-based indexing and Scipp operations by mapping them
to the values in the dict. This may happen recursively to support tree-like data
structures.
.. versionadded:: 23.01.0
"""
def __init__(
self, /, *args: Iterable[tuple[str, _V]] | Mapping[str, _V], **kwargs: _V
) -> None:
self._items = dict(*args, **kwargs)
if not all(isinstance(k, str) for k in self._items.keys()):
raise ValueError("DataGroup keys must be strings.")
def __copy__(self) -> DataGroup[_V]:
return DataGroup(copy.copy(self._items))
def __len__(self) -> int:
"""Return the number of items in the data group."""
return len(self._items)
def __iter__(self) -> Iterator[str]:
return iter(self._items)
def keys(self) -> KeysView[str]:
return self._items.keys()
def values(self) -> ValuesView[_V]:
return self._items.values()
def items(self) -> ItemsView[str, _V]:
return self._items.items()
@overload
def __getitem__(self, name: str) -> _V: ...
@overload
def __getitem__(self, name: ScippIndex) -> DataGroup[_V]: ...
def __getitem__(self, name: Any) -> Any:
"""Return item of given name or index all items.
When ``name`` is a string, return the item of the given name. Otherwise, this
returns a new DataGroup, with items created by indexing the items in this
DataGroup. This may perform, e.g., Scipp's positional indexing, label-based
indexing, or advanced indexing on items that are scipp.Variable or
scipp.DataArray.
Label-based indexing is only possible when all items have a coordinate for the
indexed dimension.
Advanced indexing comprises integer-array indexing and boolean-variable
indexing. Unlike positional indexing, integer-array indexing works even when
the item shapes are inconsistent for the indexed dimensions, provided that all
items contain the maximal index in the integer array. Boolean-variable indexing
is only possible when the shape of all items is compatible with the boolean
variable.
"""
from .bins import Bins
if isinstance(name, str):
return self._items[name]
if isinstance(name, tuple) and name == ():
return cast(DataGroup[Any], self).apply(operator.itemgetter(name))
if isinstance(name, Variable): # boolean indexing
return cast(DataGroup[Any], self).apply(operator.itemgetter(name))
if _is_positional_index(name) or _is_list_index(name):
if self.ndim != 1:
raise DimensionError(
"Slicing with implicit dimension label is only possible "
f"for 1-D objects. Got {self.sizes} with ndim={self.ndim}. Provide "
"an explicit dimension label, e.g., var['x', 0] instead of var[0]."
)
dim = self.dims[0]
index = name
else:
dim, index = name
return DataGroup(
{
key: var[dim, index] # type: ignore[index]
if (isinstance(var, Bins) or dim in _item_dims(var))
else var
for key, var in self.items()
}
)
def __setitem__(self, name: str, value: _V) -> None:
"""Set self[key] to value."""
if isinstance(name, str):
self._items[name] = value
else:
raise TypeError('Keys must be strings')
def __delitem__(self, name: str) -> None:
"""Delete self[key]."""
del self._items[name]
def __sizeof__(self) -> int:
return self.underlying_size()
def underlying_size(self) -> int:
# TODO Return the underlying size of all items in DataGroup
total_size = super.__sizeof__(self)
for item in self.values():
if isinstance(item, DataArray | Dataset | Variable | DataGroup):
total_size += item.underlying_size()
elif hasattr(item, 'nbytes'):
total_size += item.nbytes
else:
total_size += item.__sizeof__()
return total_size
@property
def dims(self) -> tuple[str, ...]:
"""Union of dims of all items. Non-Scipp items are handled as dims=()."""
return tuple(self.sizes)
@property
def ndim(self) -> int:
"""Number of dimensions, i.e., len(self.dims)."""
return len(self.dims)
@property
def shape(self) -> tuple[int | None, ...]:
"""Union of shape of all items. Non-Scipp items are handled as shape=()."""
return tuple(self.sizes.values())
@property
def sizes(self) -> dict[str, int | None]:
"""Dict combining dims and shape, i.e., mapping dim labels to their size."""
all_sizes: dict[str, set[int]] = {}
for x in self.values():
for dim, size in getattr(x, 'sizes', {}).items():
all_sizes.setdefault(dim, set()).add(size)
return {d: next(iter(s)) if len(s) == 1 else None for d, s in all_sizes.items()}
def _repr_html_(self) -> str:
from ..visualization.formatting_datagroup_html import datagroup_repr
return datagroup_repr(self)
def __repr__(self) -> str:
r = f'DataGroup(sizes={self.sizes}, keys=[\n'
for name, var in self.items():
r += f' {name}: {_summarize(var)},\n'
r += '])'
return r
def __str__(self) -> str:
return f'DataGroup(sizes={self.sizes}, keys={list(self.keys())})'
@property
def bins(self) -> DataGroup[Bins | None]:
# TODO Returning a regular DataGroup here may be wrong, since the `bins`
# property provides a different set of attrs and methods.
return self.apply(operator.attrgetter('bins'))
def apply(
self,
func: Callable[Concatenate[_V, _P], _R],
*args: _P.args,
**kwargs: _P.kwargs,
) -> DataGroup[_R]:
"""Call func on all values and return new DataGroup containing the results."""
return DataGroup({key: func(v, *args, **kwargs) for key, v in self.items()})
def _transform_dim(
self, func: str, *, dim: None | str | Iterable[str], **kwargs: Any
) -> DataGroup[Any]:
"""Transform items that depend on one or more dimensions given by `dim`."""
dims = (dim,) if isinstance(dim, str) else dim
def intersects(item: _V) -> bool:
item_dims = _item_dims(item)
if dims is None:
return item_dims != ()
return set(dims).intersection(item_dims) != set()
return DataGroup(
{
key: v
if not intersects(v)
else operator.methodcaller(func, dim, **kwargs)(v)
for key, v in self.items()
}
)
def _reduce(
self, method: str, dim: None | str | Sequence[str] = None, **kwargs: Any
) -> DataGroup[Any]:
reduce_all = operator.methodcaller(method, **kwargs)
def _reduce_child(v: _V) -> Any:
if isinstance(v, GroupByDataArray | GroupByDataset):
child_dims: tuple[None | str | Sequence[str], ...] = (dim,)
else:
child_dims = _item_dims(v)
# Reduction operations on binned data implicitly reduce over bin content.
# Therefore, a purely dimension-based logic is not sufficient to determine
# if the item has to be reduced or not.
binned = _is_binned(v)
if child_dims == () and not binned:
return v
if dim is None:
return reduce_all(v)
if isinstance(dim, str):
dims_to_reduce = dim if dim in child_dims else ()
else:
dims_to_reduce = tuple(d for d in dim if d in child_dims)
if dims_to_reduce == () and binned:
return reduce_all(v)
return (
v
if dims_to_reduce == ()
else operator.methodcaller(method, dims_to_reduce, **kwargs)(v)
)
return DataGroup({key: _reduce_child(v) for key, v in self.items()})
def copy(self, deep: bool = True) -> DataGroup[_V]:
return copy.deepcopy(self) if deep else copy.copy(self)
def all(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
return self._reduce('all', dim)
def any(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
return self._reduce('any', dim)
def astype(self, type: Any, *, copy: bool = True) -> DataGroup[_V]:
return self.apply(operator.methodcaller('astype', type, copy=copy))
def bin(
self,
arg_dict: dict[str, int | Variable] | None = None,
/,
**kwargs: int | Variable,
) -> DataGroup[_V]:
return self.apply(operator.methodcaller('bin', arg_dict, **kwargs))
@overload
def broadcast(
self,
*,
dims: Sequence[str],
shape: Sequence[int],
) -> DataGroup[_V]: ...
@overload
def broadcast(
self,
*,
sizes: dict[str, int],
) -> DataGroup[_V]: ...
def broadcast(
self,
*,
dims: Sequence[str] | None = None,
shape: Sequence[int] | None = None,
sizes: dict[str, int] | None = None,
) -> DataGroup[_V]:
return self.apply(
operator.methodcaller('broadcast', dims=dims, shape=shape, sizes=sizes)
)
def ceil(self) -> DataGroup[_V]:
return self.apply(operator.methodcaller('ceil'))
def flatten(
self, dims: Sequence[str] | None = None, to: str | None = None
) -> DataGroup[_V]:
return self._transform_dim('flatten', dim=dims, to=to)
def floor(self) -> DataGroup[_V]:
return self.apply(operator.methodcaller('floor'))
@overload
def fold(
self,
dim: str,
*,
dims: Sequence[str],
shape: Sequence[int],
) -> DataGroup[_V]: ...
@overload
def fold(
self,
dim: str,
*,
sizes: dict[str, int],
) -> DataGroup[_V]: ...
def fold(
self,
dim: str,
*,
dims: Sequence[str] | None = None,
shape: Sequence[int] | None = None,
sizes: dict[str, int] | None = None,
) -> DataGroup[_V]:
return self._transform_dim('fold', dim=dim, dims=dims, shape=shape, sizes=sizes)
def group(self, /, *args: str | Variable) -> DataGroup[_V]:
return self.apply(operator.methodcaller('group', *args))
def groupby(
self, /, group: Variable | str, *, bins: Variable | None = None
) -> DataGroup[GroupByDataArray | GroupByDataset]:
return self.apply(operator.methodcaller('groupby', group, bins=bins))
def hist(
self,
arg_dict: dict[str, int | Variable] | None = None,
/,
**kwargs: int | Variable,
) -> DataGroup[DataArray | Dataset]:
return self.apply(operator.methodcaller('hist', arg_dict, **kwargs))
def max(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
return self._reduce('max', dim)
def mean(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
return self._reduce('mean', dim)
def median(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
return self._reduce('median', dim)
def min(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
return self._reduce('min', dim)
def nanhist(
self,
arg_dict: dict[str, int | Variable] | None = None,
/,
**kwargs: int | Variable,
) -> DataGroup[DataArray]:
return self.apply(operator.methodcaller('nanhist', arg_dict, **kwargs))
def nanmax(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
return self._reduce('nanmax', dim)
def nanmean(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
return self._reduce('nanmean', dim)
def nanmedian(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
return self._reduce('nanmedian', dim)
def nanmin(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
return self._reduce('nanmin', dim)
def nansum(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
return self._reduce('nansum', dim)
def nanstd(
self, dim: None | str | tuple[str, ...] = None, *, ddof: int
) -> DataGroup[_V]:
return self._reduce('nanstd', dim, ddof=ddof)
def nanvar(
self, dim: None | str | tuple[str, ...] = None, *, ddof: int
) -> DataGroup[_V]:
return self._reduce('nanvar', dim, ddof=ddof)
def rebin(
self,
arg_dict: dict[str, int | Variable] | None = None,
/,
**kwargs: int | Variable,
) -> DataGroup[_V]:
return self.apply(operator.methodcaller('rebin', arg_dict, **kwargs))
def rename(
self, dims_dict: dict[str, str] | None = None, /, **names: str
) -> DataGroup[_V]:
return self.apply(operator.methodcaller('rename', dims_dict, **names))
def rename_dims(
self, dims_dict: dict[str, str] | None = None, /, **names: str
) -> DataGroup[_V]:
return self.apply(operator.methodcaller('rename_dims', dims_dict, **names))
def round(self) -> DataGroup[_V]:
return self.apply(operator.methodcaller('round'))
def squeeze(self, dim: str | Sequence[str] | None = None) -> DataGroup[_V]:
return self._reduce('squeeze', dim)
def std(
self, dim: None | str | tuple[str, ...] = None, *, ddof: int
) -> DataGroup[_V]:
return self._reduce('std', dim, ddof=ddof)
def sum(self, dim: None | str | tuple[str, ...] = None) -> DataGroup[_V]:
return self._reduce('sum', dim)
def to(
self,
*,
unit: Unit | str | None = None,
dtype: Any | None = None,
copy: bool = True,
) -> DataGroup[_V]:
return self.apply(
operator.methodcaller('to', unit=unit, dtype=dtype, copy=copy)
)
def transform_coords(
self,
targets: str | Iterable[str] | None = None,
/,
graph: GraphDict | None = None,
*,
rename_dims: bool = True,
keep_aliases: bool = True,
keep_intermediate: bool = True,
keep_inputs: bool = True,
quiet: bool = False,
**kwargs: Callable[..., Variable],
) -> DataGroup[_V]:
return self.apply(
operator.methodcaller(
'transform_coords',
targets,
graph=graph,
rename_dims=rename_dims,
keep_aliases=keep_aliases,
keep_intermediate=keep_intermediate,
keep_inputs=keep_inputs,
quiet=quiet,
**kwargs,
)
)
def transpose(self, dims: None | tuple[str, ...] = None) -> DataGroup[_V]:
return self._transform_dim('transpose', dim=dims)
def var(
self, dim: None | str | tuple[str, ...] = None, *, ddof: int
) -> DataGroup[_V]:
return self._reduce('var', dim, ddof=ddof)
def plot(self, *args: Any, **kwargs: Any) -> Any:
import plopp
return plopp.plot(self, *args, **kwargs)
def __eq__( # type: ignore[override]
self, other: DataGroup[object] | DataArray | Variable | float
) -> DataGroup[_V | bool]:
"""Item-wise equal."""
return data_group_nary(operator.eq, self, other)
def __ne__( # type: ignore[override]
self, other: DataGroup[object] | DataArray | Variable | float
) -> DataGroup[_V | bool]:
"""Item-wise not-equal."""
return data_group_nary(operator.ne, self, other)
def __gt__(
self, other: DataGroup[object] | DataArray | Variable | float
) -> DataGroup[_V | bool]:
"""Item-wise greater-than."""
return data_group_nary(operator.gt, self, other)
def __ge__(
self, other: DataGroup[object] | DataArray | Variable | float
) -> DataGroup[_V | bool]:
"""Item-wise greater-equal."""
return data_group_nary(operator.ge, self, other)
def __lt__(
self, other: DataGroup[object] | DataArray | Variable | float
) -> DataGroup[_V | bool]:
"""Item-wise less-than."""
return data_group_nary(operator.lt, self, other)
def __le__(
self, other: DataGroup[object] | DataArray | Variable | float
) -> DataGroup[_V | bool]:
"""Item-wise less-equal."""
return data_group_nary(operator.le, self, other)
def __add__(
self, other: DataGroup[Any] | DataArray | Variable | float
) -> DataGroup[Any]:
"""Apply ``add`` item-by-item."""
return data_group_nary(operator.add, self, other)
def __sub__(
self, other: DataGroup[Any] | DataArray | Variable | float
) -> DataGroup[Any]:
"""Apply ``sub`` item-by-item."""
return data_group_nary(operator.sub, self, other)
def __mul__(
self, other: DataGroup[Any] | DataArray | Variable | float
) -> DataGroup[Any]:
"""Apply ``mul`` item-by-item."""
return data_group_nary(operator.mul, self, other)
def __truediv__(
self, other: DataGroup[Any] | DataArray | Variable | float
) -> DataGroup[Any]:
"""Apply ``truediv`` item-by-item."""
return data_group_nary(operator.truediv, self, other)
def __floordiv__(
self, other: DataGroup[Any] | DataArray | Variable | float
) -> DataGroup[Any]:
"""Apply ``floordiv`` item-by-item."""
return data_group_nary(operator.floordiv, self, other)
def __mod__(
self, other: DataGroup[Any] | DataArray | Variable | float
) -> DataGroup[Any]:
"""Apply ``mod`` item-by-item."""
return data_group_nary(operator.mod, self, other)
def __pow__(
self, other: DataGroup[Any] | DataArray | Variable | float
) -> DataGroup[Any]:
"""Apply ``pow`` item-by-item."""
return data_group_nary(operator.pow, self, other)
def __radd__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
"""Apply ``add`` item-by-item."""
return data_group_nary(operator.add, other, self)
def __rsub__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
"""Apply ``sub`` item-by-item."""
return data_group_nary(operator.sub, other, self)
def __rmul__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
"""Apply ``mul`` item-by-item."""
return data_group_nary(operator.mul, other, self)
def __rtruediv__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
"""Apply ``truediv`` item-by-item."""
return data_group_nary(operator.truediv, other, self)
def __rfloordiv__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
"""Apply ``floordiv`` item-by-item."""
return data_group_nary(operator.floordiv, other, self)
def __rmod__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
"""Apply ``mod`` item-by-item."""
return data_group_nary(operator.mod, other, self)
def __rpow__(self, other: DataArray | Variable | float) -> DataGroup[Any]:
"""Apply ``pow`` item-by-item."""
return data_group_nary(operator.pow, other, self)
def __and__(
self, other: DataGroup[Any] | DataArray | Variable | float
) -> DataGroup[Any]:
"""Return the element-wise ``and`` of items."""
return data_group_nary(operator.and_, self, other)
def __or__(
self, other: DataGroup[Any] | DataArray | Variable | float
) -> DataGroup[Any]:
"""Return the element-wise ``or`` of items."""
return data_group_nary(operator.or_, self, other)
def __xor__(
self, other: DataGroup[Any] | DataArray | Variable | float
) -> DataGroup[Any]:
"""Return the element-wise ``xor`` of items."""
return data_group_nary(operator.xor, self, other)
def __invert__(self) -> DataGroup[Any]:
"""Return the element-wise ``or`` of items."""
return self.apply(operator.invert) # type: ignore[arg-type]
def data_group_nary(
func: Callable[..., _R], *args: Any, **kwargs: Any
) -> DataGroup[_R]:
dgs = filter(
lambda x: isinstance(x, DataGroup), itertools.chain(args, kwargs.values())
)
keys = functools.reduce(operator.and_, [dg.keys() for dg in dgs])
def elem(x: Any, key: str) -> Any:
return x[key] if isinstance(x, DataGroup) else x
return DataGroup(
{
key: func(
*[elem(x, key) for x in args],
**{name: elem(x, key) for name, x in kwargs.items()},
)
for key in keys
}
)
def apply_to_items(
func: Callable[..., _R], dgs: Iterable[DataGroup[Any]], *args: Any, **kwargs: Any
) -> DataGroup[_R]:
keys = functools.reduce(operator.and_, [dg.keys() for dg in dgs])
return DataGroup(
{key: func([dg[key] for dg in dgs], *args, **kwargs) for key in keys}
)
def data_group_overload(
func: Callable[Concatenate[_T, _P], _R],
) -> Callable[..., _R | DataGroup[_R]]:
"""Add an overload for DataGroup to a function.
If the first argument of the function is a data group,
then the decorated function is mapped over all items.
It is applied recursively for items that are themselves data groups.
Otherwise, the original function is applied directly.
Parameters
----------
func:
Function to decorate.
Returns
-------
:
Decorated function.
"""
# Do not assign '__annotations__' because that causes an error in Sphinx.
@wraps(func, assigned=('__module__', '__name__', '__qualname__', '__doc__'))
def impl(
data: _T | DataGroup[Any], *args: _P.args, **kwargs: _P.kwargs
) -> _R | DataGroup[_R]:
if isinstance(data, DataGroup):
return data.apply(impl, *args, **kwargs) # type: ignore[arg-type]
return func(data, *args, **kwargs)
return impl
# There are currently no in-place operations (__iadd__, etc.) because they require
# a check if the operation would fail before doing it. As otherwise, a failure could
# leave a partially modified data group behind. Dataset implements such a check, but
# it is simpler than for DataGroup because the latter supports more data types.
# So for now, we went with the simple solution and
# not support in-place operations at all.
#
# Binding these functions dynamically has the added benefit that type checkers think
# that the operations are not implemented.
def _make_inplace_binary_op(name: str) -> Callable[..., NoReturn]:
def impl(
self: DataGroup[Any], other: DataGroup[Any] | DataArray | Variable | float
) -> NoReturn:
raise TypeError(f'In-place operation i{name} is not supported by DataGroup.')
return impl
for _name in ('add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'pow'):
full_name = f'__i{_name}__'
_binding.bind_function_as_method(
cls=DataGroup, name=full_name, func=_make_inplace_binary_op(full_name)
)
del _name, full_name