Coverage for install/scipp/compat/wrapping.py: 100%
47 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock
5from functools import wraps
6from typing import Callable, Union
8from ..core import BinEdgeError, DataArray, DimensionError, VariancesError
11def _validated_masks(da, dim):
12 masks = {}
13 for name, mask in da.masks.items():
14 if dim in mask.dims:
15 raise DimensionError(
16 f"Cannot apply function along '{dim}' since mask '{name}' depends "
17 "on this dimension."
18 )
19 masks[name] = mask.copy()
20 return masks
23def wrap1d(is_partial=False, accept_masks=False, keep_coords=False):
24 """Decorator factory for decorating functions that wrap non-scipp 1-D functions.
26 1-D functions are typically functions from libraries such as scipy that depend
27 on a single 'axis' argument.
29 The decorators returned by this factory apply pre- and postprocessing as follows:
31 - An 'axis' keyword argument will raise ``ValueError``, recommending use of 'dim'.
32 The index of the provided dimension is added as axis to kwargs.
33 - Providing data with variances will raise ``sc.VariancesError`` since third-party
34 libraries typically cannot handle variances.
35 - Coordinates, masks, and attributes that act as "observers", i.e., do not depend
36 on the dimension of the function application, are added to the output data array.
37 Masks are deep-copied as per the usual requirement in Scipp.
39 Parameters
40 ----------
41 is_partial:
42 The wrapped function is partial, i.e., does not return a data
43 array itself, but a callable that returns a data array. If true,
44 the postprocessing step is not applied to the wrapped function.
45 Instead, the callable returned by the decorated function is
46 decorated with the postprocessing step.
47 """
49 def decorator(func: Callable) -> Callable:
50 @wraps(func)
51 def function(da: DataArray, dim: str, **kwargs) -> Union[DataArray, Callable]:
52 if 'axis' in kwargs:
53 raise ValueError("Use the 'dim' keyword argument instead of 'axis'.")
54 if da.variances is not None:
55 raise VariancesError(
56 "Cannot apply function to data with uncertainties. If uncertainties"
57 " should be ignored, use 'sc.values(da)' to extract only values."
58 )
59 if da.sizes[dim] != da.coords[dim].sizes[dim]:
60 raise BinEdgeError(
61 "Cannot apply function to data array with bin edges."
62 )
64 kwargs['axis'] = da.dims.index(dim)
66 if accept_masks:
67 masks = {k: v for k, v in da.masks.items() if dim not in v.dims}
68 else:
69 masks = _validated_masks(da, dim)
70 if keep_coords:
71 coords = da.coords
72 attrs = da.deprecated_attrs
73 else:
74 coords = {k: v for k, v in da.coords.items() if dim not in v.dims}
75 attrs = {
76 k: v for k, v in da.deprecated_attrs.items() if dim not in v.dims
77 }
79 def _add_observing_metadata(da):
80 for k, v in coords.items():
81 da.coords[k] = v
82 for k, v in masks.items():
83 da.masks[k] = v.copy()
84 for k, v in attrs.items():
85 da.deprecated_attrs[k] = v
86 return da
88 def postprocessing(func):
89 @wraps(func)
90 def function(*args, **kwargs):
91 return _add_observing_metadata(func(*args, **kwargs))
93 return function
95 if is_partial:
96 return postprocessing(func(da, dim, **kwargs))
97 else:
98 return _add_observing_metadata(func(da, dim, **kwargs))
100 return function
102 return decorator