Coverage for install/scipp/testing/assertions.py: 88%

127 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-24 01:51 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Jan-Lukas Wynen 

4"""Custom assertions for pytest-based tests. 

5 

6To get the best error messages, tell pytest to rewrite assertions in this module. 

7Place the following code in your ``conftest.py``: 

8 

9.. code-block:: python 

10 

11 pytest.register_assert_rewrite('scipp.testing.assertions') 

12 

13""" 

14 

15from collections.abc import Callable, Iterator, Mapping 

16from contextlib import contextmanager 

17from typing import Any, TypeVar 

18 

19import numpy as np 

20 

21from ..core import DataArray, DataGroup, Dataset, Variable 

22 

23# Exception notes are formatted as 'PREPOSITION {loc}', 

24# where 'loc' is set by the concrete assertion functions to indicate coords, attrs, etc. 

25# 'PREPOSITION' is replaced at the top level to produce exception messages like: 

26# 

27# [...] 

28# in coord 'x' 

29# of data group item 'b' 

30# of data group item 'a' 

31 

32_T = TypeVar('_T', Variable, DataArray) 

33 

34 

35def assert_allclose( 

36 a: _T, 

37 b: _T, 

38 rtol: Variable | None = None, 

39 atol: Variable | None = None, 

40 **kwargs: Any, 

41) -> None: 

42 """Raise an AssertionError if two objects don't have similar values 

43 or if their other properties are not identical. 

44 

45 Parameters 

46 ---------- 

47 a: 

48 The actual object to check. 

49 b: 

50 The desired, expected object. 

51 rtol: 

52 Tolerance value relative (to b). 

53 Can be a scalar or non-scalar. 

54 Cannot have variances. 

55 Defaults to scalar 1e-7 if unset. 

56 atol: 

57 Tolerance value absolute. 

58 Can be a scalar or non-scalar. 

59 Cannot have variances. 

60 Defaults to scalar 0 if unset and takes units from y arg. 

61 kwargs: 

62 Additional arguments to pass to :func:`numpy.testing.assert_allclose` 

63 which is used for comparing data. 

64 

65 Raises 

66 ------ 

67 AssertionError 

68 If the objects are not identical. 

69 """ 

70 return _assert_similar(_assert_allclose_impl, a, b, rtol=rtol, atol=atol, **kwargs) 

71 

72 

73def assert_identical(a: _T, b: _T) -> None: 

74 """Raise an AssertionError if two objects are not identical. 

75 

76 For Scipp objects, ``assert_identical(a, b)`` is equivalent to 

77 ``assert sc.identical(a, b, equal_nan=True)`` but produces a more precise 

78 error message in pytest. 

79 If this function is called with arguments that are not supported by 

80 :func:`scipp.identical`, it calls ``assert a == b``. 

81 

82 This function requires exact equality including equal types. 

83 For example, ``assert_identical(1, 1.0)`` will raise. 

84 

85 NaN elements of Scipp variables are treated as equal. 

86 

87 Parameters 

88 ---------- 

89 a: 

90 The actual object to check. 

91 b: 

92 The desired, expected object. 

93 

94 Raises 

95 ------ 

96 AssertionError 

97 If the objects are not identical. 

98 """ 

99 return _assert_similar(_assert_identical_impl, a, b) 

100 

101 

102def _assert_similar(impl: Callable[..., None], *args: Any, **kwargs: Any) -> None: 

103 try: 

104 impl(*args, **kwargs) 

105 except AssertionError as exc: 

106 if hasattr(exc, '__notes__'): 

107 # See comment above. 

108 notes = [] 

109 rest = -1 

110 for i, note in enumerate(exc.__notes__): 

111 if 'PREPOSITION' in note: 

112 notes.append(note.replace('PREPOSITION', 'in')) 

113 rest = i 

114 break 

115 notes.extend( 

116 note.replace('PREPOSITION', 'of') for note in exc.__notes__[rest + 1 :] 

117 ) 

118 exc.__notes__ = notes 

119 raise 

120 

121 

122def _assert_identical_impl( 

123 a: _T, 

124 b: _T, 

125) -> None: 

126 _assert_identical_structure(a, b) 

127 _assert_identical_data(a, b) 

128 

129 

130def _assert_allclose_impl( 

131 a: _T, 

132 b: _T, 

133 **kwargs: Any, 

134) -> None: 

135 _assert_identical_structure(a, b) 

136 _assert_allclose_data(a, b, **kwargs) 

137 

138 

139def _assert_identical_structure( 

140 a: _T, 

141 b: _T, 

142) -> None: 

143 assert type(a) is type(b) 

144 if isinstance(a, Variable): 

145 _assert_identical_variable_structure(a, b) 

146 elif isinstance(a, DataArray): 

147 _assert_identical_data_array_structure(a, b) 

148 

149 

150def _assert_identical_variable_structure(a: Variable, b: Variable) -> None: 

151 assert a.sizes == b.sizes 

152 assert a.unit == b.unit 

153 assert a.dtype == b.dtype 

154 assert (a.bins is None) == (b.bins is None) 

155 if a.bins is not None: 

156 assert b.bins is not None 

157 assert a.bins.unit == b.bins.unit 

158 else: 

159 if a.variances is not None: 

160 assert b.variances is not None, 'a has variances but b does not' 

161 else: 

162 assert b.variances is None, 'a has no variances but b does' 

