Coverage for install/scipp/curve_fit.py: 82%
139 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-01 01:59 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-12-01 01:59 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
4from collections.abc import Callable, Mapping, Sequence
5from functools import partial
6from inspect import getfullargspec, isfunction
7from numbers import Real
9import numpy as np
11from .core import (
12 BinEdgeError,
13 DataArray,
14 DataGroup,
15 DimensionError,
16 Variable,
17 array,
18 scalar,
19 stddevs,
20 zeros,
21)
24def _wrap_scipp_func(f, p0):
25 p = {k: scalar(0.0, unit=v.unit) for k, v in p0.items()}
27 def func(x, *args):
28 for k, v in zip(p, args, strict=True):
29 p[k].value = v
30 return f(**x, **p).values
32 return func
35def _wrap_numpy_func(f, param_names, coord_names):
36 def func(x, *args):
37 # If there is only one predictor variable x might be a 1D array.
38 # Make x 2D for consistency.
39 if len(x.shape) == 1:
40 x = x.reshape(1, -1)
41 c = dict(zip(coord_names, x, strict=True))
42 p = dict(zip(param_names, args, strict=True))
43 return f(**c, **p)
45 return func
48def _get_sigma(da):
49 if da.variances is None:
50 return None
52 sigma = stddevs(da).values
53 if not sigma.all():
54 raise ValueError(
55 'There is a 0 in the input variances. This would break the optimizer. '
56 'Mask the offending elements, remove them, or assign a meaningful '
57 'variance if possible before calling curve_fit.'
58 )
59 return sigma
62def _datagroup_outputs(da, p0, map_over, pdata, covdata):
63 variances = np.diagonal(covdata, axis1=-2, axis2=-1)
64 dg = DataGroup(
65 {
66 p: DataArray(
67 data=array(
68 dims=map_over,
69 values=pdata[..., i] if pdata.ndim > 1 else pdata[i],
70 variances=variances[..., i] if variances.ndim > 1 else variances[i],
71 unit=v0.unit,
72 ),
73 )
74 for i, (p, v0) in enumerate(p0.items())
75 }
76 )
77 dgcov = DataGroup(
78 {
79 p: DataGroup(
80 {
81 q: DataArray(
82 data=array(
83 dims=map_over,
84 values=covdata[..., i, j]
85 if covdata.ndim > 2
86 else covdata[i, j],
87 unit=v0.unit * u0.unit,
88 ),
89 )
90 for j, (q, u0) in enumerate(p0.items())
91 }
92 )
93 for i, (p, v0) in enumerate(p0.items())
94 }
95 )
96 for c in da.coords:
97 if set(map_over).intersection(da.coords[c].dims):
98 for p in dg:
99 dg[p].coords[c] = da.coords[c]
100 for q in dgcov[p]:
101 dgcov[p][q].coords[c] = da.coords[c]
102 for m in da.masks:
103 # Drop masks that don't fit the output data
104 if set(da.masks[m].dims).issubset(set(dg.dims)):
105 for p in dg:
106 dg[p].masks[m] = da.masks[m]
107 for q in dgcov[p]:
108 dgcov[p][q].masks[m] = da.masks[m]
109 return dg, dgcov
112def _prepare_numpy_outputs(da, p0, map_over):
113 shape = [da.sizes[d] for d in map_over]
114 dg = np.empty([*shape, len(p0)])
115 dgcov = np.empty(shape + 2 * [len(p0)])
116 return dg, dgcov
119def _make_defaults(f, coords, p0):
120 spec = getfullargspec(f)
121 non_default_args = (
122 spec.args[: -len(spec.defaults)] if spec.defaults is not None else spec.args
123 )
124 if not isfunction(f) and not isinstance(f, partial):
125 # f is a class with a __call__ method,
126 # first argument is 'self', exclude it.
127 non_default_args = non_default_args[1:]
128 args = {*non_default_args, *spec.kwonlyargs} - set(spec.kwonlydefaults or ())
129 if not set(coords).issubset(args):
130 raise ValueError("Function must take the provided coords as arguments")
131 return {
132 **{a: scalar(1.0) for a in args - set(coords)},
133 **{
134 k: v if isinstance(v, Variable) else scalar(v)
135 for k, v in (p0 or {}).items()
136 },
137 }
140def _get_specific_bounds(bounds, name, unit):
141 if name not in bounds:
142 return -scalar(np.inf, unit=unit), scalar(np.inf, unit=unit)
143 b = bounds[name]
144 if len(b) != 2:
145 raise ValueError(
146 "Parameter bounds must be given as a tuple of length 2. "
147 f"Got a collection of length {len(b)} as bounds for '{name}'."
148 )
149 if (
150 b[0] is not None
151 and b[1] is not None
152 and isinstance(b[0], Variable) ^ isinstance(b[1], Variable)
153 ):
154 raise ValueError(
155 f"Bounds cannot mix Scipp variables and other number types, "
156 f"got {type(b[0])} and {type(b[1])}"
157 )
158 le = -scalar(np.inf, unit=unit) if b[0] is None else b[0]
159 ri = scalar(np.inf, unit=unit) if b[1] is None else b[1]
160 le, ri = (
161 v.to(unit=unit) if isinstance(v, Variable) else scalar(v).to(unit=unit)
162 for v in (le, ri)
163 )
164 return le, ri
167def _parse_bounds(bounds, p0):
168 return {k: _get_specific_bounds(bounds or {}, k, v.unit) for k, v in p0.items()}
171def _reshape_bounds(bounds):
172 left, right = zip(*bounds.values(), strict=True)
173 left, right = [le.value for le in left], [ri.value for ri in right]
174 if all(le == -np.inf and ri == np.inf for le, ri in zip(left, right, strict=True)):
175 return -np.inf, np.inf
176 return left, right
179def _curve_fit(
180 f,
181 da,
182 p0,
183 bounds,
184 map_over,
185 unsafe_numpy_f,
186 out,
187 **kwargs,
188):
189 dg, dgcov = out
191 if len(map_over) > 0:
192 dim = map_over[0]
193 for i in range(da.sizes[dim]):
194 _curve_fit(
195 f,
196 da[dim, i],
197 {k: v[dim, i] if dim in v.dims else v for k, v in p0.items()},
198 {
199 k: (
200 le[dim, i] if dim in le.dims else le,
201 ri[dim, i] if dim in ri.dims else ri,
202 )
203 for k, (le, ri) in bounds.items()
204 },
205 map_over[1:],
206 unsafe_numpy_f,
207 (dg[i], dgcov[i]),
208 **kwargs,
209 )
211 return
213 for k, v in p0.items():
214 if v.shape != ():
215 raise DimensionError(f'Parameter {k} has unexpected dimensions {v.dims}')
217 for k, (le, ri) in bounds.items():
218 if le.shape != ():
219 raise DimensionError(
220 f'Left bound of parameter {k} has unexpected dimensions {le.dims}'
221 )
222 if ri.shape != ():
223 raise DimensionError(
224 f'Right bound of parameter {k} has unexpected dimensions {ri.dims}'
225 )
227 fda = da.flatten(to='row')
228 if len(fda.masks) > 0:
229 _mask = zeros(dims=fda.dims, shape=fda.shape, dtype='bool')
230 for mask in fda.masks.values():
231 _mask |= mask
232 fda = fda[~_mask]
234 if not unsafe_numpy_f:
235 # Making the coords into a dict improves runtime,
236 # probably because of pybind overhead.
237 X = dict(fda.coords)
238 else:
239 X = np.vstack([c.values for c in fda.coords.values()], dtype='float')
241 import scipy.optimize as opt
243 if len(fda) < len(dg):
244 # More parameters than data points, unable to fit, abort.
245 dg[:] = np.nan
246 dgcov[:] = np.nan
247 return
249 try:
250 popt, pcov = opt.curve_fit(
251 f,
252 X,
253 fda.data.values,
254 [v.value for v in p0.values()],
255 sigma=_get_sigma(fda),
256 bounds=_reshape_bounds(bounds),
257 **kwargs,
258 )
259 except RuntimeError as err:
260 if hasattr(err, 'message') and 'Optimal parameters not found:' in err.message:
261 popt = np.array([np.nan for p in p0])
262 pcov = np.array([[np.nan for q in p0] for p in p0])
263 else:
264 raise err
266 dg[:] = popt
267 dgcov[:] = pcov
270def curve_fit(
271 coords: Sequence[str] | Mapping[str, str | Variable],
272 f: Callable,
273 da: DataArray,
274 *,
275 p0: dict[str, Variable | Real] | None = None,
276 bounds: dict[str, tuple[Variable, Variable] | tuple[Real, Real]] | None = None,
277 reduce_dims: Sequence[str] = (),
278 unsafe_numpy_f: bool = False,
279 **kwargs,
280) -> tuple[DataGroup, DataGroup]:
281 """Use non-linear least squares to fit a function, f, to data.
282 The function interface is similar to that of :py:func:`xarray.DataArray.curvefit`.
284 .. versionadded:: 23.12.0
286 This is a wrapper around :py:func:`scipy.optimize.curve_fit`. See there for
287 in depth documentation and keyword arguments. The differences are:
289 - Instead of separate ``xdata``, ``ydata``, and ``sigma`` arguments,
290 the input data array defines these, ``xdata`` by the coords on the data array,
291 ``ydata`` by ``da.data``, and ``sigma`` is defined as the square root of
292 the variances, if present, i.e., the standard deviations.
293 - The fit function ``f`` must work with scipp objects. This provides additional
294 safety over the underlying scipy function by ensuring units are consistent.
295 - The initial guess in ``p0`` must be provided as a dict, mapping from fit-function
296 parameter names to initial guesses.
297 - The parameter bounds must also be provided as a dict, like ``p0``.
298 - If the fit parameters are not dimensionless the initial guess must be
299 a scipp ``Variable`` with the correct unit.
300 - If the fit parameters are not dimensionless the bounds must be a variables
301 with the correct unit.
302 - The bounds and initial guesses may be scalars or arrays to allow the
303 initial guesses or bounds to vary in different regions.
304 If they are arrays they will be broadcasted to the shape of the output.
305 - The returned optimal parameter values ``popt`` and the covariance matrix ``pcov``
306 will have units if the initial guess has units. ``popt`` and
307 ``pcov`` are ``DataGroup`` and a ``DataGroup`` of ``DataGroup`` respectively.
308 They are indexed by the fit parameter names. The variances of the parameter values
309 in ``popt`` are set to the corresponding diagonal value in the covariance matrix.
311 Parameters
312 ----------
313 coords:
314 The coords that act as predictor variables in the fit.
315 If a mapping, the keys signify names of arguments to ``f`` and the values
316 signify coordinate names in ``da.coords``. If a sequence, the names of the
317 arguments to ``f`` and the names of the coords are taken to be the same.
318 To use a fit coordinate not present in ``da.coords``, pass it as a Variable.
319 f:
320 The model function, ``f(x, y..., a, b...)``. It must take all coordinates
321 listed in ``coords`` as arguments, otherwise a ``ValueError`` will be raised,
322 all *other* arguments will be treated as parameters of the fit.
323 da:
324 The values of the data array provide the dependent data. If the data array
325 stores variances then the standard deviations (square root of the variances)
326 are taken into account when fitting.
327 p0:
328 An optional dict of initial guesses for the parameters.
329 If None, then the initial values will all be dimensionless 1.
330 If the fit function cannot handle initial values of 1, in particular for
331 parameters that are not dimensionless, then typically a
332 :py:class:``scipp.UnitError`` is raised,
333 but details will depend on the function.
334 bounds:
335 Lower and upper bounds on parameters.
336 Defaults to no bounds.
337 Bounds are given as a dict of 2-tuples of (lower, upper) for each parameter
338 where lower and upper are either both Variables or plain numbers.
339 Parameters omitted from the ``bounds`` dict are unbounded.
340 reduce_dims:
341 Additional dimensions to aggregate while fitting.
342 If a dimension is not in ``reduce_dims``, or in the dimensions
343 of the coords used in the fit, then the values of the optimal parameters
344 will depend on that dimension. One fit will be performed for every slice,
345 and the data arrays in the output will have the dimension in their ``dims``.
346 If a dimension is passed to ``reduce_dims`` all data in that dimension
347 is instead aggregated in a single fit and the dimension will *not*
348 be present in the output.
349 unsafe_numpy_f:
350 By default the provided fit function ``f`` is assumed to take scipp Variables
351 as input and use scipp operations to produce a scipp Variable as output.
352 This has the safety advantage of unit checking.
353 However, in some cases it might be advantageous to implement ``f`` using Numpy
354 operations for performance reasons. This is particularly the case if the
355 curve fit will make many small curve fits involving relatively few data points.
356 In this case the pybind overhead on scipp operations might be considerable.
357 If ``unsafe_numpy_f`` is set to ``True`` then the arguments passed to ``f``
358 will be Numpy arrays instead of scipp Variables and the output of ``f`` is
359 expected to be a Numpy array.
361 Returns
362 -------
363 popt:
364 Optimal values for the parameters.
365 pcov:
366 The estimated covariance of popt.
368 See Also
369 --------
370 scipp.scipy.optimize.curve_fit:
371 Similar functionality for 1D fits.
373 Examples
374 --------
376 A 1D example
378 >>> def round(a, d):
379 ... 'Helper for the doctests'
380 ... return sc.round(10**d * a) / 10**d
382 >>> def func(x, a, b):
383 ... return a * sc.exp(-b * x)
385 >>> rng = np.random.default_rng(1234)
386 >>> x = sc.linspace(dim='x', start=0.0, stop=0.4, num=50, unit='m')
387 >>> y = func(x, a=5, b=17/sc.Unit('m'))
388 >>> y.values += 0.01 * rng.normal(size=50)
389 >>> da = sc.DataArray(y, coords={'x': x})
391 >>> from scipp import curve_fit
392 >>> popt, _ = curve_fit(['x'], func, da, p0 = {'b': 1.0 / sc.Unit('m')})
393 >>> round(sc.values(popt['a']), 3), round(sc.stddevs(popt['a']), 4)
394 (<scipp.DataArray>
395 Dimensions: Sizes[]
396 Data:
397 float64 [dimensionless] () 4.999
398 ,
399 <scipp.DataArray>
400 Dimensions: Sizes[]
401 Data:
402 float64 [dimensionless] () 0.0077
403 )
405 A 2D example where two coordinates participate in the fit
407 >>> def func(x, z, a, b):
408 ... return a * z * sc.exp(-b * x)
410 >>> x = sc.linspace(dim='x', start=0.0, stop=0.4, num=50, unit='m')
411 >>> z = sc.linspace(dim='z', start=0.0, stop=1, num=10)
412 >>> y = func(x, z, a=5, b=17/sc.Unit('m'))
413 >>> y.values += 0.01 * rng.normal(size=500).reshape(10, 50)
414 >>> da = sc.DataArray(y, coords={'x': x, 'z': z})
416 >>> popt, _ = curve_fit(['x', 'z'], func, da, p0 = {'b': 1.0 / sc.Unit('m')})
417 >>> round(sc.values(popt['a']), 3), round(sc.stddevs(popt['a']), 3)
418 (<scipp.DataArray>
419 Dimensions: Sizes[]
420 Data:
421 float64 [dimensionless] () 5.004
422 ,
423 <scipp.DataArray>
424 Dimensions: Sizes[]
425 Data:
426 float64 [dimensionless] () 0.004
427 )
429 A 2D example where only one coordinate participates in the fit and we
430 map over the dimension of the other coordinate.
431 Note that the value of one of the parameters is z-dependent
432 and that the output has a z-dimension
434 >>> def func(x, a, b):
435 ... return a * sc.exp(-b * x)
437 >>> x = sc.linspace(dim='xx', start=0.0, stop=0.4, num=50, unit='m')
438 >>> z = sc.linspace(dim='zz', start=0.0, stop=1, num=10)
439 >>> # Note that parameter a is z-dependent.
440 >>> y = func(x, a=z, b=17/sc.Unit('m'))
441 >>> y.values += 0.01 * rng.normal(size=500).reshape(10, 50)
442 >>> da = sc.DataArray(y, coords={'x': x, 'z': z})
444 >>> popt, _ = curve_fit(
445 ... ['x'], func, da,
446 ... p0 = {'b': 1.0 / sc.Unit('m')})
447 >>> # Note that approximately a = z
448 >>> round(sc.values(popt['a']), 2),
449 (<scipp.DataArray>
450 Dimensions: Sizes[zz:10, ]
451 Coordinates:
452 * z float64 [dimensionless] (zz) [0, 0.111111, ..., 0.888889, 1]
453 Data:
454 float64 [dimensionless] (zz) [-0.01, 0.11, ..., 0.89, 1.01]
455 ,)
457 Lastly, a 2D example where only one coordinate participates in the fit and
458 the other coordinate is reduced.
460 >>> def func(x, a, b):
461 ... return a * sc.exp(-b * x)
463 >>> x = sc.linspace(dim='xx', start=0.0, stop=0.4, num=50, unit='m')
464 >>> z = sc.linspace(dim='zz', start=0.0, stop=1, num=10)
465 >>> y = z * 0 + func(x, a=5, b=17/sc.Unit('m'))
466 >>> y.values += 0.01 * rng.normal(size=500).reshape(10, 50)
467 >>> da = sc.DataArray(y, coords={'x': x, 'z': z})
469 >>> popt, _ = curve_fit(
470 ... ['x'], func, da,
471 ... p0 = {'b': 1.0 / sc.Unit('m')}, reduce_dims=['zz'])
472 >>> round(sc.values(popt['a']), 3), round(sc.stddevs(popt['a']), 4)
473 (<scipp.DataArray>
474 Dimensions: Sizes[]
475 Data:
476 float64 [dimensionless] () 5
477 ,
478 <scipp.DataArray>
479 Dimensions: Sizes[]
480 Data:
481 float64 [dimensionless] () 0.0021
482 )
484 Note that the variance is about 10x lower in this example than in the
485 first 1D example. That is because in this example 50x10 points are used
486 in the fit while in the first example only 50 points were used in the fit.
488 """
490 if 'jac' in kwargs:
491 raise NotImplementedError(
492 "The 'jac' argument is not yet supported. "
493 "See https://github.com/scipp/scipp/issues/2544"
494 )
496 for arg in ['xdata', 'ydata', 'sigma']:
497 if arg in kwargs:
498 raise TypeError(
499 f"Invalid argument '{arg}', already defined by the input data array."
500 )
502 for c in coords:
503 if c in da.coords and da.coords.is_edges(c):
504 raise BinEdgeError("Cannot fit data array with bin-edge coordinate.")
506 if not isinstance(coords, dict):
507 if not all(isinstance(c, str) for c in coords):
508 raise TypeError(
509 'Expected sequence of coords to only contain values of type `str`.'
510 )
511 coords = {c: c for c in coords}
513 # Mapping from function argument names to fit variables
514 coords = {
515 arg: da.coords[coord] if isinstance(coord, str) else coord
516 for arg, coord in coords.items()
517 }
519 p0 = _make_defaults(f, coords.keys(), p0)
521 f = (
522 _wrap_scipp_func(f, p0)
523 if not unsafe_numpy_f
524 else _wrap_numpy_func(f, p0, coords.keys())
525 )
527 map_over = tuple(
528 d
529 for d in da.dims
530 if d not in reduce_dims and not any(d in c.dims for c in coords.values())
531 )
533 # Create a dataarray with only the participating coords
534 _da = DataArray(da.data, coords=coords, masks=da.masks)
536 out = _prepare_numpy_outputs(da, p0, map_over)
538 _curve_fit(
539 f,
540 _da,
541 p0,
542 _parse_bounds(bounds, p0),
543 map_over,
544 unsafe_numpy_f,
545 out,
546 **kwargs,
547 )
549 return _datagroup_outputs(da, p0, map_over, *out)
552__all__ = ['curve_fit']