Coverage for install/scipp/core/comparison.py: 60%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-17 01:51 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @author Simon Heybrock 

4# ruff: noqa: E501 

5 

6from __future__ import annotations 

7 

8from typing import Any 

9 

10import numpy as np 

11 

12from .._scipp import core as _cpp 

13from ..typing import VariableLike 

14from . import data_group 

15from ._cpp_wrapper_util import call_func as _call_cpp_func 

16from .cpp_classes import DataArray, Dataset, Variable 

17from .variable import scalar 

18 

19 

20def less(x: VariableLike, y: VariableLike) -> VariableLike: 

21 """Element-wise '<' (less). 

22 

23 Warning: If one or both of the operators have variances (uncertainties) 

24 they are ignored silently, i.e., comparison is based exclusively on 

25 the values. 

26 

27 Parameters 

28 ---------- 

29 x: 

30 Left input. 

31 y: 

32 Right input. 

33 

34 Returns 

35 ------- 

36 : 

37 Booleans that are true where `a < b`. 

38 """ 

39 return _call_cpp_func(_cpp.less, x, y) 

40 

41 

42def greater(x: VariableLike, y: VariableLike) -> VariableLike: 

43 """Element-wise '>' (greater). 

44 

45 Warning: If one or both of the operators have variances (uncertainties) 

46 they are ignored silently, i.e., comparison is based exclusively on 

47 the values. 

48 

49 Parameters 

50 ---------- 

51 x: 

52 Left input. 

53 y: 

54 Right input. 

55 

56 Returns 

57 ------- 

58 : 

59 Booleans that are true where `a > b`. 

60 """ 

61 return _call_cpp_func(_cpp.greater, x, y) 

62 

63 

64def less_equal(x: VariableLike, y: VariableLike) -> VariableLike: 

65 """Element-wise '<=' (less_equal). 

66 

67 Warning: If one or both of the operators have variances (uncertainties) 

68 they are ignored silently, i.e., comparison is based exclusively on 

69 the values. 

70 

71 Parameters 

72 ---------- 

73 x: 

74 Left input. 

75 y: 

76 Right input. 

77 

78 Returns 

79 ------- 

80 : 

81 Booleans that are true where `a <= b`. 

82 """ 

83 return _call_cpp_func(_cpp.less_equal, x, y) 

84 

85 

86def greater_equal(x: VariableLike, y: VariableLike) -> VariableLike: 

87 """Element-wise '>=' (greater_equal). 

88 

89 Warning: If one or both of the operators have variances (uncertainties) 

90 they are ignored silently, i.e., comparison is based exclusively on 

91 the values. 

92 

93 Parameters 

94 ---------- 

95 x: 

96 Left input. 

97 y: 

98 Right input. 

99 

100 Returns 

101 ------- 

102 : 

103 Booleans that are true where `a >= b`. 

104 """ 

105 return _call_cpp_func(_cpp.greater_equal, x, y) 

106 

107 

108def equal(x: VariableLike, y: VariableLike) -> VariableLike: 

109 """Element-wise '==' (equal). 

110 

111 Warning: If one or both of the operators have variances (uncertainties) 

112 they are ignored silently, i.e., comparison is based exclusively on 

113 the values. 

114 

115 Parameters 

116 ---------- 

117 x: 

118 Left input. 

119 y: 

120 Right input. 

121 

122 Returns 

123 ------- 

124 : 

125 Booleans that are true where `a == b`. 

126 """ 

127 return _call_cpp_func(_cpp.equal, x, y) 

128 

129 

130def not_equal(x: VariableLike, y: VariableLike) -> VariableLike: 

131 """Element-wise '!=' (not_equal). 

132 

133 Warning: If one or both of the operators have variances (uncertainties) 

134 they are ignored silently, i.e., comparison is based exclusively on 

135 the values. 

136 

137 Parameters 

138 ---------- 

139 x: 

140 Left input. 

141 y: 

142 Right input. 

143 

144 Returns 

145 ------- 

146 : 

147 Booleans that are true where `a != b`. 

148 """ 

149 return _call_cpp_func(_cpp.not_equal, x, y) 

150 

151 

152def _identical_data_groups( 

153 x: data_group.DataGroup[Any], y: data_group.DataGroup[Any], *, equal_nan: bool 

154) -> bool: 

155 def compare(a: Any, b: Any) -> bool: 

156 if not isinstance(a, type(b)): 

157 return False 

158 if isinstance(a, Variable | DataArray | Dataset | data_group.DataGroup): 

159 return identical(a, b, equal_nan=equal_nan) 

160 if isinstance(a, np.ndarray): 

161 return np.array_equal(a, b, equal_nan=equal_nan) 

162 # Explicit conversion to bool in case __eq__ returns 

163 # something else like an array. 

164 return bool(a == b) 

165 

166 if x.keys() != y.keys(): 

167 return False 

168 return all(compare(x[k], y[k]) for k in x.keys()) 

169 

170 

171def identical(x: VariableLike, y: VariableLike, *, equal_nan: bool = False) -> bool: 

172 """Full comparison of x and y. 

173 

174 Parameters 

175 ---------- 

176 x: 

177 Left input. 

178 y: 

179 Right input. 

180 equal_nan: 

181 If true, non-finite values at the same index in (x, y) are treated as equal. 

182 Signbit must match for infs. 

183 

184 Returns 

185 ------- 

186 : 

187 True if x and y have identical values, variances, dtypes, units, 

188 dims, shapes, coords, and masks. Else False. 

189 """ 

