Coverage for install/scipp/core/comparison.py: 61%

46 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-04-28 01:28 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Simon Heybrock 

4# ruff: noqa: E501 

5 

6from __future__ import annotations 

7 

8import numpy as np 

9 

10from .._scipp import core as _cpp 

11from ..typing import VariableLike 

12from . import data_group 

13from ._cpp_wrapper_util import call_func as _call_cpp_func 

14from .cpp_classes import DataArray, Dataset, Variable 

15from .variable import scalar 

16 

17 

18def less(x: VariableLike, y: VariableLike) -> VariableLike: 

19 """Element-wise '<' (less). 

20 

21 Warning: If one or both of the operators have variances (uncertainties) 

22 they are ignored silently, i.e., comparison is based exclusively on 

23 the values. 

24 

25 Parameters 

26 ---------- 

27 x: 

28 Left input. 

29 y: 

30 Right input. 

31 

32 Returns 

33 ------- 

34 : 

35 Booleans that are true where `a < b`. 

36 """ 

37 return _call_cpp_func(_cpp.less, x, y) 

38 

39 

40def greater(x: VariableLike, y: VariableLike) -> VariableLike: 

41 """Element-wise '>' (greater). 

42 

43 Warning: If one or both of the operators have variances (uncertainties) 

44 they are ignored silently, i.e., comparison is based exclusively on 

45 the values. 

46 

47 Parameters 

48 ---------- 

49 x: 

50 Left input. 

51 y: 

52 Right input. 

53 

54 Returns 

55 ------- 

56 : 

57 Booleans that are true where `a > b`. 

58 """ 

59 return _call_cpp_func(_cpp.greater, x, y) 

60 

61 

62def less_equal(x: VariableLike, y: VariableLike) -> VariableLike: 

63 """Element-wise '<=' (less_equal). 

64 

65 Warning: If one or both of the operators have variances (uncertainties) 

66 they are ignored silently, i.e., comparison is based exclusively on 

67 the values. 

68 

69 Parameters 

70 ---------- 

71 x: 

72 Left input. 

73 y: 

74 Right input. 

75 

76 Returns 

77 ------- 

78 : 

79 Booleans that are true where `a <= b`. 

80 """ 

81 return _call_cpp_func(_cpp.less_equal, x, y) 

82 

83 

84def greater_equal(x: VariableLike, y: VariableLike) -> VariableLike: 

85 """Element-wise '>=' (greater_equal). 

86 

87 Warning: If one or both of the operators have variances (uncertainties) 

88 they are ignored silently, i.e., comparison is based exclusively on 

89 the values. 

90 

91 Parameters 

92 ---------- 

93 x: 

94 Left input. 

95 y: 

96 Right input. 

97 

98 Returns 

99 ------- 

100 : 

101 Booleans that are true where `a >= b`. 

102 """ 

103 return _call_cpp_func(_cpp.greater_equal, x, y) 

104 

105 

106def equal(x: VariableLike, y: VariableLike) -> VariableLike: 

107 """Element-wise '==' (equal). 

108 

109 Warning: If one or both of the operators have variances (uncertainties) 

110 they are ignored silently, i.e., comparison is based exclusively on 

111 the values. 

112 

113 Parameters 

114 ---------- 

115 x: 

116 Left input. 

117 y: 

118 Right input. 

119 

120 Returns 

121 ------- 

122 : 

123 Booleans that are true where `a == b`. 

124 """ 

125 return _call_cpp_func(_cpp.equal, x, y) 

126 

127 

128def not_equal(x: VariableLike, y: VariableLike) -> VariableLike: 

129 """Element-wise '!=' (not_equal). 

130 

131 Warning: If one or both of the operators have variances (uncertainties) 

132 they are ignored silently, i.e., comparison is based exclusively on 

133 the values. 

134 

135 Parameters 

136 ---------- 

137 x: 

138 Left input. 

139 y: 

140 Right input. 

141 

142 Returns 

143 ------- 

144 : 

145 Booleans that are true where `a != b`. 

146 """ 

147 return _call_cpp_func(_cpp.not_equal, x, y) 

148 

149 

150def _identical_data_groups( 

151 x: data_group.DataGroup, y: data_group.DataGroup, *, equal_nan: bool 

152) -> bool: 

153 def compare(a, b): 

154 if not isinstance(a, type(b)): 

155 return False 

156 if isinstance(a, (Variable, DataArray, Dataset, data_group.DataGroup)): 

157 return identical(a, b, equal_nan=equal_nan) 

158 if isinstance(a, np.ndarray): 

159 return np.array_equal(a, b, equal_nan=equal_nan) 

160 # Explicit conversion to bool in case __eq__ returns 

161 # something else like an array. 

162 return bool(a == b) 

163 

164 if x.keys() != y.keys(): 

165 return False 

166 return all(compare(x[k], y[k]) for k in x.keys()) 

167 

168 

169def identical(x: VariableLike, y: VariableLike, *, equal_nan: bool = False) -> bool: 

