# SPDX-License-Identifier: BSD-3-Clause# Copyright (c) 2022 Scipp contributors (https://github.com/scipp)# @author Simon Heybrockfrom__future__importannotationsimportcopyimportfunctoolsimportitertoolsimportnumbersimportoperatorfromcollections.abcimport(Callable,ItemsView,Iterable,Iterator,KeysView,Mapping,MutableMapping,Sequence,ValuesView,)fromfunctoolsimportwrapsfromtypingimport(TYPE_CHECKING,Any,Concatenate,NoReturn,ParamSpec,TypeVar,cast,overload,)importnumpyasnpfrom..import_bindingfrom.cpp_classesimport(DataArray,Dataset,DimensionError,GroupByDataArray,GroupByDataset,Unit,Variable,)ifTYPE_CHECKING:# Avoid cyclic importsfrom..coords.graphimportGraphDictfrom..typingimportScippIndexfrom.binsimportBins_T=TypeVar("_T")# Any type_V=TypeVar("_V")# Value type of self_R=TypeVar("_R")# Return type of a callable_P=ParamSpec('_P')def_item_dims(item:Any)->tuple[str,...]:returngetattr(item,'dims',())def_is_binned(item:Any)->bool:from.binsimportBinsifisinstance(item,Bins):returnTruereturngetattr(item,'bins',None)isnotNonedef_summarize(item:Any)->str:ifisinstance(item,DataGroup):returnf'{type(item).__name__}({len(item)}, {item.sizes})'ifhasattr(item,'sizes'):returnf'{type(item).__name__}({item.sizes})'returnstr(item)def_is_positional_index(key:Any)->bool:defis_int(x:object)->bool:returnisinstance(x,numbers.Integral)ifis_int(key):returnTrueifisinstance(key,slice):ifis_int(key.start)oris_int(key.stop)oris_int(key.step):returnTrueifkey.startisNoneandkey.stopisNoneandkey.stepisNone:returnTruereturnFalsedef_is_list_index(key:Any)->bool:returnisinstance(key,list|np.ndarray)classDataGroup(MutableMapping[str,_V]):""" A dict-like group of data. Additionally provides dims and shape properties. DataGroup acts like a Python dict but additionally supports Scipp functionality such as positional- and label-based indexing and Scipp operations by mapping them to the values in the dict. This may happen recursively to support tree-like data structures. .. versionadded:: 23.01.0 """def__init__(self,/,*args:Iterable[tuple[str,_V]]|Mapping[str,_V],**kwargs:_V)->None:self._items=dict(*args,**kwargs)ifnotall(isinstance(k,str)forkinself._items.keys()):raiseValueError("DataGroup keys must be strings.")def__copy__(self)->DataGroup[_V]:returnDataGroup(copy.copy(self._items))def__len__(self)->int:"""Return the number of items in the data group."""returnlen(self._items)def__iter__(self)->Iterator[str]:returniter(self._items)defkeys(self)->KeysView[str]:returnself._items.keys()defvalues(self)->ValuesView[_V]:returnself._items.values()defitems(self)->ItemsView[str,_V]:returnself._items.items()@overloaddef__getitem__(self,name:str)->_V:...@overloaddef__getitem__(self,name:ScippIndex)->DataGroup[_V]:...def__getitem__(self,name:Any)->Any:"""Return item of given name or index all items. When ``name`` is a string, return the item of the given name. Otherwise, this returns a new DataGroup, with items created by indexing the items in this DataGroup. This may perform, e.g., Scipp's positional indexing, label-based indexing, or advanced indexing on items that are scipp.Variable or scipp.DataArray. Label-based indexing is only possible when all items have a coordinate for the indexed dimension. Advanced indexing comprises integer-array indexing and boolean-variable indexing. Unlike positional indexing, integer-array indexing works even when the item shapes are inconsistent for the indexed dimensions, provided that all items contain the maximal index in the integer array. Boolean-variable indexing is only possible when the shape of all items is compatible with the boolean variable. """from.binsimportBinsifisinstance(name,str):returnself._items[name]ifisinstance(name,tuple)andname==():returncast(DataGroup[Any],self).apply(operator.itemgetter(name))ifisinstance(name,Variable):# boolean indexingreturncast(DataGroup[Any],self).apply(operator.itemgetter(name))if_is_positional_index(name)or_is_list_index(name):ifself.ndim!=1:raiseDimensionError("Slicing with implicit dimension label is only possible "f"for 1-D objects. Got {self.sizes} with ndim={self.ndim}. Provide ""an explicit dimension label, e.g., var['x', 0] instead of var[0].")dim=self.dims[0]index=nameelse:dim,index=namereturnDataGroup({key:var[dim,index]# type: ignore[index]if(isinstance(var,Bins)ordimin_item_dims(var))elsevarforkey,varinself.items()})def__setitem__(self,name:str,value:_V)->None:"""Set self[key] to value."""ifisinstance(name,str):self._items[name]=valueelse:raiseTypeError('Keys must be strings')def__delitem__(self,name:str)->None:"""Delete self[key]."""delself._items[name]def__sizeof__(self)->int:returnself.underlying_size()defunderlying_size(self)->int:# TODO Return the underlying size of all items in DataGrouptotal_size=super.__sizeof__(self)foriteminself.values():ifisinstance(item,DataArray|Dataset|Variable|DataGroup):total_size+=item.underlying_size()elifhasattr(item,'nbytes'):total_size+=item.nbyteselse:total_size+=item.__sizeof__()returntotal_size@propertydefdims(self)->tuple[str,...]:"""Union of dims of all items. Non-Scipp items are handled as dims=()."""returntuple(self.sizes)@propertydefndim(self)->int:"""Number of dimensions, i.e., len(self.dims)."""returnlen(self.dims)@propertydefshape(self)->tuple[int|None,...]:"""Union of shape of all items. Non-Scipp items are handled as shape=()."""returntuple(self.sizes.values())@propertydefsizes(self)->dict[str,int|None]:"""Dict combining dims and shape, i.e., mapping dim labels to their size."""all_sizes:dict[str,set[int]]={}forxinself.values():fordim,sizeingetattr(x,'sizes',{}).items():all_sizes.setdefault(dim,set()).add(size)return{d:next(iter(s))iflen(s)==1elseNoneford,sinall_sizes.items()}def_repr_html_(self)->str:from..visualization.formatting_datagroup_htmlimportdatagroup_reprreturndatagroup_repr(self)def__repr__(self)->str:r=f'DataGroup(sizes={self.sizes}, keys=[\n'forname,varinself.items():r+=f' {name}: {_summarize(var)},\n'r+='])'returnrdef__str__(self)->str:returnf'DataGroup(sizes={self.sizes}, keys={list(self.keys())})'@propertydefbins(self)->DataGroup[DataGroup[Any]|Bins[Any]|None]:# TODO Returning a regular DataGroup here may be wrong, since the `bins`# property provides a different set of attrs and methods.returnself.apply(operator.attrgetter('bins'))defapply(self,func:Callable[Concatenate[_V,_P],_R],*args:_P.args,**kwargs:_P.kwargs,)->DataGroup[_R]:"""Call func on all values and return new DataGroup containing the results."""returnDataGroup({key:func(v,*args,**kwargs)forkey,vinself.items()})def_transform_dim(self,func:str,*,dim:None|str|Iterable[str],**kwargs:Any)->DataGroup[Any]:"""Transform items that depend on one or more dimensions given by `dim`."""dims=(dim,)ifisinstance(dim,str)elsedimdefintersects(item:_V)->bool:item_dims=_item_dims(item)ifdimsisNone:returnitem_dims!=()returnset(dims).intersection(item_dims)!=set()returnDataGroup({key:vifnotintersects(v)elseoperator.methodcaller(func,dim,**kwargs)(v)forkey,vinself.items()})def_reduce(self,method:str,dim:None|str|Sequence[str]=None,**kwargs:Any)->DataGroup[Any]:reduce_all=operator.methodcaller(method,**kwargs)def_reduce_child(v:_V)->Any:ifisinstance(v,GroupByDataArray|GroupByDataset):child_dims:tuple[None|str|Sequence[str],...]=(dim,)else:child_dims=_item_dims(v)# Reduction operations on binned data implicitly reduce over bin content.# Therefore, a purely dimension-based logic is not sufficient to determine# if the item has to be reduced or not.binned=_is_binned(v)ifchild_dims==()andnotbinned:returnvifdimisNone:returnreduce_all(v)ifisinstance(dim,str):dims_to_reduce:tuple[str,...]|str=dimifdiminchild_dimselse()else:dims_to_reduce=tuple(dfordindimifdinchild_dims)ifdims_to_reduce==()andbinned:returnreduce_all(v)return(vifdims_to_reduce==()elseoperator.methodcaller(method,dims_to_reduce,**kwargs)(v))returnDataGroup({key:_reduce_child(v)forkey,vinself.items()})defcopy(self,deep:bool=True)->DataGroup[_V]:returncopy.deepcopy(self)ifdeepelsecopy.copy(self)defall(self,dim:None|str|tuple[str,...]=None)->DataGroup[_V]:returnself._reduce('all',dim)defany(self,dim:None|str|tuple[str,...]=None)->DataGroup[_V]:returnself._reduce('any',dim)defastype(self,type:Any,*,copy:bool=True)->DataGroup[_V]:returnself.apply(operator.methodcaller('astype',type,copy=copy))defbin(self,arg_dict:dict[str,int|Variable]|None=None,/,**kwargs:int|Variable,)->DataGroup[_V]:returnself.apply(operator.methodcaller('bin',arg_dict,**kwargs))@overloaddefbroadcast(self,*,dims:Sequence[str],shape:Sequence[int],)->DataGroup[_V]:...@overloaddefbroadcast(self,*,sizes:dict[str,int],)->DataGroup[_V]:...defbroadcast(self,*,dims:Sequence[str]|None=None,shape:Sequence[int]|None=None,sizes:dict[str,int]|None=None,)->DataGroup[_V]:returnself.apply(operator.methodcaller('broadcast',dims=dims,shape=shape,sizes=sizes))defceil(self)->DataGroup[_V]:returnself.apply(operator.methodcaller('ceil'))defflatten(self,dims:Sequence[str]|None=None,to:str|None=None)->DataGroup[_V]:returnself._transform_dim('flatten',dim=dims,to=to)deffloor(self)->DataGroup[_V]:returnself.apply(operator.methodcaller('floor'))@overloaddeffold(self,dim:str,*,dims:Sequence[str],shape:Sequence[int],)->DataGroup[_V]:...@overloaddeffold(self,dim:str,*,sizes:dict[str,int],)->DataGroup[_V]:...deffold(self,dim:str,*,dims:Sequence[str]|None=None,shape:Sequence[int]|None=None,sizes:dict[str,int]|None=None,)->DataGroup[_V]:returnself._transform_dim('fold',dim=dim,dims=dims,shape=shape,sizes=sizes)defgroup(self,/,*args:str|Variable)->DataGroup[_V]:returnself.apply(operator.methodcaller('group',*args))defgroupby(self,/,group:Variable|str,*,bins:Variable|None=None)->DataGroup[GroupByDataArray|GroupByDataset]:returnself.apply(operator.methodcaller('groupby',group,bins=bins))defhist(self,arg_dict:dict[str,int|Variable]|None=None,/,**kwargs:int|Variable,)->DataGroup[DataArray|Dataset]:returnself.apply(operator.methodcaller('hist',arg_dict,**kwargs))defmax(self,dim:None|str|tuple[str,...]=None)->DataGroup[_V]:returnself._reduce('max',dim)defmean(self,dim:None|str|tuple[str,...]=None)->DataGroup[_V]:returnself._reduce('mean',dim)defmedian(self,dim:None|str|tuple[str,...]=None)->DataGroup[_V]:returnself._reduce('median',dim)defmin(self,dim:None|str|tuple[str,...]=None)->DataGroup[_V]:returnself._reduce('min',dim)defnanhist(self,arg_dict:dict[str,int|Variable]|None=None,/,**kwargs:int|Variable,)->DataGroup[DataArray]:returnself.apply(operator.methodcaller('nanhist',arg_dict,**kwargs))defnanmax(self,dim:None|str|tuple[str,...]=None)->DataGroup[_V]:returnself._reduce('nanmax',dim)defnanmean(self,dim:None|str|tuple[str,...]=None)->DataGroup[_V]:returnself._reduce('nanmean',dim)defnanmedian(self,dim:None|str|tuple[str,...]=None)->DataGroup[_V]:returnself._reduce('nanmedian',dim)defnanmin(self,dim:None|str|tuple[str,...]=None)->DataGroup[_V]:returnself._reduce('nanmin',dim)defnansum(self,dim:None|str|tuple[str,...]=None)->DataGroup[_V]:returnself._reduce('nansum',dim)defnanstd(self,dim:None|str|tuple[str,...]=None,*,ddof:int)->DataGroup[_V]:returnself._reduce('nanstd',dim,ddof=ddof)defnanvar(self,dim:None|str|tuple[str,...]=None,*,ddof:int)->DataGroup[_V]:returnself._reduce('nanvar',dim,ddof=ddof)defrebin(self,arg_dict:dict[str,int|Variable]|None=None,/,**kwargs:int|Variable,)->DataGroup[_V]:returnself.apply(operator.methodcaller('rebin',arg_dict,**kwargs))defrename(self,dims_dict:dict[str,str]|None=None,/,**names:str)->DataGroup[_V]:returnself.apply(operator.methodcaller('rename',dims_dict,**names))defrename_dims(self,dims_dict:dict[str,str]|None=None,/,**names:str)->DataGroup[_V]:returnself.apply(operator.methodcaller('rename_dims',dims_dict,**names))defround(self)->DataGroup[_V]:returnself.apply(operator.methodcaller('round'))defsqueeze(self,dim:str|Sequence[str]|None=None)->DataGroup[_V]:returnself._reduce('squeeze',dim)defstd(self,dim:None|str|tuple[str,...]=None,*,ddof:int)->DataGroup[_V]:returnself._reduce('std',dim,ddof=ddof)defsum(self,dim:None|str|tuple[str,...]=None)->DataGroup[_V]:returnself._reduce('sum',dim)defto(self,*,unit:Unit|str|None=None,dtype:Any|None=None,copy:bool=True,)->DataGroup[_V]:returnself.apply(operator.methodcaller('to',unit=unit,dtype=dtype,copy=copy))deftransform_coords(self,targets:str|Iterable[str]|None=None,/,graph:GraphDict|None=None,*,rename_dims:bool=True,keep_aliases:bool=True,keep_intermediate:bool=True,keep_inputs:bool=True,quiet:bool=False,**kwargs:Callable[...,Variable],)->DataGroup[_V]:returnself.apply(operator.methodcaller('transform_coords',targets,graph=graph,rename_dims=rename_dims,keep_aliases=keep_aliases,keep_intermediate=keep_intermediate,keep_inputs=keep_inputs,quiet=quiet,**kwargs,))deftranspose(self,dims:None|tuple[str,...]=None)->DataGroup[_V]:returnself._transform_dim('transpose',dim=dims)defvar(self,dim:None|str|tuple[str,...]=None,*,ddof:int)->DataGroup[_V]:returnself._reduce('var',dim,ddof=ddof)defplot(self,*args:Any,**kwargs:Any)->Any:importploppreturnplopp.plot(self,*args,**kwargs)def__eq__(# type: ignore[override]self,other:DataGroup[object]|DataArray|Variable|float)->DataGroup[_V|bool]:"""Item-wise equal."""returndata_group_nary(operator.eq,self,other)def__ne__(# type: ignore[override]self,other:DataGroup[object]|DataArray|Variable|float)->DataGroup[_V|bool]:"""Item-wise not-equal."""returndata_group_nary(operator.ne,self,other)def__gt__(self,other:DataGroup[object]|DataArray|Variable|float)->DataGroup[_V|bool]:"""Item-wise greater-than."""returndata_group_nary(operator.gt,self,other)def__ge__(self,other:DataGroup[object]|DataArray|Variable|float)->DataGroup[_V|bool]:"""Item-wise greater-equal."""returndata_group_nary(operator.ge,self,other)def__lt__(self,other:DataGroup[object]|DataArray|Variable|float)->DataGroup[_V|bool]:"""Item-wise less-than."""returndata_group_nary(operator.lt,self,other)def__le__(self,other:DataGroup[object]|DataArray|Variable|float)->DataGroup[_V|bool]:"""Item-wise less-equal."""returndata_group_nary(operator.le,self,other)def__add__(self,other:DataGroup[Any]|DataArray|Variable|float)->DataGroup[Any]:"""Apply ``add`` item-by-item."""returndata_group_nary(operator.add,self,other)def__sub__(self,other:DataGroup[Any]|DataArray|Variable|float)->DataGroup[Any]:"""Apply ``sub`` item-by-item."""returndata_group_nary(operator.sub,self,other)def__mul__(self,other:DataGroup[Any]|DataArray|Variable|float)->DataGroup[Any]:"""Apply ``mul`` item-by-item."""returndata_group_nary(operator.mul,self,other)def__truediv__(self,other:DataGroup[Any]|DataArray|Variable|float)->DataGroup[Any]:"""Apply ``truediv`` item-by-item."""returndata_group_nary(operator.truediv,self,other)def__floordiv__(self,other:DataGroup[Any]|DataArray|Variable|float)->DataGroup[Any]:"""Apply ``floordiv`` item-by-item."""returndata_group_nary(operator.floordiv,self,other)def__mod__(self,other:DataGroup[Any]|DataArray|Variable|float)->DataGroup[Any]:"""Apply ``mod`` item-by-item."""returndata_group_nary(operator.mod,self,other)def__pow__(self,other:DataGroup[Any]|DataArray|Variable|float)->DataGroup[Any]:"""Apply ``pow`` item-by-item."""returndata_group_nary(operator.pow,self,other)def__radd__(self,other:DataArray|Variable|float)->DataGroup[Any]:"""Apply ``add`` item-by-item."""returndata_group_nary(operator.add,other,self)def__rsub__(self,other:DataArray|Variable|float)->DataGroup[Any]:"""Apply ``sub`` item-by-item."""returndata_group_nary(operator.sub,other,self)def__rmul__(self,other:DataArray|Variable|float)->DataGroup[Any]:"""Apply ``mul`` item-by-item."""returndata_group_nary(operator.mul,other,self)def__rtruediv__(self,other:DataArray|Variable|float)->DataGroup[Any]:"""Apply ``truediv`` item-by-item."""returndata_group_nary(operator.truediv,other,self)def__rfloordiv__(self,other:DataArray|Variable|float)->DataGroup[Any]:"""Apply ``floordiv`` item-by-item."""returndata_group_nary(operator.floordiv,other,self)def__rmod__(self,other:DataArray|Variable|float)->DataGroup[Any]:"""Apply ``mod`` item-by-item."""returndata_group_nary(operator.mod,other,self)def__rpow__(self,other:DataArray|Variable|float)->DataGroup[Any]:"""Apply ``pow`` item-by-item."""returndata_group_nary(operator.pow,other,self)def__and__(self,other:DataGroup[Any]|DataArray|Variable|float)->DataGroup[Any]:"""Return the element-wise ``and`` of items."""returndata_group_nary(operator.and_,self,other)def__or__(self,other:DataGroup[Any]|DataArray|Variable|float)->DataGroup[Any]:"""Return the element-wise ``or`` of items."""returndata_group_nary(operator.or_,self,other)def__xor__(self,other:DataGroup[Any]|DataArray|Variable|float)->DataGroup[Any]:"""Return the element-wise ``xor`` of items."""returndata_group_nary(operator.xor,self,other)def__invert__(self)->DataGroup[Any]:"""Return the element-wise ``or`` of items."""returnself.apply(operator.invert)# type: ignore[arg-type]defdata_group_nary(func:Callable[...,_R],*args:Any,**kwargs:Any)->DataGroup[_R]:dgs=filter(lambdax:isinstance(x,DataGroup),itertools.chain(args,kwargs.values()))keys=functools.reduce(operator.and_,[dg.keys()fordgindgs])defelem(x:Any,key:str)->Any:returnx[key]ifisinstance(x,DataGroup)elsexreturnDataGroup({key:func(*[elem(x,key)forxinargs],**{name:elem(x,key)forname,xinkwargs.items()},)forkeyinkeys})defapply_to_items(func:Callable[...,_R],dgs:Iterable[DataGroup[Any]],*args:Any,**kwargs:Any)->DataGroup[_R]:keys=functools.reduce(operator.and_,[dg.keys()fordgindgs])returnDataGroup({key:func([dg[key]fordgindgs],*args,**kwargs)forkeyinkeys})defdata_group_overload(func:Callable[Concatenate[_T,_P],_R],)->Callable[...,_R|DataGroup[_R]]:"""Add an overload for DataGroup to a function. If the first argument of the function is a data group, then the decorated function is mapped over all items. It is applied recursively for items that are themselves data groups. Otherwise, the original function is applied directly. Parameters ---------- func: Function to decorate. Returns ------- : Decorated function. """# Do not assign '__annotations__' because that causes an error in Sphinx.@wraps(func,assigned=('__module__','__name__','__qualname__','__doc__'))defimpl(data:_T|DataGroup[Any],*args:_P.args,**kwargs:_P.kwargs)->_R|DataGroup[_R]:ifisinstance(data,DataGroup):returndata.apply(impl,*args,**kwargs)# type: ignore[arg-type]returnfunc(data,*args,**kwargs)returnimpl# There are currently no in-place operations (__iadd__, etc.) because they require# a check if the operation would fail before doing it. As otherwise, a failure could# leave a partially modified data group behind. Dataset implements such a check, but# it is simpler than for DataGroup because the latter supports more data types.# So for now, we went with the simple solution and# not support in-place operations at all.## Binding these functions dynamically has the added benefit that type checkers think# that the operations are not implemented.def_make_inplace_binary_op(name:str)->Callable[...,NoReturn]:defimpl(self:DataGroup[Any],other:DataGroup[Any]|DataArray|Variable|float)->NoReturn:raiseTypeError(f'In-place operation i{name} is not supported by DataGroup.')returnimplfor_namein('add','sub','mul','truediv','floordiv','mod','pow'):full_name=f'__i{_name}__'_binding.bind_function_as_method(cls=DataGroup,name=full_name,func=_make_inplace_binary_op(full_name))del_name,full_name