Coverage for install/scipp/io/hdf5.py: 60%

297 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-11-17 01:51 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @file 

4# @author Simon Heybrock 

5 

6from __future__ import annotations 

7 

8from collections.abc import Mapping 

9from io import BytesIO, StringIO 

10from os import PathLike 

11from typing import TYPE_CHECKING, Any, ClassVar, Protocol 

12 

13import numpy as np 

14import numpy.typing as npt 

15 

16from .._scipp import core as _cpp 

17from ..core import ( 

18 DataArray, 

19 DataGroup, 

20 Dataset, 

21 DType, 

22 DTypeError, 

23 Unit, 

24 Variable, 

25 bins, 

26) 

27from ..logging import get_logger 

28from ..typing import VariableLike 

29 

30if TYPE_CHECKING: 

31 import h5py as h5 

32else: 

33 h5 = Any 

34 

35 

36def _dtype_lut() -> dict[str, DType]: 

37 # For types understood by numpy we do not actually need this special 

38 # handling, but will do as we add support for other types such as 

39 # variable-length strings. 

40 dtypes = [ 

41 DType.float64, 

42 DType.float32, 

43 DType.int64, 

44 DType.int32, 

45 DType.bool, 

46 DType.datetime64, 

47 DType.string, 

48 DType.Variable, 

49 DType.DataArray, 

50 DType.Dataset, 

51 DType.VariableView, 

52 DType.DataArrayView, 

53 DType.DatasetView, 

54 DType.vector3, 

55 DType.linear_transform3, 

56 DType.affine_transform3, 

57 DType.translation3, 

58 DType.rotation3, 

59 ] 

60 names = [str(dtype) for dtype in dtypes] 

61 return dict(zip(names, dtypes, strict=True)) 

62 

63 

64def _as_hdf5_type(a: npt.NDArray[Any]) -> npt.NDArray[Any]: 

65 if np.issubdtype(a.dtype, np.datetime64): 

66 return a.view(np.int64) 

67 return a 

68 

69 

70def _collection_element_name(name: str, index: int) -> str: 

71 """ 

72 Convert name into an ASCII string that can be used as an object name in HDF5. 

73 """ 

74 ascii_name = ( 

75 name.replace('.', '.') 

76 .replace('/', '/') 

77 .encode('ascii', 'xmlcharrefreplace') 

78 .decode('ascii') 

79 ) 

80 return f'elem_{index:03d}_{ascii_name}' 

81 

82 

83class _DataWriter(Protocol): 

84 @staticmethod 

85 def write(group: h5.Group, data: Variable) -> h5.Dataset | h5.Group: ... 

86 

87 

88# All readers are writers and have the same interface for writing. 

89# But readers for array variables expect a `data` argument that is 

90# absent in other readers. 

91class _ArrayDataIO(_DataWriter, Protocol): 

92 @staticmethod 

93 def read(group: h5.Group, data: Variable) -> None: ... 

94 

95 

96class _NumpyDataIO: 

97 @staticmethod 

98 def write(group: h5.Group, data: Variable) -> h5.Dataset: 

99 dset = group.create_dataset('values', data=_as_hdf5_type(data.values)) 

100 if data.variances is not None: 

101 variances = group.create_dataset('variances', data=data.variances) 

102 dset.attrs['variances'] = variances.ref 

103 return dset 

104 

105 @staticmethod 

106 def read(group: h5.Group, data: Variable) -> None: 

107 # h5py's read_direct method fails if any dim has zero size. 

108 # see https://github.com/h5py/h5py/issues/870 

109 if data.values.flags['C_CONTIGUOUS'] and data.values.size > 0: 

110 group['values'].read_direct(_as_hdf5_type(data.values)) 

111 else: 

112 # Values of Eigen matrices are transposed 

113 data.values = group['values'] 

114 if 'variances' in group and data.variances.size > 0: 

115 group['variances'].read_direct(data.variances) 

116 

117 

118class _BinDataIO: 

119 @staticmethod 

