Coverage for install/scipp/curve_fit.py: 78%

125 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-04-28 01:28 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3 

4from inspect import getfullargspec 

5from numbers import Real 

6from typing import Callable, Dict, Mapping, Optional, Sequence, Tuple, Union 

7 

8import numpy as np 

9 

10from .core import BinEdgeError, DataArray, DataGroup, Variable, array, scalar, stddevs 

11from .units import default_unit, dimensionless 

12 

13 

14def _as_scalar(obj, unit): 

15 if unit == default_unit: 

16 return obj 

17 return scalar(value=obj, unit=unit) 

18 

19 

20def _wrap_scipp_func(f, p_names, p_units): 

21 _params = {k: _as_scalar(0.0, u) for k, u in zip(p_names, p_units)} 

22 

23 def func(x, *args): 

24 for k, v in zip(p_names, args): 

25 if isinstance(_params[k], Variable): 

26 _params[k].value = v 

27 else: 

28 _params[k] = v 

29 return f(**x, **_params).values 

30 

31 return func 

32 

33 

34def _wrap_numpy_func(f, p_names, coord_names): 

35 def func(x, *args): 

36 # If there is only one predictor variable x might be a 1D array. 

37 # Make x 2D for consistency. 

38 if len(x.shape) == 1: 

39 x = x.reshape(1, -1) 

40 coords = dict(zip(coord_names, x)) 

41 params = dict(zip(p_names, args)) 

42 return f(**coords, **params) 

43 

44 return func 

45 

46 

47def _get_sigma(da): 

48 if da.variances is None: 

49 return None 

50 

51 sigma = stddevs(da).values 

52 if not sigma.all(): 

53 raise ValueError( 

54 'There is a 0 in the input variances. This would break the optimizer. ' 

55 'Mask the offending elements, remove them, or assign a meaningful ' 

56 'variance if possible before calling curve_fit.' 

57 ) 

58 return sigma 

59 

60 

61def _datagroup_outputs(da, params, p_units, map_over, pdata, covdata): 

62 variances = np.diagonal(covdata, axis1=-2, axis2=-1) 

63 dg = DataGroup( 

64 { 

65 p: DataArray( 

66 data=array( 

67 dims=map_over, 

68 values=pdata[..., i] if pdata.ndim > 1 else pdata[i], 

69 variances=variances[..., i] if variances.ndim > 1 else variances[i], 

70 unit=u, 

71 ), 

72 ) 

73 for i, (p, u) in enumerate(zip(params, p_units)) 

74 } 

75 ) 

76 dgcov = DataGroup( 

77 { 

78 p: DataGroup( 

79 { 

80 q: DataArray( 

81 data=array( 

82 dims=map_over, 

83 values=covdata[..., i, j] 

84 if covdata.ndim > 2 

85 else covdata[i, j], 

86 unit=( 

87 default_unit 

88 if p_u == default_unit and q_u == default_unit 

89 else p_u 

90 if q_u == default_unit 

91 else q_u 

92 if p_u == default_unit 

93 else p_u * q_u 

94 ), 

95 ), 

96 ) 

97 for j, (q, q_u) in enumerate(zip(params, p_units)) 

98 } 

99 ) 

100 for i, (p, p_u) in enumerate(zip(params, p_units)) 

101 } 

102 ) 

103 for c in da.coords: 

104 if set(map_over).intersection(da.coords[c].dims): 

105 for p in dg: 

106 dg[p].coords[c] = da.coords[c] 

107 for q in dgcov[p]: 

108 dgcov[p][q].coords[c] = da.coords[c] 

109 

110 for m in da.masks: 

111 if set(map_over).intersection(da.masks[m].dims): 

112 for p in dg: 

113 dg[p].masks[c] = da.masks[c] 

114 for q in dgcov[p]: 

115 dgcov[p][q].masks[c] = da.masks[c] 

116 return dg, dgcov 

117 

118 

119def _prepare_numpy_outputs(da, params, map_over): 

120 shape = [da.sizes[d] for d in map_over] 

121 dg = np.empty([*shape, len(params)]) 

122 dgcov = np.empty(shape + 2 * [len(params)]) 

123 return dg, dgcov 

124 

125 

