Coverage for install/scipp/utils/comparison.py: 76%
17 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @author Owen Arnold
4"""
5Advanced comparisons.
6"""
8from typing import Optional
10from ..core import CoordError, DataArray, DType, Variable, all, isclose
13def isnear(
14 x: DataArray,
15 y: DataArray,
16 rtol: Optional[Variable] = None,
17 atol: Optional[Variable] = None,
18 include_attrs: bool = True,
19 include_data: bool = True,
20 equal_nan: bool = True,
21) -> bool:
22 """
23 Similar to scipp.isclose, but intended to compare whole DataArrays.
24 Coordinates compared element by element with
26 .. code-block:: python
28 abs(x - y) <= atol + rtol * abs(y)
30 Compared coord and attr pairs are only considered equal if all
31 element-wise comparisons are True.
33 See scipp.isclose for more details on how the comparisons on each
34 item will be conducted.
36 Parameters
37 ----------
38 x:
39 lhs input
40 y:
41 rhs input
42 rtol:
43 relative tolerance (to y)
44 atol:
45 absolute tolerance
46 include_data:
47 Compare data element-wise between x, and y
48 include_attrs:
49 Compare all meta (coords and attrs) between x and y,
50 otherwise only compare coordinates from meta
51 equal_nan:
52 If ``True``, consider NaNs or infs to be equal
53 providing that they match in location and, for infs,
54 have the same sign
56 Returns
57 -------
58 :
59 ``True`` if near
61 Raises
62 ------
63 Exception:
64 If `x`, `y` are not going to be logically comparable
65 for reasons relating to shape, item naming or non-finite elements.
66 """
67 same_data = (
68 all(isclose(x.data, y.data, rtol=rtol, atol=atol, equal_nan=equal_nan)).value
69 if include_data
70 else True
71 )
72 same_len = (
73 len(x.deprecated_meta) == len(y.deprecated_meta)
74 if include_attrs
75 else len(x.coords) == len(y.coords)
76 )
77 if not same_len:
78 return False
79 for key, val in x.deprecated_meta.items() if include_attrs else x.coords.items():
80 a = x.deprecated_meta[key] if include_attrs else x.coords[key]
81 b = y.deprecated_meta[key] if include_attrs else y.coords[key]
82 if a.shape != b.shape:
83 raise CoordError(
84 f'Coord (or attr) with key {key} have different'
85 f' shapes. For x, shape is {a.shape}. For y, shape = {b.shape}'
86 )
87 if val.dtype in [DType.float64, DType.float32]:
88 if not all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).value:
89 return False
90 return same_data