120 def write(group: h5.Group, data: Variable) -> h5.Group: 

121 if data.bins is None: 

122 raise DTypeError("Expected binned data") 

123 bins = data.bins.constituents 

124 buffer_len = bins['data'].sizes[bins['dim']] 

125 # Crude mechanism to avoid writing large buffers, e.g., from 

126 # overallocation or when writing a slice of a larger variable. The 

127 # copy causes some overhead, but so would the (much more complicated) 

128 # solution to extract contents bin-by-bin. This approach will likely 

129 # need to be revisited in the future. 

130 if buffer_len > 1.5 * data.bins.size().sum().value: 

131 data = data.copy() 

132 bins = data.bins.constituents # type: ignore[union-attr] 

133 values = group.create_group('values') 

134 _VariableIO.write(values.create_group('begin'), var=bins['begin']) 

135 _VariableIO.write(values.create_group('end'), var=bins['end']) 

136 data_group = values.create_group('data') 

137 data_group.attrs['dim'] = bins['dim'] 

138 _HDF5IO.write(data_group, bins['data']) 

139 return values 

140 

141 @staticmethod 

142 def read(group: h5.Group) -> Variable: 

143 values = group['values'] 

144 begin = _VariableIO.read(values['begin']) 

145 end = _VariableIO.read(values['end']) 

146 dim = values['data'].attrs['dim'] 

147 data = _HDF5IO.read(values['data']) 

148 return bins(begin=begin, end=end, dim=dim, data=data) 

149 

150 

151class _ScippDataIO: 

152 @staticmethod 

153 def write(group: h5.Group, data: Variable) -> h5.Group: 

154 values = group.create_group('values') 

155 if len(data.shape) == 0: 

156 _HDF5IO.write(values, data.value) 

157 else: 

158 for i, item in enumerate(data.values): 

159 _HDF5IO.write(values.create_group(f'value-{i}'), item) 

160 return values 

161 

162 @staticmethod 

163 def read(group: h5.Group, data: Variable) -> None: 

164 values = group['values'] 

165 if len(data.shape) == 0: 

166 data.value = _HDF5IO.read(values) 

167 else: 

168 for i in range(len(data.values)): 

169 data.values[i] = _HDF5IO.read(values[f'value-{i}']) 

170 

171 

172class _StringDataIO: 

173 @staticmethod 

174 def write(group: h5.Group, data: Variable) -> h5.Dataset: 

175 import h5py 

176 

177 dt = h5py.string_dtype(encoding='utf-8') 

178 dset = group.create_dataset('values', shape=data.shape, dtype=dt) 

179 if len(data.shape) == 0: 

180 dset[()] = data.value 

181 else: 

182 for i in range(len(data.values)): 

183 dset[i] = data.values[i] 

184 return dset 

185 

186 @staticmethod 

187 def read(group: h5.Group, data: Variable) -> None: 

188 values = group['values'] 

189 if len(data.shape) == 0: 

190 data.value = values[()] 

191 else: 

192 for i in range(len(data.values)): 

193 data.values[i] = values[i] 

194 

195 

196def _write_scipp_header(group: h5.Group, what: str) -> None: 

197 from ..core import __version__ 

198 

199 group.attrs['scipp-version'] = __version__ 

200 group.attrs['scipp-type'] = what 

201 

202 

203def _check_scipp_header(group: h5.Group, what: str) -> None: 

204 if 'scipp-version' not in group.attrs: 

205 raise RuntimeError( 

206 "This does not look like an HDF5 file/group written by Scipp." 

207 ) 

208 if group.attrs['scipp-type'] != what: 

209 raise RuntimeError( 

210 f"Attempt to read {what}, found {group.attrs['scipp-type']}." 

211 ) 

212 

213 

214def _array_data_io_lut() -> dict[str, _ArrayDataIO]: 

215 handler: dict[str, _ArrayDataIO] = {} 