126def _make_defaults(f, coords, params): 

127 spec = getfullargspec(f) 

128 all_args = {*spec.args, *spec.kwonlyargs} 

129 if not set(coords).issubset(all_args): 

130 raise ValueError("Function must take the provided coords as arguments") 

131 default_arguments = dict( 

132 zip(spec.args[-len(spec.defaults) :], spec.defaults) if spec.defaults else {}, 

133 **(spec.kwonlydefaults or {}), 

134 ) 

135 return { 

136 **{a: 1.0 for a in all_args - set(coords)}, 

137 **default_arguments, 

138 **(params or {}), 

139 } 

140 

141 

142def _get_specific_bounds(bounds, name, unit) -> Tuple[float, float]: 

143 if name not in bounds: 

144 return -np.inf, np.inf 

145 b = bounds[name] 

146 if len(b) != 2: 

147 raise ValueError( 

148 "Parameter bounds must be given as a tuple of length 2. " 

149 f"Got a collection of length {len(b)} as bounds for '{name}'." 

150 ) 

151 if isinstance(b[0], Variable): 

152 return ( 

153 b[0].to(unit=unit, dtype=float).value, 

154 b[1].to(unit=unit, dtype=float).value, 

155 ) 

156 return b 

157 

158 

159def _parse_bounds( 

160 bounds, params 

161) -> Union[Tuple[float, float], Tuple[np.ndarray, np.ndarray]]: 

162 if bounds is None: 

163 return -np.inf, np.inf 

164 

165 bounds_tuples = [ 

166 _get_specific_bounds( 

167 bounds, name, param.unit if isinstance(param, Variable) else dimensionless 

168 ) 

169 for name, param in params.items() 

170 ] 

171 bounds_array = np.array(bounds_tuples).T 

172 return bounds_array[0], bounds_array[1] 

173 

174 

175def _curve_fit( 

176 f, 

177 da, 

178 p0, 

179 bounds, 

180 map_over, 

181 unsafe_numpy_f, 

182 out, 

183 **kwargs, 

184): 

185 dg, dgcov = out 

186 

187 if len(map_over) > 0: 

188 dim = map_over[0] 

189 for i in range(da.sizes[dim]): 

190 _curve_fit( 

191 f, 

192 da[dim, i], 

193 p0, 

194 bounds, 

195 map_over[1:], 

196 unsafe_numpy_f, 

197 (dg[i], dgcov[i]), 

198 **kwargs, 

199 ) 

200 

201 return 

202 

203 fda = da.flatten(to='row') 

204 

205 for m in fda.masks.values(): 

206 fda = fda[~m] 

207 

208 if not unsafe_numpy_f: 

209 # Making the coords into a dict improves runtime, 

210 # probably because of pybind overhead. 

211 X = dict(fda.coords) 

212 else: 

213 X = np.vstack([c.values for c in fda.coords.values()], dtype='float') 

214 

215 import scipy.optimize as opt 

216 

217 try: 

218 popt, pcov = opt.curve_fit( 

219 f, 

220 X, 

221 fda.data.values, 

222 p0, 

223 sigma=_get_sigma(fda), 

224 bounds=bounds, 

225 **kwargs, 

226 ) 

227 except RuntimeError as err: 

228 if hasattr(err, 'message') and 'Optimal parameters not found:' in err.message: 

229 popt = np.array([np.nan for p in p0]) 

230 pcov = np.array([[np.nan for q in p0] for p in p0]) 

231 else: 

232 raise err 

233 

234 dg[:] = popt 

235 dgcov[:] = pcov 

236 

237 

238def curve_fit( 

239 coords: Union[Sequence[str], Mapping[str, Union[str, Variable]]], 

240 f: Callable, 

241 da: DataArray, 

242 *, 

243 p0: Optional[Dict[str, Union[Variable, Real]]] = None, 

244 bounds: Optional[ 

245 Dict[str, Union[Tuple[Variable, Variable], Tuple[Real, Real]]] 

246 ] = None, 

247 reduce_dims: Sequence[str] = (), 

248 unsafe_numpy_f: bool = False, 

249 **kwargs, 

250) -> Tuple[DataGroup, DataGroup]: 

