Coverage for install/scipp/scipy/interpolate/__init__.py: 100%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-17 01:51 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Simon Heybrock 

4"""Sub-package for objects used in interpolation. 

5 

6This subpackage provides wrappers for a subset of functions from 

7:py:mod:`scipy.interpolate`. 

8""" 

9 

10from __future__ import annotations 

11 

12from typing import Any, Literal, Protocol, TypeVar 

13 

14import numpy as np 

15import numpy.typing as npt 

16 

17from ...compat.wrapping import wrap1d 

18from ...core import ( 

19 DataArray, 

20 DimensionError, 

21 DType, 

22 UnitError, 

23 Variable, 

24 empty, 

25 epoch, 

26 irreducible_mask, 

27) 

28 

29_ArrayOrVar = TypeVar('_ArrayOrVar', npt.NDArray[Any], Variable) 

30 

31 

32def _as_interpolation_type(x: _ArrayOrVar) -> _ArrayOrVar: 

33 if isinstance(x, np.ndarray): 

34 if x.dtype.kind == 'M': 

35 return x.astype('int64', copy=False) 

36 else: 

37 if x.dtype == DType.datetime64: 

38 return x - epoch(unit=x.unit) 

39 return x 

40 

41 

42def _midpoints(var: Variable, dim: str) -> Variable: 

43 a = var[dim, :-1] 

44 b = var[dim, 1:] 

45 return _as_interpolation_type(a) + 0.5 * (b - a) 

46 

47 

48def _drop_masked(da: DataArray, dim: str) -> DataArray: 

49 if (mask := irreducible_mask(da.masks, dim)) is not None: 

50 return da[~mask] 

51 return da 

52 

53 

54@wrap1d(is_partial=True, accept_masks=True) 

55def interp1d( 

56 da: DataArray, 

57 dim: str, 

58 *, 

59 kind: int 

60 | Literal[ 

61 'linear', 

62 'nearest', 

63 'nearest-up', 

64 'zero', 

65 'slinear', 

66 'quadratic', 

67 'cubic', 

68 'previous', 

69 'next', 

70 ] = 'linear', 

71 fill_value: Any = np.nan, 

72 **kwargs: Any, 

73) -> _Interp1dImpl: 

