Coverage for install/scipp/curve_fit.py: 78%
125 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
4from inspect import getfullargspec
5from numbers import Real
6from typing import Callable, Dict, Mapping, Optional, Sequence, Tuple, Union
8import numpy as np
10from .core import BinEdgeError, DataArray, DataGroup, Variable, array, scalar, stddevs
11from .units import default_unit, dimensionless
14def _as_scalar(obj, unit):
15 if unit == default_unit:
16 return obj
17 return scalar(value=obj, unit=unit)
20def _wrap_scipp_func(f, p_names, p_units):
21 _params = {k: _as_scalar(0.0, u) for k, u in zip(p_names, p_units)}
23 def func(x, *args):
24 for k, v in zip(p_names, args):
25 if isinstance(_params[k], Variable):
26 _params[k].value = v
27 else:
28 _params[k] = v
29 return f(**x, **_params).values
31 return func
34def _wrap_numpy_func(f, p_names, coord_names):
35 def func(x, *args):
36 # If there is only one predictor variable x might be a 1D array.
37 # Make x 2D for consistency.
38 if len(x.shape) == 1:
39 x = x.reshape(1, -1)
40 coords = dict(zip(coord_names, x))
41 params = dict(zip(p_names, args))
42 return f(**coords, **params)
44 return func
47def _get_sigma(da):
48 if da.variances is None:
49 return None
51 sigma = stddevs(da).values
52 if not sigma.all():
53 raise ValueError(
54 'There is a 0 in the input variances. This would break the optimizer. '
55 'Mask the offending elements, remove them, or assign a meaningful '
56 'variance if possible before calling curve_fit.'
57 )
58 return sigma
61def _datagroup_outputs(da, params, p_units, map_over, pdata, covdata):
62 variances = np.diagonal(covdata, axis1=-2, axis2=-1)
63 dg = DataGroup(
64 {
65 p: DataArray(
66 data=array(
67 dims=map_over,
68 values=pdata[..., i] if pdata.ndim > 1 else pdata[i],
69 variances=variances[..., i] if variances.ndim > 1 else variances[i],
70 unit=u,
71 ),
72 )
73 for i, (p, u) in enumerate(zip(params, p_units))
74 }
75 )
76 dgcov = DataGroup(
77 {
78 p: DataGroup(
79 {
80 q: DataArray(
81 data=array(
82 dims=map_over,
83 values=covdata[..., i, j]
84 if covdata.ndim > 2
85 else covdata[i, j],
86 unit=(
87 default_unit
88 if p_u == default_unit and q_u == default_unit
89 else p_u
90 if q_u == default_unit
91 else q_u
92 if p_u == default_unit
93 else p_u * q_u
94 ),
95 ),
96 )
97 for j, (q, q_u) in enumerate(zip(params, p_units))
98 }
99 )
100 for i, (p, p_u) in enumerate(zip(params, p_units))
101 }
102 )
103 for c in da.coords:
104 if set(map_over).intersection(da.coords[c].dims):
105 for p in dg:
106 dg[p].coords[c] = da.coords[c]
107 for q in dgcov[p]:
108 dgcov[p][q].coords[c] = da.coords[c]
110 for m in da.masks:
111 if set(map_over).intersection(da.masks[m].dims):
112 for p in dg:
113 dg[p].masks[c] = da.masks[c]
114 for q in dgcov[p]:
115 dgcov[p][q].masks[c] = da.masks[c]
116 return dg, dgcov
119def _prepare_numpy_outputs(da, params, map_over):
120 shape = [da.sizes[d] for d in map_over]
121 dg = np.empty([*shape, len(params)])
122 dgcov = np.empty(shape + 2 * [len(params)])
123 return dg, dgcov
126def _make_defaults(f, coords, params):
127 spec = getfullargspec(f)
128 all_args = {*spec.args, *spec.kwonlyargs}
129 if not set(coords).issubset(all_args):
130 raise ValueError("Function must take the provided coords as arguments")
131 default_arguments = dict(
132 zip(spec.args[-len(spec.defaults) :], spec.defaults) if spec.defaults else {},
133 **(spec.kwonlydefaults or {}),
134 )
135 return {
136 **{a: 1.0 for a in all_args - set(coords)},
137 **default_arguments,
138 **(params or {}),
139 }
142def _get_specific_bounds(bounds, name, unit) -> Tuple[float, float]:
143 if name not in bounds:
144 return -np.inf, np.inf
145 b = bounds[name]
146 if len(b) != 2:
147 raise ValueError(
148 "Parameter bounds must be given as a tuple of length 2. "
149 f"Got a collection of length {len(b)} as bounds for '{name}'."
150 )
151 if isinstance(b[0], Variable):
152 return (
153 b[0].to(unit=unit, dtype=float).value,
154 b[1].to(unit=unit, dtype=float).value,
155 )
156 return b
159def _parse_bounds(
160 bounds, params
161) -> Union[Tuple[float, float], Tuple[np.ndarray, np.ndarray]]:
162 if bounds is None:
163 return -np.inf, np.inf
165 bounds_tuples = [
166 _get_specific_bounds(
167 bounds, name, param.unit if isinstance(param, Variable) else dimensionless
168 )
169 for name, param in params.items()
170 ]
171 bounds_array = np.array(bounds_tuples).T
172 return bounds_array[0], bounds_array[1]
175def _curve_fit(
176 f,
177 da,
178 p0,
179 bounds,
180 map_over,
181 unsafe_numpy_f,
182 out,
183 **kwargs,
184):
185 dg, dgcov = out
187 if len(map_over) > 0:
188 dim = map_over[0]
189 for i in range(da.sizes[dim]):
190 _curve_fit(
191 f,
192 da[dim, i],
193 p0,
194 bounds,
195 map_over[1:],
196 unsafe_numpy_f,
197 (dg[i], dgcov[i]),
198 **kwargs,
199 )
201 return
203 fda = da.flatten(to='row')
205 for m in fda.masks.values():
206 fda = fda[~m]
208 if not unsafe_numpy_f:
209 # Making the coords into a dict improves runtime,
210 # probably because of pybind overhead.
211 X = dict(fda.coords)
212 else:
213 X = np.vstack([c.values for c in fda.coords.values()], dtype='float')
215 import scipy.optimize as opt
217 try:
218 popt, pcov = opt.curve_fit(
219 f,
220 X,
221 fda.data.values,
222 p0,
223 sigma=_get_sigma(fda),
224 bounds=bounds,
225 **kwargs,
226 )
227 except RuntimeError as err:
228 if hasattr(err, 'message') and 'Optimal parameters not found:' in err.message:
229 popt = np.array([np.nan for p in p0])
230 pcov = np.array([[np.nan for q in p0] for p in p0])
231 else:
232 raise err
234 dg[:] = popt
235 dgcov[:] = pcov
238def curve_fit(
239 coords: Union[Sequence[str], Mapping[str, Union[str, Variable]]],
240 f: Callable,
241 da: DataArray,
242 *,
243 p0: Optional[Dict[str, Union[Variable, Real]]] = None,
244 bounds: Optional[
245 Dict[str, Union[Tuple[Variable, Variable], Tuple[Real, Real]]]
246 ] = None,
247 reduce_dims: Sequence[str] = (),
248 unsafe_numpy_f: bool = False,
249 **kwargs,
250) -> Tuple[DataGroup, DataGroup]:
251 """Use non-linear least squares to fit a function, f, to data.
252 The function interface is similar to that of :py:func:`xarray.DataArray.curvefit`.
254 .. versionadded:: 23.12.0
256 This is a wrapper around :py:func:`scipy.optimize.curve_fit`. See there for
257 indepth documentation and keyword arguments. The differences are:
259 - Instead of separate ``xdata``, ``ydata``, and ``sigma`` arguments,
260 the input data array defines these, ``xdata`` by the coords on the data array,
261 ``ydata`` by ``da.data``, and ``sigma`` is defined as the square root of
262 the variances, if present, i.e., the standard deviations.
263 - The fit function ``f`` must work with scipp objects. This provides additional
264 safety over the underlying scipy function by ensuring units are consistent.
265 - The initial guess in ``p0`` must be provided as a dict, mapping from fit-function
266 parameter names to initial guesses.
267 - The parameter bounds must also be provided as a dict, like ``p0``.
268 - The fit parameters may be scalar scipp variables. In that case an initial guess
269 ``p0`` with the correct units must be provided.
270 - The returned optimal parameter values ``popt`` and the covariance matrix ``pcov``
271 will have units provided that the initial parameters have units. ``popt`` and
272 ``pcov`` are DataGroup and a DataGroup of DataGroup respectively. They are indexed
273 by the fit parameter names. The variances of the parameter values in ``popt``
274 are set to the corresponding diagonal value in the covariance matrix.
276 Parameters
277 ----------
278 coords:
279 The coords that act as predictor variables in the fit.
280 If a mapping, the keys signify names of arguments to ``f`` and the values
281 signify coordinate names in ``da.coords``. If a sequence, the names of the
282 arguments to ``f`` and the names of the coords are taken to be the same.
283 To use a fit coordinate not present in ``da.coords``, pass it as a Variable.
284 f:
285 The model function, ``f(x, y..., a, b...)``. It must take all coordinates
286 listed in ``coords`` as arguments, otherwise a ``ValueError`` will be raised,
287 all *other* arguments will be treated as parameters of the fit.
288 da:
289 The values of the data array provide the dependent data. If the data array
290 stores variances then the standard deviations (square root of the variances)
291 are taken into account when fitting.
292 p0:
293 An optional dict of initial guesses for the parameters.
294 If None, then the initial values will all be dimensionless 1.
295 If the fit function cannot handle initial values of 1, in particular for
296 parameters that are not dimensionless, then typically a
297 :py:class:``scipp.UnitError`` is raised,
298 but details will depend on the function.
299 bounds:
300 Lower and upper bounds on parameters.
301 Defaults to no bounds.
302 Bounds are given as a dict of 2-tuples of (lower, upper) for each parameter
303 where lower and upper are either both Variables or plain numbers.
304 Parameters omitted from the ``bounds`` dict are unbounded.
305 reduce_dims:
306 Additional dimensions to aggregate while fitting.
307 If a dimension is not in ``reduce_dims``, or in the dimensions
308 of the coords used in the fit, then the values of the optimal parameters
309 will depend on that dimension. One fit will be performed for every slice,
310 and the data arrays in the output will have the dimension in their ``dims``.
311 If a dimension is passed to ``reduce_dims`` all data in that dimension
312 is instead aggregated in a single fit and the dimension will *not*
313 be present in the output.
314 unsafe_numpy_f:
315 By default the provided fit function ``f`` is assumed to take scipp Variables
316 as input and use scipp operations to produce a scipp Variable as output.
317 This has the safety advantage of unit checking.
318 However, in some cases it might be advantageous to implement ``f`` using Numpy
319 operations for performance reasons. This is particularly the case if the
320 curve fit will make many small curve fits involving relatively few data points.
321 In this case the pybind overhead on scipp operations might be considerable.
322 If ``unsafe_numpy_f`` is set to ``True`` then the arguments passed to ``f``
323 will be Numpy arrays instead of scipp Variables and the output of ``f`` is
324 expected to be a Numpy array.
326 Returns
327 -------
328 popt:
329 Optimal values for the parameters.
330 pcov:
331 The estimated covariance of popt.
333 See Also
334 --------
335 scipp.scipy.optimize.curve_fit:
336 Similar functionality for 1D fits.
338 Examples
339 --------
341 A 1D example
343 >>> def round(a, d):
344 ... 'Helper for the doctests'
345 ... return sc.round(10**d * a) / 10**d
347 >>> def func(x, a, b):
348 ... return a * sc.exp(-b * x)
350 >>> rng = np.random.default_rng(1234)
351 >>> x = sc.linspace(dim='x', start=0.0, stop=0.4, num=50, unit='m')
352 >>> y = func(x, a=5, b=17/sc.Unit('m'))
353 >>> y.values += 0.01 * rng.normal(size=50)
354 >>> da = sc.DataArray(y, coords={'x': x})
356 >>> from scipp import curve_fit
357 >>> popt, _ = curve_fit(['x'], func, da, p0 = {'b': 1.0 / sc.Unit('m')})
358 >>> round(sc.values(popt['a']), 3), round(sc.stddevs(popt['a']), 4)
359 (<scipp.DataArray>
360 Dimensions: Sizes[]
361 Data:
362 float64 [dimensionless] () 4.999
363 ,
364 <scipp.DataArray>
365 Dimensions: Sizes[]
366 Data:
367 float64 [dimensionless] () 0.0077
368 )
370 A 2D example where two coordinates participate in the fit
372 >>> def func(x, z, a, b):
373 ... return a * z * sc.exp(-b * x)
375 >>> x = sc.linspace(dim='x', start=0.0, stop=0.4, num=50, unit='m')
376 >>> z = sc.linspace(dim='z', start=0.0, stop=1, num=10)
377 >>> y = func(x, z, a=5, b=17/sc.Unit('m'))
378 >>> y.values += 0.01 * rng.normal(size=500).reshape(10, 50)
379 >>> da = sc.DataArray(y, coords={'x': x, 'z': z})
381 >>> popt, _ = curve_fit(['x', 'z'], func, da, p0 = {'b': 1.0 / sc.Unit('m')})
382 >>> round(sc.values(popt['a']), 3), round(sc.stddevs(popt['a']), 3)
383 (<scipp.DataArray>
384 Dimensions: Sizes[]
385 Data:
386 float64 [dimensionless] () 5.004
387 ,
388 <scipp.DataArray>
389 Dimensions: Sizes[]
390 Data:
391 float64 [dimensionless] () 0.004
392 )
394 A 2D example where only one coordinate participates in the fit and we
395 map over the dimension of the other coordinate.
396 Note that the value of one of the parameters is z-dependent
397 and that the output has a z-dimension
399 >>> def func(x, a, b):
400 ... return a * sc.exp(-b * x)
402 >>> x = sc.linspace(dim='xx', start=0.0, stop=0.4, num=50, unit='m')
403 >>> z = sc.linspace(dim='zz', start=0.0, stop=1, num=10)
404 >>> # Note that parameter a is z-dependent.
405 >>> y = func(x, a=z, b=17/sc.Unit('m'))
406 >>> y.values += 0.01 * rng.normal(size=500).reshape(10, 50)
407 >>> da = sc.DataArray(y, coords={'x': x, 'z': z})
409 >>> popt, _ = curve_fit(
410 ... ['x'], func, da,
411 ... p0 = {'b': 1.0 / sc.Unit('m')})
412 >>> # Note that approximately a = z
413 >>> round(sc.values(popt['a']), 2),
414 (<scipp.DataArray>
415 Dimensions: Sizes[zz:10, ]
416 Coordinates:
417 * z float64 [dimensionless] (zz) [0, 0.111111, ..., 0.888889, 1]
418 Data:
419 float64 [dimensionless] (zz) [-0.01, 0.11, ..., 0.89, 1.01]
420 ,)
422 Lastly, a 2D example where only one coordinate participates in the fit and
423 the other coordinate is reduced.
425 >>> def func(x, a, b):
426 ... return a * sc.exp(-b * x)
428 >>> x = sc.linspace(dim='xx', start=0.0, stop=0.4, num=50, unit='m')
429 >>> z = sc.linspace(dim='zz', start=0.0, stop=1, num=10)
430 >>> y = z * 0 + func(x, a=5, b=17/sc.Unit('m'))
431 >>> y.values += 0.01 * rng.normal(size=500).reshape(10, 50)
432 >>> da = sc.DataArray(y, coords={'x': x, 'z': z})
434 >>> popt, _ = curve_fit(
435 ... ['x'], func, da,
436 ... p0 = {'b': 1.0 / sc.Unit('m')}, reduce_dims=['zz'])
437 >>> round(sc.values(popt['a']), 3), round(sc.stddevs(popt['a']), 4)
438 (<scipp.DataArray>
439 Dimensions: Sizes[]
440 Data:
441 float64 [dimensionless] () 5
442 ,
443 <scipp.DataArray>
444 Dimensions: Sizes[]
445 Data:
446 float64 [dimensionless] () 0.0021
447 )
449 Note that the variance is about 10x lower in this example than in the
450 first 1D example. That is because in this example 50x10 points are used
451 in the fit while in the first example only 50 points were used in the fit.
453 """
455 if 'jac' in kwargs:
456 raise NotImplementedError(
457 "The 'jac' argument is not yet supported. "
458 "See https://github.com/scipp/scipp/issues/2544"
459 )
461 for arg in ['xdata', 'ydata', 'sigma']:
462 if arg in kwargs:
463 raise TypeError(
464 f"Invalid argument '{arg}', already defined by the input data array."
465 )
467 for c in coords:
468 if c in da.coords and da.coords.is_edges(c):
469 raise BinEdgeError("Cannot fit data array with bin-edge coordinate.")
471 if not isinstance(coords, dict):
472 if not all(isinstance(c, str) for c in coords):
473 raise TypeError(
474 'Expected sequence of coords to only contain values of type `str`.'
475 )
476 coords = {c: c for c in coords}
478 # Mapping from function argument names to fit variables
479 coords = {
480 arg: da.coords[coord] if isinstance(coord, str) else coord
481 for arg, coord in coords.items()
482 }
484 p0 = _make_defaults(f, coords.keys(), p0)
485 p_units = [p.unit if isinstance(p, Variable) else default_unit for p in p0.values()]
487 f = (
488 _wrap_scipp_func(f, p0, p_units)
489 if not unsafe_numpy_f
490 else _wrap_numpy_func(f, p0, coords.keys())
491 )
493 map_over = tuple(
494 d
495 for d in da.dims
496 if d not in reduce_dims and not any(d in c.dims for c in coords.values())
497 )
499 dims_participating_in_fit = set(da.dims) - set(map_over)
501 # Create a dataarray with only the participating coords and masks
502 # and coordinate names matching the argument names of f.
503 _da = DataArray(
504 da.data,
505 coords=coords,
506 masks={
507 m: da.masks[m]
508 for m in da.masks
509 if dims_participating_in_fit.intersection(da.masks[m].dims)
510 },
511 )
513 out = _prepare_numpy_outputs(da, p0, map_over)
515 _curve_fit(
516 f,
517 _da,
518 [p.value if isinstance(p, Variable) else p for p in p0.values()],
519 _parse_bounds(bounds, p0),
520 map_over,
521 unsafe_numpy_f,
522 out,
523 **kwargs,
524 )
526 return _datagroup_outputs(da, p0, p_units, map_over, *out)
529__all__ = ['curve_fit']