Coverage for install/scipp/scipy/interpolate/__init__.py: 100%
42 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock
4"""Sub-package for objects used in interpolation.
6This subpackage provides wrappers for a subset of functions from
7:py:mod:`scipy.interpolate`.
8"""
10from __future__ import annotations
12from typing import Any, Literal, Protocol, TypeVar
14import numpy as np
15import numpy.typing as npt
17from ...compat.wrapping import wrap1d
18from ...core import (
19 DataArray,
20 DimensionError,
21 DType,
22 UnitError,
23 Variable,
24 empty,
25 epoch,
26 irreducible_mask,
27)
29_ArrayOrVar = TypeVar('_ArrayOrVar', npt.NDArray[Any], Variable)
32def _as_interpolation_type(x: _ArrayOrVar) -> _ArrayOrVar:
33 if isinstance(x, np.ndarray):
34 if x.dtype.kind == 'M':
35 return x.astype('int64', copy=False)
36 else:
37 if x.dtype == DType.datetime64:
38 return x - epoch(unit=x.unit)
39 return x
42def _midpoints(var: Variable, dim: str) -> Variable:
43 a = var[dim, :-1]
44 b = var[dim, 1:]
45 return _as_interpolation_type(a) + 0.5 * (b - a)
48def _drop_masked(da: DataArray, dim: str) -> DataArray:
49 if (mask := irreducible_mask(da.masks, dim)) is not None:
50 return da[~mask]
51 return da
54@wrap1d(is_partial=True, accept_masks=True)
55def interp1d(
56 da: DataArray,
57 dim: str,
58 *,
59 kind: int
60 | Literal[
61 'linear',
62 'nearest',
63 'nearest-up',
64 'zero',
65 'slinear',
66 'quadratic',
67 'cubic',
68 'previous',
69 'next',
70 ] = 'linear',
71 fill_value: Any = np.nan,
72 **kwargs: Any,
73) -> _Interp1dImpl:
74 """Interpolate a 1-D function.
76 A data array is used to approximate some function f: y = f(x), where y is given by
77 the array values and x is is given by the coordinate for the given dimension. This
78 class returns a function whose call method uses interpolation to find the value of
79 new points.
81 The function is a wrapper for scipy.interpolate.interp1d. The differences are:
83 - Instead of x and y, a data array defining these is used as input.
84 - Instead of an axis, a dimension label defines the interpolation dimension.
85 - The returned function does not just return the values of f(x) but a new
86 data array with values defined as f(x) and x as a coordinate for the
87 interpolation dimension.
88 - The returned function accepts an extra argument ``midpoints``. When setting
89 ``midpoints=True`` the interpolation uses the midpoints of the new points
90 instead of the points itself. The returned data array is then a histogram, i.e.,
91 the new coordinate is a bin-edge coordinate.
93 If the input data array contains masks that depend on the interpolation dimension
94 the masked points are treated as missing, i.e., they are ignored for the definition
95 of the interpolation function. If such a mask also depends on additional dimensions
96 :py:class:`scipp.DimensionError` is raised since interpolation requires points to
97 be 1-D.
99 For structured input data dtypes such as vectors, rotations, or linear
100 transformations interpolation is structure-element-wise. While this is appropriate
101 for vectors, such a naive interpolation for, e.g., rotations does typically not
102 yield a rotation so this should be used with care, unless the 'kind' parameter is
103 set to, e.g., 'previous', 'next', or 'nearest'.
105 Parameters not described above are forwarded to scipy.interpolate.interp1d. The
106 most relevant ones are (see :py:class:`scipy.interpolate.interp1d` for details):
108 Parameters
109 ----------
110 da:
111 Input data. Defines both dependent and independent variables for interpolation.
112 dim:
113 Dimension of the interpolation.
114 kind:
116 - **integer**: order of the spline interpolator
117 - **string**:
119 - 'zero', 'slinear', 'quadratic', 'cubic':
120 spline interpolation of zeroth, first, second or third order
121 - 'previous' and 'next':
122 simply return the previous or next value of the point
123 - 'nearest-up' and 'nearest'
124 differ when interpolating half-integers (e.g. 0.5, 1.5) in that
125 'nearest-up' rounds up and 'nearest' rounds down
126 fill_value:
127 Set to 'extrapolate' to allow for extrapolation of points
128 outside the range.
130 Returns
131 -------
132 :
133 A callable ``f(x)`` that returns interpolated values of ``da`` at ``x``.
135 Examples
136 --------
138 .. plot:: :context: close-figs
140 >>> x = sc.linspace(dim='x', start=0.1, stop=1.4, num=4, unit='rad')
141 >>> da = sc.DataArray(sc.sin(x), coords={'x': x})
143 >>> from scipp.scipy.interpolate import interp1d
144 >>> f = interp1d(da, 'x')
146 >>> xnew = sc.linspace(dim='x', start=0.1, stop=1.4, num=12, unit='rad')
147 >>> f(xnew) # use interpolation function returned by `interp1d`
148 <scipp.DataArray>
149 Dimensions: Sizes[x:12, ]
150 Coordinates:
151 * x float64 [rad] (x) [0.1, 0.218182, ..., 1.28182, 1.4]
152 Data:
153 float64 [dimensionless] (x) [0.0998334, 0.211262, ..., 0.941144, 0.98545]
155 >>> f(xnew, midpoints=True)
156 <scipp.DataArray>
157 Dimensions: Sizes[x:11, ]
158 Coordinates:
159 * x float64 [rad] (x [bin-edge]) [0.1, 0.218182, ..., 1.28182, 1.4]
160 Data:
161 float64 [dimensionless] (x) [0.155548, 0.266977, ..., 0.918992, 0.963297]
163 .. plot:: :context: close-figs
165 >>> sc.plot({'original':da,
166 ... 'interp1d':f(xnew),
167 ... 'interp1d-midpoints':f(xnew, midpoints=True)})
168 """ # noqa: E501
169 import scipy.interpolate as inter
171 da = _drop_masked(da, dim)
173 def func(xnew: Variable, *, midpoints: bool = False) -> DataArray:
174 """Compute interpolation function defined by ``interp1d``
175 at interpolation points.
177 Parameters
178 ----------
179 xnew:
180 Interpolation points.
181 midpoints:
182 Interpolate at midpoints of given points.
183 The result will be a histogram.
184 Default is ``False``.
186 Returns
187 -------
188 :
189 Interpolated data array with new coord given by interpolation points
190 and data given by interpolation function evaluated at the
191 interpolation points (or evaluated at the midpoints of the given points).
192 """
193 if xnew.unit != da.coords[dim].unit:
194 raise UnitError(
195 f"Unit of interpolation points '{xnew.unit}' does not match unit "
196 f"'{da.coords[dim].unit}' of points defining the interpolation "
197 "function along dimension '{dim}'."
198 )
199 if xnew.dim != dim:
200 raise DimensionError(
201 f"Dimension of interpolation points '{xnew.dim}' does not match "
202 f"interpolation dimension '{dim}'"
203 )
204 f = inter.interp1d(
205 x=_as_interpolation_type(da.coords[dim].values),
206 y=da.values,
207 kind=kind,
208 fill_value=fill_value,
209 **kwargs,
210 )
211 x_ = _as_interpolation_type(_midpoints(xnew, dim) if midpoints else xnew)
212 sizes = da.sizes
213 sizes[dim] = x_.sizes[dim]
214 # ynew is created in this manner to allow for creation of structured dtypes,
215 # which is not possible using scipp.array
216 ynew = empty(sizes=sizes, unit=da.unit, dtype=da.dtype)
217 ynew.values = f(x_.values)
218 return DataArray(data=ynew, coords={dim: xnew})
220 return func
223class _Interp1dImpl(Protocol):
224 def __call__(self, xnew: Variable, *, midpoints: bool = False) -> DataArray: ...
227__all__ = ['interp1d']