216 for dtype in [ 

217 DType.float64, 

218 DType.float32, 

219 DType.int64, 

220 DType.int32, 

221 DType.bool, 

222 DType.datetime64, 

223 DType.vector3, 

224 DType.linear_transform3, 

225 DType.rotation3, 

226 DType.translation3, 

227 DType.affine_transform3, 

228 ]: 

229 handler[str(dtype)] = _NumpyDataIO 

230 for dtype in [DType.Variable, DType.DataArray, DType.Dataset]: 

231 handler[str(dtype)] = _ScippDataIO 

232 for dtype in [DType.string]: 

233 handler[str(dtype)] = _StringDataIO 

234 return handler 

235 

236 

237def _data_writer_lut() -> dict[str, _DataWriter]: 

238 # Unpack and repack dict to cast to the correct value type. 

239 handler: dict[str, _DataWriter] = {**_array_data_io_lut()} 

240 for dtype in [DType.VariableView, DType.DataArrayView, DType.DatasetView]: 

241 handler[str(dtype)] = _BinDataIO 

242 return handler 

243 

244 

245def _serialize_unit(unit: Unit) -> npt.NDArray[Any]: 

246 unit_dict = unit.to_dict() 

247 dtype: list[tuple[str, Any]] = [('__version__', int), ('multiplier', float)] 

248 vals = [unit_dict['__version__'], unit_dict['multiplier']] 

249 if 'powers' in unit_dict: 

250 dtype.append(('powers', [(name, int) for name in unit_dict['powers']])) 

251 vals.append(tuple(val for val in unit_dict['powers'].values())) 

252 return np.array(tuple(vals), dtype=dtype) 

253 

254 

255def _read_unit_attr(ds: h5.Dataset) -> Unit: 

256 u = ds.attrs['unit'] 

257 if isinstance(u, str): 

258 return Unit(u) # legacy encoding as a string 

259 

260 # u is a structured numpy array 

261 unit_dict = {'__version__': u['__version__'], 'multiplier': u['multiplier']} 

262 if 'powers' in u.dtype.names: 

263 unit_dict['powers'] = { 

264 name: u['powers'][name] for name in u['powers'].dtype.names 

265 } 

266 return Unit.from_dict(unit_dict) 

267 

268 

269class _VariableIO: 

270 _dtypes = _dtype_lut() 

271 _array_data_readers = _array_data_io_lut() 

272 _data_writers = _data_writer_lut() 

273 

274 @classmethod 

275 def _write_data(cls, group: h5.Group, data: Variable) -> h5.Dataset | h5.Group: 

276 return cls._data_writers[str(data.dtype)].write(group, data) 

277 

278 @classmethod 

279 def _read_array_data(cls, group: h5.Group, data: Variable) -> None: 

280 cls._array_data_readers[str(data.dtype)].read(group, data) 

281 

282 @classmethod 

283 def write(cls, group: h5.Group, var: Variable) -> h5.Group | None: 

284 if var.dtype not in cls._dtypes.values(): 

285 # In practice this may make the file unreadable, e.g., if values 

286 # have unsupported dtype. 

287 get_logger().warning( 

288 'Writing with dtype=%s not implemented, skipping.', var.dtype 

289 ) 

290 return None 

291 _write_scipp_header(group, 'Variable') 

292 dset = cls._write_data(group, var) 

293 dset.attrs['dims'] = [str(dim) for dim in var.dims] 

294 dset.attrs['shape'] = var.shape 

295 dset.attrs['dtype'] = str(var.dtype) 

296 if var.unit is not None: 

297 dset.attrs['unit'] = _serialize_unit(var.unit) 

298 dset.attrs['aligned'] = var.aligned 

299 return group 

300 

301 @classmethod 

302 def read(cls, group: h5.Group) -> Variable: 

303 _check_scipp_header(group, 'Variable') 

304 values = group['values'] 

305 contents = {key: values.attrs[key] for key in ['dims', 'shape']} 

306 contents['dtype'] = cls._dtypes[values.attrs['dtype']] 

307 if 'unit' in values.attrs: 

308 contents['unit'] = _read_unit_attr(values) 