74 """Interpolate a 1-D function. 

75 

76 A data array is used to approximate some function f: y = f(x), where y is given by 

77 the array values and x is is given by the coordinate for the given dimension. This 

78 class returns a function whose call method uses interpolation to find the value of 

79 new points. 

80 

81 The function is a wrapper for scipy.interpolate.interp1d. The differences are: 

82 

83 - Instead of x and y, a data array defining these is used as input. 

84 - Instead of an axis, a dimension label defines the interpolation dimension. 

85 - The returned function does not just return the values of f(x) but a new 

86 data array with values defined as f(x) and x as a coordinate for the 

87 interpolation dimension. 

88 - The returned function accepts an extra argument ``midpoints``. When setting 

89 ``midpoints=True`` the interpolation uses the midpoints of the new points 

90 instead of the points itself. The returned data array is then a histogram, i.e., 

91 the new coordinate is a bin-edge coordinate. 

92 

93 If the input data array contains masks that depend on the interpolation dimension 

94 the masked points are treated as missing, i.e., they are ignored for the definition 

95 of the interpolation function. If such a mask also depends on additional dimensions 

96 :py:class:`scipp.DimensionError` is raised since interpolation requires points to 

97 be 1-D. 

98 

99 For structured input data dtypes such as vectors, rotations, or linear 

100 transformations interpolation is structure-element-wise. While this is appropriate 

101 for vectors, such a naive interpolation for, e.g., rotations does typically not 

102 yield a rotation so this should be used with care, unless the 'kind' parameter is 

103 set to, e.g., 'previous', 'next', or 'nearest'. 

104 

105 Parameters not described above are forwarded to scipy.interpolate.interp1d. The 

106 most relevant ones are (see :py:class:`scipy.interpolate.interp1d` for details): 

107 

108 Parameters 

109 ---------- 

110 da: 

111 Input data. Defines both dependent and independent variables for interpolation. 

112 dim: 

113 Dimension of the interpolation. 

114 kind: 

115 

116 - **integer**: order of the spline interpolator 

117 - **string**: 

118 

119 - 'zero', 'slinear', 'quadratic', 'cubic': 

120 spline interpolation of zeroth, first, second or third order 

121 - 'previous' and 'next': 

122 simply return the previous or next value of the point 

123 - 'nearest-up' and 'nearest' 

124 differ when interpolating half-integers (e.g. 0.5, 1.5) in that 

125 'nearest-up' rounds up and 'nearest' rounds down 

126 fill_value: 

127 Set to 'extrapolate' to allow for extrapolation of points 

128 outside the range. 

129 

130 Returns 

131 ------- 

132 : 

133 A callable ``f(x)`` that returns interpolated values of ``da`` at ``x``. 

134 

135 Examples 

136 -------- 

137 

138 .. plot:: :context: close-figs 

139 

140 >>> x = sc.linspace(dim='x', start=0.1, stop=1.4, num=4, unit='rad') 

141 >>> da = sc.DataArray(sc.sin(x), coords={'x': x}) 

142 

143 >>> from scipp.scipy.interpolate import interp1d 

144 >>> f = interp1d(da, 'x') 

145 

146 >>> xnew = sc.linspace(dim='x', start=0.1, stop=1.4, num=12, unit='rad') 

147 >>> f(xnew) # use interpolation function returned by `interp1d` 

148 <scipp.DataArray> 

149 Dimensions: Sizes[x:12, ] 

150 Coordinates: 

151 * x float64 [rad] (x) [0.1, 0.218182, ..., 1.28182, 1.4] 

152 Data: 

153 float64 [dimensionless] (x) [0.0998334, 0.211262, ..., 0.941144, 0.98545] 

154 

155 >>> f(xnew, midpoints=True) 

156 <scipp.DataArray> 

157 Dimensions: Sizes[x:11, ] 

158 Coordinates: 

159 * x float64 [rad] (x [bin-edge]) [0.1, 0.218182, ..., 1.28182, 1.4] 

160 Data: 

161 float64 [dimensionless] (x) [0.155548, 0.266977, ..., 0.918992, 0.963297] 

162 

163 .. plot:: :context: close-figs 

164 

165 >>> sc.plot({'original':da, 

166 ... 'interp1d':f(xnew), 

167 ... 'interp1d-midpoints':f(xnew, midpoints=True)}) 

168 """ # noqa: E501 

169 import scipy.interpolate as inter 

170 

171 da = _drop_masked(da, dim) 

172 

173 def func(xnew: Variable, *, midpoints: bool = False) -> DataArray: 

174 """Compute interpolation function defined by ``interp1d`` 

175 at interpolation points. 

176 

177 Parameters 

178 ---------- 

179 xnew: 

180 Interpolation points. 

181 midpoints: 

182 Interpolate at midpoints of given points. 

183 The result will be a histogram. 

184 Default is ``False``. 

185 

186 Returns 

187 ------- 

188 : 

189 Interpolated data array with new coord given by interpolation points 

190 and data given by interpolation function evaluated at the 

191 interpolation points (or evaluated at the midpoints of the given points). 

192 """ 

193 if xnew.unit != da.coords[dim].unit: 

194 raise UnitError( 

195 f"Unit of interpolation points '{xnew.unit}' does not match unit " 

196 f"'{da.coords[dim].unit}' of points defining the interpolation " 

197 "function along dimension '{dim}'." 

198 ) 

199 if xnew.dim != dim: 

200 raise DimensionError( 

201 f"Dimension of interpolation points '{xnew.dim}' does not match " 

202 f"interpolation dimension '{dim}'" 

203 ) 

204 f = inter.interp1d( 

205 x=_as_interpolation_type(da.coords[dim].values), 

206 y=da.values, 

207 kind=kind, 

208 fill_value=fill_value, 

209 **kwargs, 

210 ) 

211 x_ = _as_interpolation_type(_midpoints(xnew, dim) if midpoints else xnew) 

212 sizes = da.sizes 

213 sizes[dim] = x_.sizes[dim] 

214 # ynew is created in this manner to allow for creation of structured dtypes, 

215 # which is not possible using scipp.array 

216 ynew = empty(sizes=sizes, unit=da.unit, dtype=da.dtype) 

217 ynew.values = f(x_.values) 

218 return DataArray(data=ynew, coords={dim: xnew}) 

219 

220 return func 

221 

222 

223class _Interp1dImpl(Protocol): 

224 def __call__(self, xnew: Variable, *, midpoints: bool = False) -> DataArray: ... 

225 

226 

227__all__ = ['interp1d']