163 

164 

165def _assert_identical_data_array_structure(a: DataArray, b: DataArray) -> None: 

166 _assert_mapping_eq(a.coords, b.coords, 'coord') 

167 _assert_mapping_eq(a.deprecated_attrs, b.deprecated_attrs, 'attr') 

168 _assert_mapping_eq(a.masks, b.masks, 'mask') 

169 

170 

171def _assert_identical_dataset(a: Dataset, b: Dataset) -> None: 

172 _assert_mapping_eq(a, b, 'dataset item') 

173 

174 

175def _assert_identical_datagroup(a: DataGroup[Any], b: DataGroup[Any]) -> None: 

176 _assert_mapping_eq(a, b, 'data group item') 

177 

178 

179def _assert_identical_alignment(a: Any, b: Any) -> None: 

180 if isinstance(a, Variable) and isinstance(b, Variable): 

181 assert a.aligned == b.aligned 

182 

183 

184def _assert_mapping_eq( 

185 a: Mapping[str, Any], 

186 b: Mapping[str, Any], 

187 map_name: str, 

188) -> None: 

189 with _add_note(map_name + 's'): 

190 assert a.keys() == b.keys() 

191 for name, val_a in a.items(): 

192 with _add_note("{} '{}'", map_name, name): 

193 val_b = b[name] 

194 _assert_identical_impl(val_a, val_b) 

195 _assert_identical_alignment(val_a, val_b) 

196 

197 

198def _assert_identical_data( 

199 a: _T, 

200 b: _T, 

201) -> None: 

202 if isinstance(a, Variable): 

203 _assert_identical_variable_data(a, b) 

204 elif isinstance(a, DataArray): 

205 _assert_identical_variable_data(a.data, b.data) 

206 elif isinstance(a, Dataset): 

207 _assert_identical_dataset(a, b) 

208 elif isinstance(a, DataGroup): 

209 _assert_identical_datagroup(a, b) 

210 else: 

211 assert a == b 

212 

213 

214def _assert_identical_variable_data(a: Variable, b: Variable) -> None: 

215 if a.bins is None: 

216 _assert_identical_dense_variable_data(a, b) 

217 else: 

218 _assert_identical_binned_variable_data(a, b) 

219 

220 

221def _assert_identical_dense_variable_data(a: Variable, b: Variable) -> None: 

222 with _add_note('values'): 

223 np.testing.assert_array_equal( 

224 a.values, b.values, err_msg='when comparing values' 

225 ) 

226 if a.variances is not None: 

227 with _add_note('variances'): 

228 np.testing.assert_array_equal( 

229 a.variances, b.variances, err_msg='when comparing variances' 

230 ) 

231 

232 

233def _assert_identical_binned_variable_data(a: Variable, b: Variable) -> None: 

234 assert a.bins is not None 

235 assert b.bins is not None 

236 _assert_identical_impl(a.bins.concat().value, b.bins.concat().value) 

237 

238 

239def _assert_allclose_data( 

240 a: _T, 

241 b: _T, 

242 **kwargs: Any, 

243) -> None: 

244 if isinstance(a, Variable): 

245 _assert_allclose_variable_data(a, b, **kwargs) 

246 elif isinstance(a, DataArray): 

247 _assert_allclose_variable_data(a.data, b.data, **kwargs) 

248 else: 

249 raise NotImplementedError 

250 

251 

252def _assert_allclose_variable_data(a: Variable, b: Variable, **kwargs: Any) -> None: 

253 if a.bins is None: 

254 _assert_allclose_dense_variable_data(a, b, **kwargs) 

255 else: 

256 _assert_allclose_binned_variable_data(a, b, **kwargs) 

257 

258 

259def _assert_allclose_dense_variable_data( 

260 a: Variable, 

261 b: Variable, 

262 rtol: Variable | None = None, 

263 atol: Variable | None = None, 

264 **kwargs: Any, 

265) -> None: 

266 if rtol is not None: 

267 kwargs['rtol'] = rtol.to(unit='dimensionless').value 

268 if atol is not None: 

269 if hasattr(a, 'unit'): 

270 atol = atol.to(unit=a.unit) 

271 else: 

272 atol = atol.to(unit='dimensionless') 

273 kwargs['atol'] = atol.value 

274 

275 with _add_note('values'): 

276 np.testing.assert_allclose( 

277 a.values, b.values, err_msg='when comparing values', **kwargs 

278 ) 

279 if a.variances is not None: 

280 with _add_note('variances'): 

281 np.testing.assert_allclose( 

282 a.variances, 

283 b.variances, 

284 err_msg='when comparing variances', 

285 **kwargs, 

286 ) 

287 

288 

289def _assert_allclose_binned_variable_data( 

290 a: Variable, b: Variable, **kwargs: Any 

291) -> None: 

292 assert a.bins is not None 

293 assert b.bins is not None 

294 _assert_allclose_impl(a.bins.concat().value, b.bins.concat().value, **kwargs) 

295 

296 

297@contextmanager 

298def _add_note(loc: str, *args: str) -> Iterator[None]: 

299 try: 

300 yield 

301 except AssertionError as exc: 

302 if hasattr(exc, 'add_note'): 

303 # Needs Python >= 3.11 

304 exc.add_note(f'PREPOSITION {loc.format(*args)}') 

305 raise 

306 

307 

308__all__ = ['assert_identical', 'assert_allclose']