309 else: 

310 contents['unit'] = None # essential, otherwise default unit is used 

311 contents['with_variances'] = 'variances' in group 

312 contents['aligned'] = values.attrs.get('aligned', True) 

313 if contents['dtype'] in [ 

314 DType.VariableView, 

315 DType.DataArrayView, 

316 DType.DatasetView, 

317 ]: 

318 var = _BinDataIO.read(group) 

319 else: 

320 var = _cpp.empty(**contents) 

321 cls._read_array_data(group, var) 

322 return var 

323 

324 

325def _write_mapping( 

326 parent: h5.Group, 

327 mapping: Mapping[str, VariableLike], 

328 override: Mapping[str, h5.Group] | None = None, 

329) -> None: 

330 if override is None: 

331 override = {} 

332 for i, name in enumerate(mapping): 

333 var_group_name = _collection_element_name(name, i) 

334 if (g := override.get(name)) is not None: 

335 parent[var_group_name] = g 

336 else: 

337 g = _HDF5IO.write( 

338 group=parent.create_group(var_group_name), data=mapping[name] 

339 ) 

340 if g is None: 

341 del parent[var_group_name] 

342 else: 

343 g.attrs['name'] = str(name) 

344 

345 

346def _read_mapping( 

347 group: h5.Group, override: Mapping[str, h5.Group] | None = None 

348) -> VariableLike: 

349 if override is None: 

350 override = {} 

351 return { 

352 g.attrs['name']: override[g.attrs['name']] 

353 if g.attrs['name'] in override 

354 else _HDF5IO.read(g) 

355 for g in group.values() 

356 } 

357 

358 

359class _DataArrayIO: 

360 @staticmethod 

361 def write(group: h5.Group, data, override=None): 

362 if override is None: 

363 override = {} 

364 _write_scipp_header(group, 'DataArray') 

365 group.attrs['name'] = data.name 

366 if _VariableIO.write(group.create_group('data'), var=data.data) is None: 

367 return None 

368 views = [data.coords, data.masks, data.attrs] 

369 # Note that we write aligned and unaligned coords into the same group. 

370 # Distinction is via an attribute, which is more natural than having 

371 # 2 separate groups. 

372 for view_name, view in zip(['coords', 'masks', 'attrs'], views, strict=True): 

373 subgroup = group.create_group(view_name) 

374 _write_mapping(subgroup, view, override.get(view_name)) 

375 return group 

376 

377 @staticmethod 

378 def read(group: h5.Group, override=None): 

379 _check_scipp_header(group, 'DataArray') 

380 if override is None: 

381 override = {} 

382 contents = {} 

383 contents['name'] = group.attrs['name'] 

384 contents['data'] = _VariableIO.read(group['data']) 

385 for category in ['coords', 'masks', 'attrs']: 

386 contents[category] = _read_mapping(group[category], override.get(category)) 

387 return DataArray(**contents) 

388 

389 

390class _DatasetIO: 

391 @staticmethod 

392 def write(group: h5.Group, data): 

393 _write_scipp_header(group, 'Dataset') 

394 coords = group.create_group('coords') 

395 _write_mapping(coords, data.coords) 

396 entries = group.create_group('entries') 

397 # We cannot use coords directly, since we need lookup by name. The key used as 

398 # group name includes an integer index which may differ when writing items and 

399 # is not sufficient. 

400 coords = {v.attrs['name']: v for v in coords.values()} 

401 for i, (name, da) in enumerate(data.items()): 

402 _HDF5IO.write( 

403 entries.create_group(_collection_element_name(name, i)), 

404 da, 

405 override={'coords': coords}, 

406 ) 

407 return group 

408 

409 @staticmethod 

410 def read(group: h5.Group): 

411 _check_scipp_header(group, 'Dataset') 

412 coords = _read_mapping(group['coords']) 

413 return Dataset( 

414 coords=coords, 

415 data=_read_mapping(group['entries'], override={'coords': coords}), 

416 ) 

417 

418 

419class _DataGroupIO: 

