Coverage for install/scipp/core/comparison.py: 61%
46 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Simon Heybrock
4# ruff: noqa: E501
6from __future__ import annotations
8import numpy as np
10from .._scipp import core as _cpp
11from ..typing import VariableLike
12from . import data_group
13from ._cpp_wrapper_util import call_func as _call_cpp_func
14from .cpp_classes import DataArray, Dataset, Variable
15from .variable import scalar
18def less(x: VariableLike, y: VariableLike) -> VariableLike:
19 """Element-wise '<' (less).
21 Warning: If one or both of the operators have variances (uncertainties)
22 they are ignored silently, i.e., comparison is based exclusively on
23 the values.
25 Parameters
26 ----------
27 x:
28 Left input.
29 y:
30 Right input.
32 Returns
33 -------
34 :
35 Booleans that are true where `a < b`.
36 """
37 return _call_cpp_func(_cpp.less, x, y)
40def greater(x: VariableLike, y: VariableLike) -> VariableLike:
41 """Element-wise '>' (greater).
43 Warning: If one or both of the operators have variances (uncertainties)
44 they are ignored silently, i.e., comparison is based exclusively on
45 the values.
47 Parameters
48 ----------
49 x:
50 Left input.
51 y:
52 Right input.
54 Returns
55 -------
56 :
57 Booleans that are true where `a > b`.
58 """
59 return _call_cpp_func(_cpp.greater, x, y)
62def less_equal(x: VariableLike, y: VariableLike) -> VariableLike:
63 """Element-wise '<=' (less_equal).
65 Warning: If one or both of the operators have variances (uncertainties)
66 they are ignored silently, i.e., comparison is based exclusively on
67 the values.
69 Parameters
70 ----------
71 x:
72 Left input.
73 y:
74 Right input.
76 Returns
77 -------
78 :
79 Booleans that are true where `a <= b`.
80 """
81 return _call_cpp_func(_cpp.less_equal, x, y)
84def greater_equal(x: VariableLike, y: VariableLike) -> VariableLike:
85 """Element-wise '>=' (greater_equal).
87 Warning: If one or both of the operators have variances (uncertainties)
88 they are ignored silently, i.e., comparison is based exclusively on
89 the values.
91 Parameters
92 ----------
93 x:
94 Left input.
95 y:
96 Right input.
98 Returns
99 -------
100 :
101 Booleans that are true where `a >= b`.
102 """
103 return _call_cpp_func(_cpp.greater_equal, x, y)
106def equal(x: VariableLike, y: VariableLike) -> VariableLike:
107 """Element-wise '==' (equal).
109 Warning: If one or both of the operators have variances (uncertainties)
110 they are ignored silently, i.e., comparison is based exclusively on
111 the values.
113 Parameters
114 ----------
115 x:
116 Left input.
117 y:
118 Right input.
120 Returns
121 -------
122 :
123 Booleans that are true where `a == b`.
124 """
125 return _call_cpp_func(_cpp.equal, x, y)
128def not_equal(x: VariableLike, y: VariableLike) -> VariableLike:
129 """Element-wise '!=' (not_equal).
131 Warning: If one or both of the operators have variances (uncertainties)
132 they are ignored silently, i.e., comparison is based exclusively on
133 the values.
135 Parameters
136 ----------
137 x:
138 Left input.
139 y:
140 Right input.
142 Returns
143 -------
144 :
145 Booleans that are true where `a != b`.
146 """
147 return _call_cpp_func(_cpp.not_equal, x, y)
150def _identical_data_groups(
151 x: data_group.DataGroup, y: data_group.DataGroup, *, equal_nan: bool
152) -> bool:
153 def compare(a, b):
154 if not isinstance(a, type(b)):
155 return False
156 if isinstance(a, (Variable, DataArray, Dataset, data_group.DataGroup)):
157 return identical(a, b, equal_nan=equal_nan)
158 if isinstance(a, np.ndarray):
159 return np.array_equal(a, b, equal_nan=equal_nan)
160 # Explicit conversion to bool in case __eq__ returns
161 # something else like an array.
162 return bool(a == b)
164 if x.keys() != y.keys():
165 return False
166 return all(compare(x[k], y[k]) for k in x.keys())
169def identical(x: VariableLike, y: VariableLike, *, equal_nan: bool = False) -> bool:
170 """Full comparison of x and y.
172 Parameters
173 ----------
174 x:
175 Left input.
176 y:
177 Right input.
178 equal_nan:
179 If true, non-finite values at the same index in (x, y) are treated as equal.
180 Signbit must match for infs.
182 Returns
183 -------
184 :
185 True if x and y have identical values, variances, dtypes, units,
186 dims, shapes, coords, and masks. Else False.
187 """
188 if isinstance(x, data_group.DataGroup):
189 if not isinstance(y, data_group.DataGroup):
190 raise TypeError("Both or neither of the arguments must be a DataGroup")
191 return _identical_data_groups(x, y, equal_nan=equal_nan)
193 return _call_cpp_func(_cpp.identical, x, y, equal_nan=equal_nan)
196def isclose(
197 x: _cpp.Variable,
198 y: _cpp.Variable,
199 *,
200 rtol: _cpp.Variable = None,
201 atol: _cpp.Variable = None,
202 equal_nan: bool = False,
203) -> _cpp.Variable:
204 """Checks element-wise if the inputs are close to each other.
206 Compares values of x and y element by element against absolute
207 and relative tolerances according to (non-symmetric)
209 .. code-block:: python
211 abs(x - y) <= atol + rtol * abs(y)
213 If both x and y have variances, the variances are also compared
214 between elements. In this case, both values and variances must
215 be within the computed tolerance limits. That is:
217 .. code-block:: python
219 abs(x.value - y.value) <= atol + rtol * abs(y.value) and
220 abs(sqrt(x.variance) - sqrt(y.variance)) <= atol + rtol * abs(sqrt(y.variance))
222 Attention
223 ---------
224 Vectors and matrices are compared element-wise.
225 This is not necessarily a good measure for the similarity of `spatial`
226 dtypes like ``scipp.DType.rotation3`` or ``scipp.Dtype.affine_transform3``
227 (see :mod:`scipp.spatial`).
229 Parameters
230 ----------
231 x:
232 Left input.
233 y:
234 Right input.
235 rtol:
236 Tolerance value relative (to y).
237 Can be a scalar or non-scalar.
238 Defaults to scalar 1e-5 if unset.
239 atol:
240 Tolerance value absolute.
241 Can be a scalar or non-scalar.
242 Defaults to scalar 1e-8 if unset and takes units from y arg.
243 equal_nan:
244 If true, non-finite values at the same index in (x, y) are treated as equal.
245 Signbit must match for infs.
247 Returns
248 -------
249 :
250 Variable same size as input.
251 Element True if absolute diff of value <= atol + rtol * abs(y),
252 otherwise False.
254 See Also
255 --------
256 scipp.allclose:
257 Equivalent of ``sc.all(sc.isclose(...)).value``.
258 """
259 if atol is None:
260 atol = scalar(1e-8, unit=y.unit)
261 if rtol is None:
262 rtol = scalar(1e-5, unit=None if atol.unit is None else _cpp.units.one)
263 return _call_cpp_func(_cpp.isclose, x, y, rtol, atol, equal_nan)
266def allclose(x, y, rtol=None, atol=None, equal_nan=False) -> True:
267 """Checks if all elements in the inputs are close to each other.
269 Verifies that ALL element-wise comparisons meet the condition:
271 abs(x - y) <= atol + rtol * abs(y)
273 If both x and y have variances, the variances are also compared
274 between elements. In this case, both values and variances must
275 be within the computed tolerance limits. That is:
277 .. code-block:: python
279 abs(x.value - y.value) <= atol + rtol * abs(y.value) and abs(
280 sqrt(x.variance) - sqrt(y.variance)) \
281 <= atol + rtol * abs(sqrt(y.variance))
283 Attention
284 ---------
285 Vectors and matrices are compared element-wise.
286 This is not necessarily a good measure for the similarity of `spatial`
287 dtypes like ``scipp.DType.rotation3`` or ``scipp.Dtype.affine_transform3``
288 (see :mod:`scipp.spatial`).
290 Parameters
291 ----------
292 x:
293 Left input.
294 y:
295 Right input.
296 rtol:
297 Tolerance value relative (to y).
298 Can be a scalar or non-scalar.
299 Cannot have variances.
300 Defaults to scalar 1e-5 if unset.
301 atol:
302 Tolerance value absolute.
303 Can be a scalar or non-scalar.
304 Cannot have variances.
305 Defaults to scalar 1e-8 if unset and takes units from y arg.
306 equal_nan:
307 If true, non-finite values at the same index in (x, y) are treated as equal.
308 Signbit must match for infs.
310 Returns
311 -------
312 :
313 True if for all elements ``value <= atol + rtol * abs(y)``, otherwise False.
315 See Also
316 --------
317 scipp.isclose:
318 Compares element-wise with specified tolerances.
319 """
320 return _call_cpp_func(
321 _cpp.all, isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan)
322 ).value