251 """Use non-linear least squares to fit a function, f, to data. 

252 The function interface is similar to that of :py:func:`xarray.DataArray.curvefit`. 

253 

254 .. versionadded:: 23.12.0 

255 

256 This is a wrapper around :py:func:`scipy.optimize.curve_fit`. See there for 

257 indepth documentation and keyword arguments. The differences are: 

258 

259 - Instead of separate ``xdata``, ``ydata``, and ``sigma`` arguments, 

260 the input data array defines these, ``xdata`` by the coords on the data array, 

261 ``ydata`` by ``da.data``, and ``sigma`` is defined as the square root of 

262 the variances, if present, i.e., the standard deviations. 

263 - The fit function ``f`` must work with scipp objects. This provides additional 

264 safety over the underlying scipy function by ensuring units are consistent. 

265 - The initial guess in ``p0`` must be provided as a dict, mapping from fit-function 

266 parameter names to initial guesses. 

267 - The parameter bounds must also be provided as a dict, like ``p0``. 

268 - The fit parameters may be scalar scipp variables. In that case an initial guess 

269 ``p0`` with the correct units must be provided. 

270 - The returned optimal parameter values ``popt`` and the covariance matrix ``pcov`` 

271 will have units provided that the initial parameters have units. ``popt`` and 

272 ``pcov`` are DataGroup and a DataGroup of DataGroup respectively. They are indexed 

273 by the fit parameter names. The variances of the parameter values in ``popt`` 

274 are set to the corresponding diagonal value in the covariance matrix. 

275 

276 Parameters 

277 ---------- 

278 coords: 

279 The coords that act as predictor variables in the fit. 

280 If a mapping, the keys signify names of arguments to ``f`` and the values 

281 signify coordinate names in ``da.coords``. If a sequence, the names of the 

282 arguments to ``f`` and the names of the coords are taken to be the same. 

283 To use a fit coordinate not present in ``da.coords``, pass it as a Variable. 

284 f: 

285 The model function, ``f(x, y..., a, b...)``. It must take all coordinates 

286 listed in ``coords`` as arguments, otherwise a ``ValueError`` will be raised, 

287 all *other* arguments will be treated as parameters of the fit. 

288 da: 

289 The values of the data array provide the dependent data. If the data array 

290 stores variances then the standard deviations (square root of the variances) 

291 are taken into account when fitting. 

292 p0: 

293 An optional dict of initial guesses for the parameters. 

294 If None, then the initial values will all be dimensionless 1. 

295 If the fit function cannot handle initial values of 1, in particular for 

296 parameters that are not dimensionless, then typically a 

297 :py:class:``scipp.UnitError`` is raised, 

298 but details will depend on the function. 

299 bounds: 

300 Lower and upper bounds on parameters. 

301 Defaults to no bounds. 

302 Bounds are given as a dict of 2-tuples of (lower, upper) for each parameter 

303 where lower and upper are either both Variables or plain numbers. 

304 Parameters omitted from the ``bounds`` dict are unbounded. 

305 reduce_dims: 

306 Additional dimensions to aggregate while fitting. 

307 If a dimension is not in ``reduce_dims``, or in the dimensions 

308 of the coords used in the fit, then the values of the optimal parameters 

309 will depend on that dimension. One fit will be performed for every slice, 

310 and the data arrays in the output will have the dimension in their ``dims``. 

311 If a dimension is passed to ``reduce_dims`` all data in that dimension 

312 is instead aggregated in a single fit and the dimension will *not* 

313 be present in the output. 

314 unsafe_numpy_f: 

315 By default the provided fit function ``f`` is assumed to take scipp Variables 

316 as input and use scipp operations to produce a scipp Variable as output. 

317 This has the safety advantage of unit checking. 

318 However, in some cases it might be advantageous to implement ``f`` using Numpy 

319 operations for performance reasons. This is particularly the case if the 

320 curve fit will make many small curve fits involving relatively few data points. 

321 In this case the pybind overhead on scipp operations might be considerable. 

322 If ``unsafe_numpy_f`` is set to ``True`` then the arguments passed to ``f`` 

323 will be Numpy arrays instead of scipp Variables and the output of ``f`` is 

324 expected to be a Numpy array. 

325 

326 Returns 

327 ------- 

328 popt: 

329 Optimal values for the parameters. 

330 pcov: 

331 The estimated covariance of popt. 

332 

333 See Also 

334 -------- 

335 scipp.scipy.optimize.curve_fit: 

336 Similar functionality for 1D fits. 

337 

338 Examples 

339 -------- 

340 

341 A 1D example 

342 

343 >>> def round(a, d): 

344 ... 'Helper for the doctests' 

345 ... return sc.round(10**d * a) / 10**d 

346 

347 >>> def func(x, a, b): 

348 ... return a * sc.exp(-b * x) 

349 

350 >>> rng = np.random.default_rng(1234) 

351 >>> x = sc.linspace(dim='x', start=0.0, stop=0.4, num=50, unit='m') 

352 >>> y = func(x, a=5, b=17/sc.Unit('m')) 

353 >>> y.values += 0.01 * rng.normal(size=50) 

354 >>> da = sc.DataArray(y, coords={'x': x}) 

355 

356 >>> from scipp import curve_fit 

357 >>> popt, _ = curve_fit(['x'], func, da, p0 = {'b': 1.0 / sc.Unit('m')}) 

358 >>> round(sc.values(popt['a']), 3), round(sc.stddevs(popt['a']), 4) 

359 (<scipp.DataArray> 

360 Dimensions: Sizes[] 

361 Data: 

362 float64 [dimensionless] () 4.999 

363 , 

364 <scipp.DataArray> 

365 Dimensions: Sizes[] 

366 Data: 

367 float64 [dimensionless] () 0.0077 

368 ) 

369 

370 A 2D example where two coordinates participate in the fit 

371 

372 >>> def func(x, z, a, b): 

373 ... return a * z * sc.exp(-b * x) 

374 

375 >>> x = sc.linspace(dim='x', start=0.0, stop=0.4, num=50, unit='m') 

376 >>> z = sc.linspace(dim='z', start=0.0, stop=1, num=10) 

377 >>> y = func(x, z, a=5, b=17/sc.Unit('m')) 

378 >>> y.values += 0.01 * rng.normal(size=500).reshape(10, 50) 

379 >>> da = sc.DataArray(y, coords={'x': x, 'z': z}) 

380 

381 >>> popt, _ = curve_fit(['x', 'z'], func, da, p0 = {'b': 1.0 / sc.Unit('m')}) 

382 >>> round(sc.values(popt['a']), 3), round(sc.stddevs(popt['a']), 3) 

383 (<scipp.DataArray> 

384 Dimensions: Sizes[] 

385 Data: 

386 float64 [dimensionless] () 5.004 

387 , 

388 <scipp.DataArray> 

389 Dimensions: Sizes[] 

390 Data: 

391 float64 [dimensionless] () 0.004 

392 ) 

393 

394 A 2D example where only one coordinate participates in the fit and we 

395 map over the dimension of the other coordinate. 

396 Note that the value of one of the parameters is z-dependent 

397 and that the output has a z-dimension 

398 

399 >>> def func(x, a, b): 

400 ... return a * sc.exp(-b * x) 

401 

402 >>> x = sc.linspace(dim='xx', start=0.0, stop=0.4, num=50, unit='m') 

403 >>> z = sc.linspace(dim='zz', start=0.0, stop=1, num=10) 

404 >>> # Note that parameter a is z-dependent. 

405 >>> y = func(x, a=z, b=17/sc.Unit('m')) 

406 >>> y.values += 0.01 * rng.normal(size=500).reshape(10, 50) 

407 >>> da = sc.DataArray(y, coords={'x': x, 'z': z}) 

408 

409 >>> popt, _ = curve_fit( 

410 ... ['x'], func, da, 

411 ... p0 = {'b': 1.0 / sc.Unit('m')}) 

412 >>> # Note that approximately a = z 

413 >>> round(sc.values(popt['a']), 2), 

414 (<scipp.DataArray> 

415 Dimensions: Sizes[zz:10, ] 

416 Coordinates: 

417 * z float64 [dimensionless] (zz) [0, 0.111111, ..., 0.888889, 1] 

418 Data: 

419 float64 [dimensionless] (zz) [-0.01, 0.11, ..., 0.89, 1.01] 

420 ,) 

421 

422 Lastly, a 2D example where only one coordinate participates in the fit and 

423 the other coordinate is reduced. 

424 

425 >>> def func(x, a, b): 

426 ... return a * sc.exp(-b * x) 

427 

428 >>> x = sc.linspace(dim='xx', start=0.0, stop=0.4, num=50, unit='m') 

429 >>> z = sc.linspace(dim='zz', start=0.0, stop=1, num=10) 

430 >>> y = z * 0 + func(x, a=5, b=17/sc.Unit('m')) 

431 >>> y.values += 0.01 * rng.normal(size=500).reshape(10, 50) 

432 >>> da = sc.DataArray(y, coords={'x': x, 'z': z}) 

433 

434 >>> popt, _ = curve_fit( 

435 ... ['x'], func, da, 

436 ... p0 = {'b': 1.0 / sc.Unit('m')}, reduce_dims=['zz']) 

437 >>> round(sc.values(popt['a']), 3), round(sc.stddevs(popt['a']), 4) 

438 (<scipp.DataArray> 

439 Dimensions: Sizes[] 

440 Data: 

441 float64 [dimensionless] () 5 

442 , 

443 <scipp.DataArray> 

444 Dimensions: Sizes[] 

445 Data: 

446 float64 [dimensionless] () 0.0021 

447 ) 

448 

449 Note that the variance is about 10x lower in this example than in the 

450 first 1D example. That is because in this example 50x10 points are used 

451 in the fit while in the first example only 50 points were used in the fit. 

452 

453 """ 

