Coverage for install/scipp/testing/assertions.py: 88%
127 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-24 01:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-24 01:51 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Jan-Lukas Wynen
4"""Custom assertions for pytest-based tests.
6To get the best error messages, tell pytest to rewrite assertions in this module.
7Place the following code in your ``conftest.py``:
9.. code-block:: python
11 pytest.register_assert_rewrite('scipp.testing.assertions')
13"""
15from collections.abc import Callable, Iterator, Mapping
16from contextlib import contextmanager
17from typing import Any, TypeVar
19import numpy as np
21from ..core import DataArray, DataGroup, Dataset, Variable
23# Exception notes are formatted as 'PREPOSITION {loc}',
24# where 'loc' is set by the concrete assertion functions to indicate coords, attrs, etc.
25# 'PREPOSITION' is replaced at the top level to produce exception messages like:
26#
27# [...]
28# in coord 'x'
29# of data group item 'b'
30# of data group item 'a'
32_T = TypeVar('_T', Variable, DataArray)
35def assert_allclose(
36 a: _T,
37 b: _T,
38 rtol: Variable | None = None,
39 atol: Variable | None = None,
40 **kwargs: Any,
41) -> None:
42 """Raise an AssertionError if two objects don't have similar values
43 or if their other properties are not identical.
45 Parameters
46 ----------
47 a:
48 The actual object to check.
49 b:
50 The desired, expected object.
51 rtol:
52 Tolerance value relative (to b).
53 Can be a scalar or non-scalar.
54 Cannot have variances.
55 Defaults to scalar 1e-7 if unset.
56 atol:
57 Tolerance value absolute.
58 Can be a scalar or non-scalar.
59 Cannot have variances.
60 Defaults to scalar 0 if unset and takes units from y arg.
61 kwargs:
62 Additional arguments to pass to :func:`numpy.testing.assert_allclose`
63 which is used for comparing data.
65 Raises
66 ------
67 AssertionError
68 If the objects are not identical.
69 """
70 return _assert_similar(_assert_allclose_impl, a, b, rtol=rtol, atol=atol, **kwargs)
73def assert_identical(a: _T, b: _T) -> None:
74 """Raise an AssertionError if two objects are not identical.
76 For Scipp objects, ``assert_identical(a, b)`` is equivalent to
77 ``assert sc.identical(a, b, equal_nan=True)`` but produces a more precise
78 error message in pytest.
79 If this function is called with arguments that are not supported by
80 :func:`scipp.identical`, it calls ``assert a == b``.
82 This function requires exact equality including equal types.
83 For example, ``assert_identical(1, 1.0)`` will raise.
85 NaN elements of Scipp variables are treated as equal.
87 Parameters
88 ----------
89 a:
90 The actual object to check.
91 b:
92 The desired, expected object.
94 Raises
95 ------
96 AssertionError
97 If the objects are not identical.
98 """
99 return _assert_similar(_assert_identical_impl, a, b)
102def _assert_similar(impl: Callable[..., None], *args: Any, **kwargs: Any) -> None:
103 try:
104 impl(*args, **kwargs)
105 except AssertionError as exc:
106 if hasattr(exc, '__notes__'):
107 # See comment above.
108 notes = []
109 rest = -1
110 for i, note in enumerate(exc.__notes__):
111 if 'PREPOSITION' in note:
112 notes.append(note.replace('PREPOSITION', 'in'))
113 rest = i
114 break
115 notes.extend(
116 note.replace('PREPOSITION', 'of') for note in exc.__notes__[rest + 1 :]
117 )
118 exc.__notes__ = notes
119 raise
122def _assert_identical_impl(
123 a: _T,
124 b: _T,
125) -> None:
126 _assert_identical_structure(a, b)
127 _assert_identical_data(a, b)
130def _assert_allclose_impl(
131 a: _T,
132 b: _T,
133 **kwargs: Any,
134) -> None:
135 _assert_identical_structure(a, b)
136 _assert_allclose_data(a, b, **kwargs)
139def _assert_identical_structure(
140 a: _T,
141 b: _T,
142) -> None:
143 assert type(a) is type(b)
144 if isinstance(a, Variable):
145 _assert_identical_variable_structure(a, b)
146 elif isinstance(a, DataArray):
147 _assert_identical_data_array_structure(a, b)
150def _assert_identical_variable_structure(a: Variable, b: Variable) -> None:
151 assert a.sizes == b.sizes
152 assert a.unit == b.unit
153 assert a.dtype == b.dtype
154 assert (a.bins is None) == (b.bins is None)
155 if a.bins is not None:
156 assert b.bins is not None
157 assert a.bins.unit == b.bins.unit
158 else:
159 if a.variances is not None:
160 assert b.variances is not None, 'a has variances but b does not'
161 else:
162 assert b.variances is None, 'a has no variances but b does'
165def _assert_identical_data_array_structure(a: DataArray, b: DataArray) -> None:
166 _assert_mapping_eq(a.coords, b.coords, 'coord')
167 _assert_mapping_eq(a.deprecated_attrs, b.deprecated_attrs, 'attr')
168 _assert_mapping_eq(a.masks, b.masks, 'mask')
171def _assert_identical_dataset(a: Dataset, b: Dataset) -> None:
172 _assert_mapping_eq(a, b, 'dataset item')
175def _assert_identical_datagroup(a: DataGroup[Any], b: DataGroup[Any]) -> None:
176 _assert_mapping_eq(a, b, 'data group item')
179def _assert_identical_alignment(a: Any, b: Any) -> None:
180 if isinstance(a, Variable) and isinstance(b, Variable):
181 assert a.aligned == b.aligned
184def _assert_mapping_eq(
185 a: Mapping[str, Any],
186 b: Mapping[str, Any],
187 map_name: str,
188) -> None:
189 with _add_note(map_name + 's'):
190 assert a.keys() == b.keys()
191 for name, val_a in a.items():
192 with _add_note("{} '{}'", map_name, name):
193 val_b = b[name]
194 _assert_identical_impl(val_a, val_b)
195 _assert_identical_alignment(val_a, val_b)
198def _assert_identical_data(
199 a: _T,
200 b: _T,
201) -> None:
202 if isinstance(a, Variable):
203 _assert_identical_variable_data(a, b)
204 elif isinstance(a, DataArray):
205 _assert_identical_variable_data(a.data, b.data)
206 elif isinstance(a, Dataset):
207 _assert_identical_dataset(a, b)
208 elif isinstance(a, DataGroup):
209 _assert_identical_datagroup(a, b)
210 else:
211 assert a == b
214def _assert_identical_variable_data(a: Variable, b: Variable) -> None:
215 if a.bins is None:
216 _assert_identical_dense_variable_data(a, b)
217 else:
218 _assert_identical_binned_variable_data(a, b)
221def _assert_identical_dense_variable_data(a: Variable, b: Variable) -> None:
222 with _add_note('values'):
223 np.testing.assert_array_equal(
224 a.values, b.values, err_msg='when comparing values'
225 )
226 if a.variances is not None:
227 with _add_note('variances'):
228 np.testing.assert_array_equal(
229 a.variances, b.variances, err_msg='when comparing variances'
230 )
233def _assert_identical_binned_variable_data(a: Variable, b: Variable) -> None:
234 assert a.bins is not None
235 assert b.bins is not None
236 _assert_identical_impl(a.bins.concat().value, b.bins.concat().value)
239def _assert_allclose_data(
240 a: _T,
241 b: _T,
242 **kwargs: Any,
243) -> None:
244 if isinstance(a, Variable):
245 _assert_allclose_variable_data(a, b, **kwargs)
246 elif isinstance(a, DataArray):
247 _assert_allclose_variable_data(a.data, b.data, **kwargs)
248 else:
249 raise NotImplementedError
252def _assert_allclose_variable_data(a: Variable, b: Variable, **kwargs: Any) -> None:
253 if a.bins is None:
254 _assert_allclose_dense_variable_data(a, b, **kwargs)
255 else:
256 _assert_allclose_binned_variable_data(a, b, **kwargs)
259def _assert_allclose_dense_variable_data(
260 a: Variable,
261 b: Variable,
262 rtol: Variable | None = None,
263 atol: Variable | None = None,
264 **kwargs: Any,
265) -> None:
266 if rtol is not None:
267 kwargs['rtol'] = rtol.to(unit='dimensionless').value
268 if atol is not None:
269 if hasattr(a, 'unit'):
270 atol = atol.to(unit=a.unit)
271 else:
272 atol = atol.to(unit='dimensionless')
273 kwargs['atol'] = atol.value
275 with _add_note('values'):
276 np.testing.assert_allclose(
277 a.values, b.values, err_msg='when comparing values', **kwargs
278 )
279 if a.variances is not None:
280 with _add_note('variances'):
281 np.testing.assert_allclose(
282 a.variances,
283 b.variances,
284 err_msg='when comparing variances',
285 **kwargs,
286 )
289def _assert_allclose_binned_variable_data(
290 a: Variable, b: Variable, **kwargs: Any
291) -> None:
292 assert a.bins is not None
293 assert b.bins is not None
294 _assert_allclose_impl(a.bins.concat().value, b.bins.concat().value, **kwargs)
297@contextmanager
298def _add_note(loc: str, *args: str) -> Iterator[None]:
299 try:
300 yield
301 except AssertionError as exc:
302 if hasattr(exc, 'add_note'):
303 # Needs Python >= 3.11
304 exc.add_note(f'PREPOSITION {loc.format(*args)}')
305 raise
308__all__ = ['assert_identical', 'assert_allclose']