Coverage for install/scipp/compat/wrapping.py: 100%
51 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock
5from collections.abc import Callable, Mapping
6from functools import wraps
7from typing import Any, TypeVar
9from ..core import (
10 BinEdgeError,
11 DataArray,
12 DimensionError,
13 Variable,
14 VariancesError,
15)
18def _validated_masks(da: DataArray, dim: str) -> dict[str, Variable]:
19 masks = {}
20 for name, mask in da.masks.items():
21 if dim in mask.dims:
22 raise DimensionError(
23 f"Cannot apply function along '{dim}' since mask '{name}' depends "
24 "on this dimension."
25 )
26 masks[name] = mask.copy()
27 return masks
30_Out = TypeVar('_Out', bound=DataArray | Callable[..., DataArray])
33def wrap1d(
34 is_partial: bool = False, accept_masks: bool = False, keep_coords: bool = False
35) -> Callable[[Callable[..., _Out]], Callable[..., _Out]]:
36 """Decorator factory for decorating functions that wrap non-scipp 1-D functions.
38 1-D functions are typically functions from libraries such as scipy that depend
39 on a single 'axis' argument.
41 The decorators returned by this factory apply pre- and postprocessing as follows:
43 - An 'axis' keyword argument will raise ``ValueError``, recommending use of 'dim'.
44 The index of the provided dimension is added as axis to kwargs.
45 - Providing data with variances will raise ``sc.VariancesError`` since third-party
46 libraries typically cannot handle variances.
47 - Coordinates, masks, and attributes that act as "observers", i.e., do not depend
48 on the dimension of the function application, are added to the output data array.
49 Masks are deep-copied as per the usual requirement in Scipp.
51 Parameters
52 ----------
53 is_partial:
54 The wrapped function is partial, i.e., does not return a data
55 array itself, but a callable that returns a data array. If true,
56 the postprocessing step is not applied to the wrapped function.
57 Instead, the callable returned by the decorated function is
58 decorated with the postprocessing step.
59 accept_masks:
60 If false, all masks must apply to the dimension that
61 the function is applied to.
62 keep_coords:
63 If true, preserve the input coordinates.
64 If false, drop coordinates that do not apply to the dimension
65 the function is applied to.
66 """
68 def decorator(func: Callable[..., _Out]) -> Callable[..., _Out]:
69 @wraps(func)
70 def function(da: DataArray, dim: str, **kwargs: Any) -> _Out:
71 if 'axis' in kwargs:
72 raise ValueError("Use the 'dim' keyword argument instead of 'axis'.")
73 if da.variances is not None:
74 raise VariancesError(
75 "Cannot apply function to data with uncertainties. If uncertainties"
76 " should be ignored, use 'sc.values(da)' to extract only values."
77 )
78 if da.sizes[dim] != da.coords[dim].sizes[dim]:
79 raise BinEdgeError(
80 "Cannot apply function to data array with bin edges."
81 )
83 kwargs['axis'] = da.dims.index(dim)
84 result = func(da, dim, **kwargs)
85 return _postprocess(
86 input_da=da,
87 output_da=result,
88 dim=dim,
89 is_partial=is_partial,
90 accept_masks=accept_masks,
91 keep_coords=keep_coords,
92 )
94 return function
96 return decorator
99def _postprocess(
100 *,
101 input_da: DataArray,
102 output_da: _Out,
103 dim: str,
104 is_partial: bool,
105 accept_masks: bool,
106 keep_coords: bool,
107) -> _Out:
108 if accept_masks:
109 masks = _remove_columns_in_dim(input_da.masks, dim)
110 else:
111 masks = _validated_masks(input_da, dim)
112 if keep_coords:
113 coords: Mapping[str, Variable] = input_da.coords
114 attrs: Mapping[str, Variable] = input_da.deprecated_attrs
115 else:
116 coords = _remove_columns_in_dim(input_da.coords, dim)
117 attrs = _remove_columns_in_dim(input_da.deprecated_attrs, dim)
119 def add_observing_metadata(da: DataArray) -> DataArray:
120 # operates in-place!
121 da.coords.update(coords)
122 da.masks.update((key, mask.copy()) for key, mask in masks.items())
123 da.deprecated_attrs.update(attrs)
124 return da
126 if is_partial: # corresponds to `not isinstance(out_da, DataArray)`
128 def postprocessing(func: Callable[..., DataArray]) -> Callable[..., DataArray]:
129 @wraps(func)
130 def function(*args: Any, **kwargs: Any) -> DataArray:
131 return add_observing_metadata(func(*args, **kwargs))
133 return function
135 return postprocessing(output_da) # type: ignore[arg-type, return-value]
136 else:
137 return add_observing_metadata(output_da) # type: ignore[arg-type, return-value]
140def _remove_columns_in_dim(
141 mapping: Mapping[str, Variable], dim: str
142) -> dict[str, Variable]:
143 return {key: var for key, var in mapping.items() if dim not in var.dims}