Coverage for install/scipp/testing/assertions.py: 88%

122 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-04-28 01:28 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Jan-Lukas Wynen 

4"""Custom assertions for pytest-based tests. 

5 

6To get the best error messages, tell pytest to rewrite assertions in this module. 

7Place the following code in your ``conftest.py``: 

8 

9.. code-block:: python 

10 

11 pytest.register_assert_rewrite('scipp.testing.assertions') 

12 

13""" 

14 

15from contextlib import contextmanager 

16from typing import Any, Callable, Iterator, Mapping, TypeVar 

17 

18import numpy as np 

19 

20from ..core import DataArray, DataGroup, Dataset, Variable 

21 

22# Exception notes are formatted as 'PREPOSITION {loc}', 

23# where 'loc' is set by the concrete assertion functions to indicate coords, attrs, etc. 

24# 'PREPOSITION' is replaced at the top level to produce exception messages like: 

25# 

26# [...] 

27# in coord 'x' 

28# of data group item 'b' 

29# of data group item 'a' 

30 

31T = TypeVar('T') 

32 

33 

34def assert_allclose( 

35 a: T, b: T, rtol: Variable = None, atol: Variable = None, **kwargs 

36) -> None: 

37 """Raise an AssertionError if two objects don't have similar values 

38 or if their other properties are not identical. 

39 

40 Parameters 

41 ---------- 

42 a: 

43 The actual object to check. 

44 b: 

45 The desired, expected object. 

46 

47 Raises 

48 ------ 

49 AssertionError 

50 If the objects are not identical. 

51 """ 

52 return _assert_similar(_assert_allclose_impl, a, b, rtol=rtol, atol=atol, **kwargs) 

53 

54 

55def assert_identical(a: T, b: T) -> None: 

56 """Raise an AssertionError if two objects are not identical. 

57 

58 For Scipp objects, ``assert_identical(a, b)`` is equivalent to 

59 ``assert sc.identical(a, b, equal_nan=True)`` but produces a more precise 

60 error message in pytest. 

61 If this function is called with arguments that are not supported by 

62 :func:`scipp.identical`, it calls ``assert a == b``. 

63 

64 This function requires exact equality including equal types. 

65 For example, ``assert_identical(1, 1.0)`` will raise. 

66 

67 NaN elements of Scipp variables are treated as equal. 

68 

69 Parameters 

70 ---------- 

71 a: 

72 The actual object to check. 

73 b: 

74 The desired, expected object. 

75 

76 Raises 

77 ------ 

78 AssertionError 

79 If the objects are not identical. 

80 """ 

81 return _assert_similar(_assert_identical_impl, a, b) 

82 

83 

84def _assert_similar(impl: Callable, *args, **kwargs) -> None: 

85 try: 

86 impl(*args, **kwargs) 

87 except AssertionError as exc: 

88 if hasattr(exc, '__notes__'): 

89 # See comment above. 

90 notes = [] 

91 rest = -1 

92 for i, note in enumerate(exc.__notes__): 

93 if 'PREPOSITION' in note: 

94 notes.append(note.replace('PREPOSITION', 'in')) 

95 rest = i 

96 break 

97 notes.extend( 

98 note.replace('PREPOSITION', 'of') for note in exc.__notes__[rest + 1 :] 

99 ) 

100 exc.__notes__ = notes 

101 raise 

102 

103 

104def _assert_identical_impl( 

105 a: T, 

106 b: T, 

107) -> None: 

108 _assert_identical_structure(a, b) 

109 _assert_identical_data(a, b) 

110 

111 

112def _assert_allclose_impl( 

113 a: T, 

114 b: T, 

115 **kwargs, 

116) -> None: 

117 _assert_identical_structure(a, b) 

118 _assert_allclose_data(a, b, **kwargs) 

119 

120 

121def _assert_identical_structure( 

122 a: T, 

123 b: T, 

124) -> None: 

125 assert type(a) == type(b) 

126 if isinstance(a, Variable): 

127 _assert_identical_variable_structure(a, b) 

128 elif isinstance(a, DataArray): 

129 _assert_identical_data_array_structure(a, b) 

130 

131 

132def _assert_identical_variable_structure(a: Variable, b: Variable) -> None: 

133 assert a.sizes == b.sizes 

134 assert a.unit == b.unit 

135 assert a.dtype == b.dtype 

136 assert (a.bins is None) == (b.bins is None) 

137 if a.bins is not None: 

138 assert a.bins.unit == b.bins.unit 

139 else: 

140 if a.variances is not None: 

141 assert b.variances is not None, 'a has variances but b does not' 

142 else: 

143 assert b.variances is None, 'a has no variances but b does' 

144 

145 

146def _assert_identical_data_array_structure(a: DataArray, b: DataArray) -> None: 

147 _assert_mapping_eq(a.coords, b.coords, 'coord') 

148 _assert_mapping_eq(a.deprecated_attrs, b.deprecated_attrs, 'attr') 