420 @staticmethod 

421 def write(group: h5.Group, data): 

422 _write_scipp_header(group, 'DataGroup') 

423 entries = group.create_group('entries') 

424 _write_mapping(entries, data) 

425 return group 

426 

427 @staticmethod 

428 def read(group: h5.Group): 

429 _check_scipp_header(group, 'DataGroup') 

430 return DataGroup(_read_mapping(group['entries'])) 

431 

432 

433def _direct_io(cls, convert=None): 

434 # Use fully qualified name for numpy.bool to avoid confusion with builtin bool." 

435 # Numpy bools have different names in 1.x and 2.x 

436 # TODO: This should be fixed once we drop support for numpy<2 

437 if cls.__module__ == 'numpy' and cls.__name__[:4] == 'bool': 

438 type_name = 'numpy.bool' 

439 else: 

440 type_name = cls.__name__ 

441 if convert is None: 

442 convert = cls 

443 

444 class GenericIO: 

445 @staticmethod 

446 def write(group, data): 

447 _write_scipp_header(group, type_name) 

448 group['entry'] = data 

449 return group 

450 

451 @staticmethod 

452 def read(group): 

453 _check_scipp_header(group, type_name) 

454 return convert(group['entry'][()]) 

455 

456 return GenericIO 

457 

458 

459class _HDF5IO: 

460 _handlers: ClassVar[dict[str, Any]] = { 

461 'Variable': _VariableIO, 

462 'DataArray': _DataArrayIO, 

463 'Dataset': _DatasetIO, 

464 'DataGroup': _DataGroupIO, 

465 'str': _direct_io(str, convert=lambda b: b.decode('utf-8')), 

466 'ndarray': _direct_io(np.ndarray, convert=lambda x: x), 

467 'numpy.bool': _direct_io(np.bool_), 

468 **{ 

469 cls.__name__: _direct_io(cls) 

470 for cls in ( 

471 int, 

472 np.int64, 

473 np.int32, 

474 np.uint64, 

475 np.uint32, 

476 float, 

477 np.float32, 

478 np.float64, 

479 bool, 

480 bytes, 

481 ) 

482 }, 

483 } 

484 

485 @classmethod 

486 def write(cls, group: h5.Group, data, **kwargs): 

487 data_cls = data.__class__ 

488 # Numpy bools have different names in 1.x and 2.x 

489 # TODO: This should be fixed once we drop support for numpy<2 

490 if data_cls.__module__ == 'numpy' and data_cls.__name__[:4] == 'bool': 

491 name = 'numpy.bool' 

492 else: 

493 name = data_cls.__name__.replace('View', '') 

494 try: 

495 handler = cls._handlers[name] 

496 except KeyError: 

497 get_logger().warning( 

498 "Writing type '%s' to HDF5 not implemented, skipping.", type(data) 

499 ) 

500 return None 

501 return handler.write(group, data, **kwargs) 

502 

503 @classmethod 

504 def read(cls, group: h5.Group, **kwargs): 

505 return cls._handlers[group.attrs['scipp-type']].read(group, **kwargs) 

506 

507 

508def save_hdf5( 

509 obj: VariableLike, 

510 filename: str | PathLike[str] | StringIO | BytesIO | h5.Group, 

511) -> None: 

512 """Write an object out to file in HDF5 format.""" 

513 import h5py 

514 

515 if isinstance(filename, h5py.Group): 

516 return _HDF5IO.write(filename, obj) 

517 

518 with h5py.File(filename, 'w') as f: 

519 _HDF5IO.write(f, obj) 

520 

521 

522def load_hdf5( 

523 filename: str | PathLike[str] | StringIO | BytesIO | h5.Group, 

524) -> VariableLike: 

525 """Load a Scipp-HDF5 file.""" 

526 import h5py 

527 

528 if isinstance(filename, h5py.Group): 

529 return _HDF5IO.read(filename) 

530 

531 with h5py.File(filename, 'r') as f: 

532 return _HDF5IO.read(f)