170 """Full comparison of x and y. 

171 

172 Parameters 

173 ---------- 

174 x: 

175 Left input. 

176 y: 

177 Right input. 

178 equal_nan: 

179 If true, non-finite values at the same index in (x, y) are treated as equal. 

180 Signbit must match for infs. 

181 

182 Returns 

183 ------- 

184 : 

185 True if x and y have identical values, variances, dtypes, units, 

186 dims, shapes, coords, and masks. Else False. 

187 """ 

188 if isinstance(x, data_group.DataGroup): 

189 if not isinstance(y, data_group.DataGroup): 

190 raise TypeError("Both or neither of the arguments must be a DataGroup") 

191 return _identical_data_groups(x, y, equal_nan=equal_nan) 

192 

193 return _call_cpp_func(_cpp.identical, x, y, equal_nan=equal_nan) 

194 

195 

196def isclose( 

197 x: _cpp.Variable, 

198 y: _cpp.Variable, 

199 *, 

200 rtol: _cpp.Variable = None, 

201 atol: _cpp.Variable = None, 

202 equal_nan: bool = False, 

203) -> _cpp.Variable: 

204 """Checks element-wise if the inputs are close to each other. 

205 

206 Compares values of x and y element by element against absolute 

207 and relative tolerances according to (non-symmetric) 

208 

209 .. code-block:: python 

210 

211 abs(x - y) <= atol + rtol * abs(y) 

212 

213 If both x and y have variances, the variances are also compared 

214 between elements. In this case, both values and variances must 

215 be within the computed tolerance limits. That is: 

216 

217 .. code-block:: python 

218 

219 abs(x.value - y.value) <= atol + rtol * abs(y.value) and 

220 abs(sqrt(x.variance) - sqrt(y.variance)) <= atol + rtol * abs(sqrt(y.variance)) 

221 

222 Attention 

223 --------- 

224 Vectors and matrices are compared element-wise. 

225 This is not necessarily a good measure for the similarity of `spatial` 

226 dtypes like ``scipp.DType.rotation3`` or ``scipp.Dtype.affine_transform3`` 

227 (see :mod:`scipp.spatial`). 

228 

229 Parameters 

230 ---------- 

231 x: 

232 Left input. 

233 y: 

234 Right input. 

235 rtol: 

236 Tolerance value relative (to y). 

237 Can be a scalar or non-scalar. 

238 Defaults to scalar 1e-5 if unset. 

239 atol: 

240 Tolerance value absolute. 

241 Can be a scalar or non-scalar. 

242 Defaults to scalar 1e-8 if unset and takes units from y arg. 

243 equal_nan: 

244 If true, non-finite values at the same index in (x, y) are treated as equal. 

245 Signbit must match for infs. 

246 

247 Returns 

248 ------- 

249 : 

250 Variable same size as input. 

251 Element True if absolute diff of value <= atol + rtol * abs(y), 

252 otherwise False. 

253 

254 See Also 

255 -------- 

256 scipp.allclose: 

257 Equivalent of ``sc.all(sc.isclose(...)).value``. 

258 """ 

259 if atol is None: 

260 atol = scalar(1e-8, unit=y.unit) 

261 if rtol is None: 

262 rtol = scalar(1e-5, unit=None if atol.unit is None else _cpp.units.one) 

263 return _call_cpp_func(_cpp.isclose, x, y, rtol, atol, equal_nan) 

264 

265 

266def allclose(x, y, rtol=None, atol=None, equal_nan=False) -> True: 

267 """Checks if all elements in the inputs are close to each other. 

268 

269 Verifies that ALL element-wise comparisons meet the condition: 

270 

271 abs(x - y) <= atol + rtol * abs(y) 

272 

273 If both x and y have variances, the variances are also compared 

274 between elements. In this case, both values and variances must 

275 be within the computed tolerance limits. That is: 

276 

277 .. code-block:: python 

278 

279 abs(x.value - y.value) <= atol + rtol * abs(y.value) and abs( 

280 sqrt(x.variance) - sqrt(y.variance)) \ 

281 <= atol + rtol * abs(sqrt(y.variance)) 

282 

283 Attention 

284 --------- 

285 Vectors and matrices are compared element-wise. 

286 This is not necessarily a good measure for the similarity of `spatial` 

287 dtypes like ``scipp.DType.rotation3`` or ``scipp.Dtype.affine_transform3`` 

288 (see :mod:`scipp.spatial`). 

289 

290 Parameters 

291 ---------- 

292 x: 

293 Left input. 

294 y: 

295 Right input. 

296 rtol: 

297 Tolerance value relative (to y). 

298 Can be a scalar or non-scalar. 

299 Cannot have variances. 

300 Defaults to scalar 1e-5 if unset. 

301 atol: 

302 Tolerance value absolute. 

303 Can be a scalar or non-scalar. 

304 Cannot have variances. 

305 Defaults to scalar 1e-8 if unset and takes units from y arg. 

306 equal_nan: 

307 If true, non-finite values at the same index in (x, y) are treated as equal. 

308 Signbit must match for infs. 

309 

310 Returns 

311 ------- 

312 : 

313 True if for all elements ``value <= atol + rtol * abs(y)``, otherwise False. 

314 

315 See Also 

316 -------- 

317 scipp.isclose: 

318 Compares element-wise with specified tolerances. 

319 """ 

320 return _call_cpp_func( 

321 _cpp.all, isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) 

322 ).value