454 

455 if 'jac' in kwargs: 

456 raise NotImplementedError( 

457 "The 'jac' argument is not yet supported. " 

458 "See https://github.com/scipp/scipp/issues/2544" 

459 ) 

460 

461 for arg in ['xdata', 'ydata', 'sigma']: 

462 if arg in kwargs: 

463 raise TypeError( 

464 f"Invalid argument '{arg}', already defined by the input data array." 

465 ) 

466 

467 for c in coords: 

468 if c in da.coords and da.coords.is_edges(c): 

469 raise BinEdgeError("Cannot fit data array with bin-edge coordinate.") 

470 

471 if not isinstance(coords, dict): 

472 if not all(isinstance(c, str) for c in coords): 

473 raise TypeError( 

474 'Expected sequence of coords to only contain values of type `str`.' 

475 ) 

476 coords = {c: c for c in coords} 

477 

478 # Mapping from function argument names to fit variables 

479 coords = { 

480 arg: da.coords[coord] if isinstance(coord, str) else coord 

481 for arg, coord in coords.items() 

482 } 

483 

484 p0 = _make_defaults(f, coords.keys(), p0) 

485 p_units = [p.unit if isinstance(p, Variable) else default_unit for p in p0.values()] 

486 

487 f = ( 

488 _wrap_scipp_func(f, p0, p_units) 

489 if not unsafe_numpy_f 

490 else _wrap_numpy_func(f, p0, coords.keys()) 

491 ) 

492 

493 map_over = tuple( 

494 d 

495 for d in da.dims 

496 if d not in reduce_dims and not any(d in c.dims for c in coords.values()) 

497 ) 

498 

499 dims_participating_in_fit = set(da.dims) - set(map_over) 

500 

501 # Create a dataarray with only the participating coords and masks 

502 # and coordinate names matching the argument names of f. 

503 _da = DataArray( 

504 da.data, 

505 coords=coords, 

506 masks={ 

507 m: da.masks[m] 

508 for m in da.masks 

509 if dims_participating_in_fit.intersection(da.masks[m].dims) 

510 }, 

511 ) 

512 

513 out = _prepare_numpy_outputs(da, p0, map_over) 

514 

515 _curve_fit( 

516 f, 

517 _da, 

518 [p.value if isinstance(p, Variable) else p for p in p0.values()], 

519 _parse_bounds(bounds, p0), 

520 map_over, 

521 unsafe_numpy_f, 

522 out, 

523 **kwargs, 

524 ) 

525 

526 return _datagroup_outputs(da, p0, p_units, map_over, *out) 

527 

528 

529__all__ = ['curve_fit']