Coverage for install/scipp/testing/assertions.py: 88%
122 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Jan-Lukas Wynen
4"""Custom assertions for pytest-based tests.
6To get the best error messages, tell pytest to rewrite assertions in this module.
7Place the following code in your ``conftest.py``:
9.. code-block:: python
11 pytest.register_assert_rewrite('scipp.testing.assertions')
13"""
15from contextlib import contextmanager
16from typing import Any, Callable, Iterator, Mapping, TypeVar
18import numpy as np
20from ..core import DataArray, DataGroup, Dataset, Variable
22# Exception notes are formatted as 'PREPOSITION {loc}',
23# where 'loc' is set by the concrete assertion functions to indicate coords, attrs, etc.
24# 'PREPOSITION' is replaced at the top level to produce exception messages like:
25#
26# [...]
27# in coord 'x'
28# of data group item 'b'
29# of data group item 'a'
31T = TypeVar('T')
34def assert_allclose(
35 a: T, b: T, rtol: Variable = None, atol: Variable = None, **kwargs
36) -> None:
37 """Raise an AssertionError if two objects don't have similar values
38 or if their other properties are not identical.
40 Parameters
41 ----------
42 a:
43 The actual object to check.
44 b:
45 The desired, expected object.
47 Raises
48 ------
49 AssertionError
50 If the objects are not identical.
51 """
52 return _assert_similar(_assert_allclose_impl, a, b, rtol=rtol, atol=atol, **kwargs)
55def assert_identical(a: T, b: T) -> None:
56 """Raise an AssertionError if two objects are not identical.
58 For Scipp objects, ``assert_identical(a, b)`` is equivalent to
59 ``assert sc.identical(a, b, equal_nan=True)`` but produces a more precise
60 error message in pytest.
61 If this function is called with arguments that are not supported by
62 :func:`scipp.identical`, it calls ``assert a == b``.
64 This function requires exact equality including equal types.
65 For example, ``assert_identical(1, 1.0)`` will raise.
67 NaN elements of Scipp variables are treated as equal.
69 Parameters
70 ----------
71 a:
72 The actual object to check.
73 b:
74 The desired, expected object.
76 Raises
77 ------
78 AssertionError
79 If the objects are not identical.
80 """
81 return _assert_similar(_assert_identical_impl, a, b)
84def _assert_similar(impl: Callable, *args, **kwargs) -> None:
85 try:
86 impl(*args, **kwargs)
87 except AssertionError as exc:
88 if hasattr(exc, '__notes__'):
89 # See comment above.
90 notes = []
91 rest = -1
92 for i, note in enumerate(exc.__notes__):
93 if 'PREPOSITION' in note:
94 notes.append(note.replace('PREPOSITION', 'in'))
95 rest = i
96 break
97 notes.extend(
98 note.replace('PREPOSITION', 'of') for note in exc.__notes__[rest + 1 :]
99 )
100 exc.__notes__ = notes
101 raise
104def _assert_identical_impl(
105 a: T,
106 b: T,
107) -> None:
108 _assert_identical_structure(a, b)
109 _assert_identical_data(a, b)
112def _assert_allclose_impl(
113 a: T,
114 b: T,
115 **kwargs,
116) -> None:
117 _assert_identical_structure(a, b)
118 _assert_allclose_data(a, b, **kwargs)
121def _assert_identical_structure(
122 a: T,
123 b: T,
124) -> None:
125 assert type(a) == type(b)
126 if isinstance(a, Variable):
127 _assert_identical_variable_structure(a, b)
128 elif isinstance(a, DataArray):
129 _assert_identical_data_array_structure(a, b)
132def _assert_identical_variable_structure(a: Variable, b: Variable) -> None:
133 assert a.sizes == b.sizes
134 assert a.unit == b.unit
135 assert a.dtype == b.dtype
136 assert (a.bins is None) == (b.bins is None)
137 if a.bins is not None:
138 assert a.bins.unit == b.bins.unit
139 else:
140 if a.variances is not None:
141 assert b.variances is not None, 'a has variances but b does not'
142 else:
143 assert b.variances is None, 'a has no variances but b does'
146def _assert_identical_data_array_structure(a: DataArray, b: DataArray) -> None:
147 _assert_mapping_eq(a.coords, b.coords, 'coord')
148 _assert_mapping_eq(a.deprecated_attrs, b.deprecated_attrs, 'attr')
149 _assert_mapping_eq(a.masks, b.masks, 'mask')
152def _assert_identical_dataset(a: Dataset, b: Dataset) -> None:
153 _assert_mapping_eq(a, b, 'dataset item')
156def _assert_identical_datagroup(a: DataGroup, b: DataGroup) -> None:
157 _assert_mapping_eq(a, b, 'data group item')
160def _assert_identical_alignment(a: Any, b: Any) -> None:
161 if isinstance(a, Variable) and isinstance(b, Variable):
162 assert a.aligned == b.aligned
165def _assert_mapping_eq(
166 a: Mapping[str, Any],
167 b: Mapping[str, Any],
168 map_name: str,
169 **kwargs,
170) -> None:
171 with _add_note(map_name + 's'):
172 assert a.keys() == b.keys()
173 for name, val_a in a.items():
174 with _add_note("{} '{}'", map_name, name):
175 val_b = b[name]
176 _assert_identical_impl(val_a, val_b)
177 _assert_identical_alignment(val_a, val_b)
180def _assert_identical_data(
181 a: T,
182 b: T,
183) -> None:
184 if isinstance(a, Variable):
185 _assert_identical_variable_data(a, b)
186 elif isinstance(a, DataArray):
187 _assert_identical_variable_data(a.data, b.data)
188 elif isinstance(a, Dataset):
189 _assert_identical_dataset(a, b)
190 elif isinstance(a, DataGroup):
191 _assert_identical_datagroup(a, b)
192 else:
193 assert a == b
196def _assert_identical_variable_data(a: Variable, b: Variable) -> None:
197 if a.bins is None:
198 _assert_identical_dense_variable_data(a, b)
199 else:
200 _assert_identical_binned_variable_data(a, b)
203def _assert_identical_dense_variable_data(a: Variable, b: Variable) -> None:
204 with _add_note('values'):
205 np.testing.assert_array_equal(
206 a.values, b.values, err_msg='when comparing values'
207 )
208 if a.variances is not None:
209 with _add_note('variances'):
210 np.testing.assert_array_equal(
211 a.variances, b.variances, err_msg='when comparing variances'
212 )
215def _assert_identical_binned_variable_data(a: Variable, b: Variable) -> None:
216 _assert_identical_impl(a.bins.concat().value, b.bins.concat().value)
219def _assert_allclose_data(
220 a: T,
221 b: T,
222 **kwargs,
223) -> None:
224 if isinstance(a, Variable):
225 _assert_allclose_variable_data(a, b, **kwargs)
226 elif isinstance(a, DataArray):
227 _assert_allclose_variable_data(a.data, b.data, **kwargs)
228 else:
229 raise NotImplementedError
232def _assert_allclose_variable_data(a: Variable, b: Variable, **kwargs) -> None:
233 if a.bins is None:
234 _assert_allclose_dense_variable_data(a, b, **kwargs)
235 else:
236 _assert_allclose_binned_variable_data(a, b, **kwargs)
239def _assert_allclose_dense_variable_data(
240 a: Variable, b: Variable, rtol: Variable = None, atol: Variable = None, **kwargs
241) -> None:
242 if rtol is not None:
243 kwargs['rtol'] = rtol.to(unit='dimensionless').value
244 if atol is not None:
245 if hasattr(a, 'unit'):
246 atol = atol.to(unit=a.unit)
247 else:
248 atol = atol.to(unit='dimensionless')
249 kwargs['atol'] = atol.value
251 with _add_note('values'):
252 np.testing.assert_allclose(
253 a.values, b.values, err_msg='when comparing values', **kwargs
254 )
255 if a.variances is not None:
256 with _add_note('variances'):
257 np.testing.assert_allclose(
258 a.variances,
259 b.variances,
260 err_msg='when comparing variances',
261 **kwargs,
262 )
265def _assert_allclose_binned_variable_data(
266 a: Variable, b: Variable, rtol: Variable, atol: Variable, **kwargs
267) -> None:
268 _assert_allclose_impl(a.bins.concat().value, b.bins.concat().value, **kwargs)
271@contextmanager
272def _add_note(loc: str, *args: str) -> Iterator[None]:
273 try:
274 yield
275 except AssertionError as exc:
276 if hasattr(exc, 'add_note'):
277 # Needs Python >= 3.11
278 exc.add_note(f'PREPOSITION {loc.format(*args)}')
279 raise
282__all__ = ['assert_identical', 'assert_allclose']