149 _assert_mapping_eq(a.masks, b.masks, 'mask') 

150 

151 

152def _assert_identical_dataset(a: Dataset, b: Dataset) -> None: 

153 _assert_mapping_eq(a, b, 'dataset item') 

154 

155 

156def _assert_identical_datagroup(a: DataGroup, b: DataGroup) -> None: 

157 _assert_mapping_eq(a, b, 'data group item') 

158 

159 

160def _assert_identical_alignment(a: Any, b: Any) -> None: 

161 if isinstance(a, Variable) and isinstance(b, Variable): 

162 assert a.aligned == b.aligned 

163 

164 

165def _assert_mapping_eq( 

166 a: Mapping[str, Any], 

167 b: Mapping[str, Any], 

168 map_name: str, 

169 **kwargs, 

170) -> None: 

171 with _add_note(map_name + 's'): 

172 assert a.keys() == b.keys() 

173 for name, val_a in a.items(): 

174 with _add_note("{} '{}'", map_name, name): 

175 val_b = b[name] 

176 _assert_identical_impl(val_a, val_b) 

177 _assert_identical_alignment(val_a, val_b) 

178 

179 

180def _assert_identical_data( 

181 a: T, 

182 b: T, 

183) -> None: 

184 if isinstance(a, Variable): 

185 _assert_identical_variable_data(a, b) 

186 elif isinstance(a, DataArray): 

187 _assert_identical_variable_data(a.data, b.data) 

188 elif isinstance(a, Dataset): 

189 _assert_identical_dataset(a, b) 

190 elif isinstance(a, DataGroup): 

191 _assert_identical_datagroup(a, b) 

192 else: 

193 assert a == b 

194 

195 

196def _assert_identical_variable_data(a: Variable, b: Variable) -> None: 

197 if a.bins is None: 

198 _assert_identical_dense_variable_data(a, b) 

199 else: 

200 _assert_identical_binned_variable_data(a, b) 

201 

202 

203def _assert_identical_dense_variable_data(a: Variable, b: Variable) -> None: 

204 with _add_note('values'): 

205 np.testing.assert_array_equal( 

206 a.values, b.values, err_msg='when comparing values' 

207 ) 

208 if a.variances is not None: 

209 with _add_note('variances'): 

210 np.testing.assert_array_equal( 

211 a.variances, b.variances, err_msg='when comparing variances' 

212 ) 

213 

214 

215def _assert_identical_binned_variable_data(a: Variable, b: Variable) -> None: 

216 _assert_identical_impl(a.bins.concat().value, b.bins.concat().value) 

217 

218 

219def _assert_allclose_data( 

220 a: T, 

221 b: T, 

222 **kwargs, 

223) -> None: 

224 if isinstance(a, Variable): 

225 _assert_allclose_variable_data(a, b, **kwargs) 

226 elif isinstance(a, DataArray): 

227 _assert_allclose_variable_data(a.data, b.data, **kwargs) 

228 else: 

229 raise NotImplementedError 

230 

231 

232def _assert_allclose_variable_data(a: Variable, b: Variable, **kwargs) -> None: 

233 if a.bins is None: 

234 _assert_allclose_dense_variable_data(a, b, **kwargs) 

235 else: 

236 _assert_allclose_binned_variable_data(a, b, **kwargs) 

237 

238 

239def _assert_allclose_dense_variable_data( 

240 a: Variable, b: Variable, rtol: Variable = None, atol: Variable = None, **kwargs 

241) -> None: 

242 if rtol is not None: 

243 kwargs['rtol'] = rtol.to(unit='dimensionless').value 

244 if atol is not None: 

245 if hasattr(a, 'unit'): 

246 atol = atol.to(unit=a.unit) 

247 else: 

248 atol = atol.to(unit='dimensionless') 

249 kwargs['atol'] = atol.value 

250 

251 with _add_note('values'): 

252 np.testing.assert_allclose( 

253 a.values, b.values, err_msg='when comparing values', **kwargs 

254 ) 

255 if a.variances is not None: 

256 with _add_note('variances'): 

257 np.testing.assert_allclose( 

258 a.variances, 

259 b.variances, 

260 err_msg='when comparing variances', 

261 **kwargs, 

262 ) 

263 

264 

265def _assert_allclose_binned_variable_data( 

266 a: Variable, b: Variable, rtol: Variable, atol: Variable, **kwargs 

267) -> None: 

268 _assert_allclose_impl(a.bins.concat().value, b.bins.concat().value, **kwargs) 

269 

270 

271@contextmanager 

272def _add_note(loc: str, *args: str) -> Iterator[None]: 

273 try: 

274 yield 

275 except AssertionError as exc: 

276 if hasattr(exc, 'add_note'): 

277 # Needs Python >= 3.11 

278 exc.add_note(f'PREPOSITION {loc.format(*args)}') 

279 raise 

280 

281 

282__all__ = ['assert_identical', 'assert_allclose']