Coverage for install/scipp/utils/comparison.py: 76%

17 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-04-28 01:28 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Owen Arnold 

4""" 

5Advanced comparisons. 

6""" 

7 

8from typing import Optional 

9 

10from ..core import CoordError, DataArray, DType, Variable, all, isclose 

11 

12 

13def isnear( 

14 x: DataArray, 

15 y: DataArray, 

16 rtol: Optional[Variable] = None, 

17 atol: Optional[Variable] = None, 

18 include_attrs: bool = True, 

19 include_data: bool = True, 

20 equal_nan: bool = True, 

21) -> bool: 

22 """ 

23 Similar to scipp.isclose, but intended to compare whole DataArrays. 

24 Coordinates compared element by element with 

25 

26 .. code-block:: python 

27 

28 abs(x - y) <= atol + rtol * abs(y) 

29 

30 Compared coord and attr pairs are only considered equal if all 

31 element-wise comparisons are True. 

32 

33 See scipp.isclose for more details on how the comparisons on each 

34 item will be conducted. 

35 

36 Parameters 

37 ---------- 

38 x: 

39 lhs input 

40 y: 

41 rhs input 

42 rtol: 

43 relative tolerance (to y) 

44 atol: 

45 absolute tolerance 

46 include_data: 

47 Compare data element-wise between x, and y 

48 include_attrs: 

49 Compare all meta (coords and attrs) between x and y, 

50 otherwise only compare coordinates from meta 

51 equal_nan: 

52 If ``True``, consider NaNs or infs to be equal 

53 providing that they match in location and, for infs, 

54 have the same sign 

55 

56 Returns 

57 ------- 

58 : 

59 ``True`` if near 

60 

61 Raises 

62 ------ 

63 Exception: 

64 If `x`, `y` are not going to be logically comparable 

65 for reasons relating to shape, item naming or non-finite elements. 

66 """ 

67 same_data = ( 

68 all(isclose(x.data, y.data, rtol=rtol, atol=atol, equal_nan=equal_nan)).value 

69 if include_data 

70 else True 

71 ) 

72 same_len = ( 

73 len(x.deprecated_meta) == len(y.deprecated_meta) 

74 if include_attrs 

75 else len(x.coords) == len(y.coords) 

76 ) 

77 if not same_len: 

78 return False 

79 for key, val in x.deprecated_meta.items() if include_attrs else x.coords.items(): 

80 a = x.deprecated_meta[key] if include_attrs else x.coords[key] 

81 b = y.deprecated_meta[key] if include_attrs else y.coords[key] 

82 if a.shape != b.shape: 

83 raise CoordError( 

84 f'Coord (or attr) with key {key} have different' 

85 f' shapes. For x, shape is {a.shape}. For y, shape = {b.shape}' 

86 ) 

87 if val.dtype in [DType.float64, DType.float32]: 

88 if not all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)).value: 

89 return False 

90 return same_data