190 if isinstance(x, data_group.DataGroup): 

191 if not isinstance(y, data_group.DataGroup): 

192 raise TypeError("Both or neither of the arguments must be a DataGroup") 

193 return _identical_data_groups(x, y, equal_nan=equal_nan) 

194 

195 return _call_cpp_func(_cpp.identical, x, y, equal_nan=equal_nan) # type: ignore[return-value] 

196 

197 

198def isclose( 

199 x: _cpp.Variable, 

200 y: _cpp.Variable, 

201 *, 

202 rtol: _cpp.Variable = None, 

203 atol: _cpp.Variable = None, 

204 equal_nan: bool = False, 

205) -> _cpp.Variable: 

206 """Checks element-wise if the inputs are close to each other. 

207 

208 Compares values of x and y element by element against absolute 

209 and relative tolerances according to (non-symmetric) 

210 

211 .. code-block:: python 

212 

213 abs(x - y) <= atol + rtol * abs(y) 

214 

215 If both x and y have variances, the variances are also compared 

216 between elements. In this case, both values and variances must 

217 be within the computed tolerance limits. That is: 

218 

219 .. code-block:: python 

220 

221 abs(x.value - y.value) <= atol + rtol * abs(y.value) and 

222 abs(sqrt(x.variance) - sqrt(y.variance)) <= atol + rtol * abs(sqrt(y.variance)) 

223 

224 Attention 

225 --------- 

226 Vectors and matrices are compared element-wise. 

227 This is not necessarily a good measure for the similarity of `spatial` 

228 dtypes like ``scipp.DType.rotation3`` or ``scipp.Dtype.affine_transform3`` 

229 (see :mod:`scipp.spatial`). 

230 

231 Parameters 

232 ---------- 

233 x: 

234 Left input. 

235 y: 

236 Right input. 

237 rtol: 

238 Tolerance value relative (to y). 

239 Can be a scalar or non-scalar. 

240 Defaults to scalar 1e-5 if unset. 

241 atol: 

242 Tolerance value absolute. 

243 Can be a scalar or non-scalar. 

244 Defaults to scalar 1e-8 if unset and takes units from y arg. 

245 equal_nan: 

246 If true, non-finite values at the same index in (x, y) are treated as equal. 

247 Signbit must match for infs. 

248 

249 Returns 

250 ------- 

251 : 

252 Variable same size as input. 

253 Element True if absolute diff of value <= atol + rtol * abs(y), 

254 otherwise False. 

255 

256 See Also 

257 -------- 

258 scipp.allclose: 

259 Equivalent of ``sc.all(sc.isclose(...)).value``. 

260 """ 

261 if atol is None: 

262 atol = scalar(1e-8, unit=y.unit) 

263 if rtol is None: 

264 rtol = scalar(1e-5, unit=None if atol.unit is None else _cpp.units.one) 

265 return _call_cpp_func(_cpp.isclose, x, y, rtol, atol, equal_nan) 

266 

267 

268def allclose( 

269 x: _cpp.Variable, 

270 y: _cpp.Variable, 

271 rtol: _cpp.Variable = None, 

272 atol: _cpp.Variable = None, 

273 equal_nan: bool = False, 

274) -> bool: 

275 """Checks if all elements in the inputs are close to each other. 

276 

277 Verifies that ALL element-wise comparisons meet the condition: 

278 

279 abs(x - y) <= atol + rtol * abs(y) 

280 

281 If both x and y have variances, the variances are also compared 

282 between elements. In this case, both values and variances must 

283 be within the computed tolerance limits. That is: 

284 

285 .. code-block:: python 

286 

287 abs(x.value - y.value) <= atol + rtol * abs(y.value) and abs( 

288 sqrt(x.variance) - sqrt(y.variance)) \ 

289 <= atol + rtol * abs(sqrt(y.variance)) 

290 

291 Attention 

292 --------- 

293 Vectors and matrices are compared element-wise. 

294 This is not necessarily a good measure for the similarity of `spatial` 

295 dtypes like ``scipp.DType.rotation3`` or ``scipp.Dtype.affine_transform3`` 

296 (see :mod:`scipp.spatial`). 

297 

298 Parameters 

299 ---------- 

300 x: 

301 Left input. 

302 y: 

303 Right input. 

304 rtol: 

305 Tolerance value relative (to y). 

306 Can be a scalar or non-scalar. 

307 Cannot have variances. 

308 Defaults to scalar 1e-5 if unset. 

309 atol: 

310 Tolerance value absolute. 

311 Can be a scalar or non-scalar. 

312 Cannot have variances. 

313 Defaults to scalar 1e-8 if unset and takes units from y arg. 

314 equal_nan: 

315 If true, non-finite values at the same index in (x, y) are treated as equal. 

316 Signbit must match for infs. 

317 

318 Returns 

319 ------- 

320 : 

321 True if for all elements ``value <= atol + rtol * abs(y)``, otherwise False. 

322 

323 See Also 

324 -------- 

325 scipp.isclose: 

326 Compares element-wise with specified tolerances. 

327 """ 

328 return _call_cpp_func( # type:ignore[no-any-return] 

329 _cpp.all, isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) 

330 ).value # type: ignore[union-attr]