Coverage for install/scipp/core/comparison.py: 60%
47 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock
4# ruff: noqa: E501
6from __future__ import annotations
8from typing import Any
10import numpy as np
12from .._scipp import core as _cpp
13from ..typing import VariableLike
14from . import data_group
15from ._cpp_wrapper_util import call_func as _call_cpp_func
16from .cpp_classes import DataArray, Dataset, Variable
17from .variable import scalar
20def less(x: VariableLike, y: VariableLike) -> VariableLike:
21 """Element-wise '<' (less).
23 Warning: If one or both of the operators have variances (uncertainties)
24 they are ignored silently, i.e., comparison is based exclusively on
25 the values.
27 Parameters
28 ----------
29 x:
30 Left input.
31 y:
32 Right input.
34 Returns
35 -------
36 :
37 Booleans that are true where `a < b`.
38 """
39 return _call_cpp_func(_cpp.less, x, y)
42def greater(x: VariableLike, y: VariableLike) -> VariableLike:
43 """Element-wise '>' (greater).
45 Warning: If one or both of the operators have variances (uncertainties)
46 they are ignored silently, i.e., comparison is based exclusively on
47 the values.
49 Parameters
50 ----------
51 x:
52 Left input.
53 y:
54 Right input.
56 Returns
57 -------
58 :
59 Booleans that are true where `a > b`.
60 """
61 return _call_cpp_func(_cpp.greater, x, y)
64def less_equal(x: VariableLike, y: VariableLike) -> VariableLike:
65 """Element-wise '<=' (less_equal).
67 Warning: If one or both of the operators have variances (uncertainties)
68 they are ignored silently, i.e., comparison is based exclusively on
69 the values.
71 Parameters
72 ----------
73 x:
74 Left input.
75 y:
76 Right input.
78 Returns
79 -------
80 :
81 Booleans that are true where `a <= b`.
82 """
83 return _call_cpp_func(_cpp.less_equal, x, y)
86def greater_equal(x: VariableLike, y: VariableLike) -> VariableLike:
87 """Element-wise '>=' (greater_equal).
89 Warning: If one or both of the operators have variances (uncertainties)
90 they are ignored silently, i.e., comparison is based exclusively on
91 the values.
93 Parameters
94 ----------
95 x:
96 Left input.
97 y:
98 Right input.
100 Returns
101 -------
102 :
103 Booleans that are true where `a >= b`.
104 """
105 return _call_cpp_func(_cpp.greater_equal, x, y)
108def equal(x: VariableLike, y: VariableLike) -> VariableLike:
109 """Element-wise '==' (equal).
111 Warning: If one or both of the operators have variances (uncertainties)
112 they are ignored silently, i.e., comparison is based exclusively on
113 the values.
115 Parameters
116 ----------
117 x:
118 Left input.
119 y:
120 Right input.
122 Returns
123 -------
124 :
125 Booleans that are true where `a == b`.
126 """
127 return _call_cpp_func(_cpp.equal, x, y)
130def not_equal(x: VariableLike, y: VariableLike) -> VariableLike:
131 """Element-wise '!=' (not_equal).
133 Warning: If one or both of the operators have variances (uncertainties)
134 they are ignored silently, i.e., comparison is based exclusively on
135 the values.
137 Parameters
138 ----------
139 x:
140 Left input.
141 y:
142 Right input.
144 Returns
145 -------
146 :
147 Booleans that are true where `a != b`.
148 """
149 return _call_cpp_func(_cpp.not_equal, x, y)
152def _identical_data_groups(
153 x: data_group.DataGroup[Any], y: data_group.DataGroup[Any], *, equal_nan: bool
154) -> bool:
155 def compare(a: Any, b: Any) -> bool:
156 if not isinstance(a, type(b)):
157 return False
158 if isinstance(a, Variable | DataArray | Dataset | data_group.DataGroup):
159 return identical(a, b, equal_nan=equal_nan)
160 if isinstance(a, np.ndarray):
161 return np.array_equal(a, b, equal_nan=equal_nan)
162 # Explicit conversion to bool in case __eq__ returns
163 # something else like an array.
164 return bool(a == b)
166 if x.keys() != y.keys():
167 return False
168 return all(compare(x[k], y[k]) for k in x.keys())
171def identical(x: VariableLike, y: VariableLike, *, equal_nan: bool = False) -> bool:
172 """Full comparison of x and y.
174 Parameters
175 ----------
176 x:
177 Left input.
178 y:
179 Right input.
180 equal_nan:
181 If true, non-finite values at the same index in (x, y) are treated as equal.
182 Signbit must match for infs.
184 Returns
185 -------
186 :
187 True if x and y have identical values, variances, dtypes, units,
188 dims, shapes, coords, and masks. Else False.
189 """
190 if isinstance(x, data_group.DataGroup):
191 if not isinstance(y, data_group.DataGroup):
192 raise TypeError("Both or neither of the arguments must be a DataGroup")
193 return _identical_data_groups(x, y, equal_nan=equal_nan)
195 return _call_cpp_func(_cpp.identical, x, y, equal_nan=equal_nan) # type: ignore[return-value]
198def isclose(
199 x: _cpp.Variable,
200 y: _cpp.Variable,
201 *,
202 rtol: _cpp.Variable = None,
203 atol: _cpp.Variable = None,
204 equal_nan: bool = False,
205) -> _cpp.Variable:
206 """Checks element-wise if the inputs are close to each other.
208 Compares values of x and y element by element against absolute
209 and relative tolerances according to (non-symmetric)
211 .. code-block:: python
213 abs(x - y) <= atol + rtol * abs(y)
215 If both x and y have variances, the variances are also compared
216 between elements. In this case, both values and variances must
217 be within the computed tolerance limits. That is:
219 .. code-block:: python
221 abs(x.value - y.value) <= atol + rtol * abs(y.value) and
222 abs(sqrt(x.variance) - sqrt(y.variance)) <= atol + rtol * abs(sqrt(y.variance))
224 Attention
225 ---------
226 Vectors and matrices are compared element-wise.
227 This is not necessarily a good measure for the similarity of `spatial`
228 dtypes like ``scipp.DType.rotation3`` or ``scipp.Dtype.affine_transform3``
229 (see :mod:`scipp.spatial`).
231 Parameters
232 ----------
233 x:
234 Left input.
235 y:
236 Right input.
237 rtol:
238 Tolerance value relative (to y).
239 Can be a scalar or non-scalar.
240 Defaults to scalar 1e-5 if unset.
241 atol:
242 Tolerance value absolute.
243 Can be a scalar or non-scalar.
244 Defaults to scalar 1e-8 if unset and takes units from y arg.
245 equal_nan:
246 If true, non-finite values at the same index in (x, y) are treated as equal.
247 Signbit must match for infs.
249 Returns
250 -------
251 :
252 Variable same size as input.
253 Element True if absolute diff of value <= atol + rtol * abs(y),
254 otherwise False.
256 See Also
257 --------
258 scipp.allclose:
259 Equivalent of ``sc.all(sc.isclose(...)).value``.
260 """
261 if atol is None:
262 atol = scalar(1e-8, unit=y.unit)
263 if rtol is None:
264 rtol = scalar(1e-5, unit=None if atol.unit is None else _cpp.units.one)
265 return _call_cpp_func(_cpp.isclose, x, y, rtol, atol, equal_nan)
268def allclose(
269 x: _cpp.Variable,
270 y: _cpp.Variable,
271 rtol: _cpp.Variable = None,
272 atol: _cpp.Variable = None,
273 equal_nan: bool = False,
274) -> bool:
275 """Checks if all elements in the inputs are close to each other.
277 Verifies that ALL element-wise comparisons meet the condition:
279 abs(x - y) <= atol + rtol * abs(y)
281 If both x and y have variances, the variances are also compared
282 between elements. In this case, both values and variances must
283 be within the computed tolerance limits. That is:
285 .. code-block:: python
287 abs(x.value - y.value) <= atol + rtol * abs(y.value) and abs(
288 sqrt(x.variance) - sqrt(y.variance)) \
289 <= atol + rtol * abs(sqrt(y.variance))
291 Attention
292 ---------
293 Vectors and matrices are compared element-wise.
294 This is not necessarily a good measure for the similarity of `spatial`
295 dtypes like ``scipp.DType.rotation3`` or ``scipp.Dtype.affine_transform3``
296 (see :mod:`scipp.spatial`).
298 Parameters
299 ----------
300 x:
301 Left input.
302 y:
303 Right input.
304 rtol:
305 Tolerance value relative (to y).
306 Can be a scalar or non-scalar.
307 Cannot have variances.
308 Defaults to scalar 1e-5 if unset.
309 atol:
310 Tolerance value absolute.
311 Can be a scalar or non-scalar.
312 Cannot have variances.
313 Defaults to scalar 1e-8 if unset and takes units from y arg.
314 equal_nan:
315 If true, non-finite values at the same index in (x, y) are treated as equal.
316 Signbit must match for infs.
318 Returns
319 -------
320 :
321 True if for all elements ``value <= atol + rtol * abs(y)``, otherwise False.
323 See Also
324 --------
325 scipp.isclose:
326 Compares element-wise with specified tolerances.
327 """
328 return _call_cpp_func( # type:ignore[no-any-return]
329 _cpp.all, isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan)
330 ).value # type: ignore[union-attr]