Coverage for install/scipp/io/hdf5.py: 64%

277 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-04-28 01:28 +0000

1# SPDX-License-Identifier: BSD-3-Clause 

2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) 

3# @file 

4# @author Simon Heybrock 

5 

6from __future__ import annotations 

7 

8from pathlib import Path 

9from typing import Any, ClassVar, Union 

10 

11import numpy as np 

12 

13from ..core.cpp_classes import Unit 

14from ..logging import get_logger 

15from ..typing import VariableLike 

16 

17 

18def _dtype_lut(): 

19 from .._scipp.core import DType as d 

20 

21 # For types understood by numpy we do not actually need this special 

22 # handling, but will do as we add support for other types such as 

23 # variable-length strings. 

24 dtypes = [ 

25 d.float64, 

26 d.float32, 

27 d.int64, 

28 d.int32, 

29 d.bool, 

30 d.datetime64, 

31 d.string, 

32 d.Variable, 

33 d.DataArray, 

34 d.Dataset, 

35 d.VariableView, 

36 d.DataArrayView, 

37 d.DatasetView, 

38 d.vector3, 

39 d.linear_transform3, 

40 d.affine_transform3, 

41 d.translation3, 

42 d.rotation3, 

43 ] 

44 names = [str(dtype) for dtype in dtypes] 

45 return dict(zip(names, dtypes)) 

46 

47 

48def _as_hdf5_type(a): 

49 if np.issubdtype(a.dtype, np.datetime64): 

50 return a.view(np.int64) 

51 return a 

52 

53 

54def collection_element_name(name, index): 

55 """ 

56 Convert name into an ASCII string that can be used as an object name in HDF5. 

57 """ 

58 ascii_name = ( 

59 name.replace('.', '.') 

60 .replace('/', '/') 

61 .encode('ascii', 'xmlcharrefreplace') 

62 .decode('ascii') 

63 ) 

64 return f'elem_{index:03d}_{ascii_name}' 

65 

66 

67class NumpyDataIO: 

68 @staticmethod 

69 def write(group, data): 

70 dset = group.create_dataset('values', data=_as_hdf5_type(data.values)) 

71 if data.variances is not None: 

72 variances = group.create_dataset('variances', data=data.variances) 

73 dset.attrs['variances'] = variances.ref 

74 return dset 

75 

76 @staticmethod 

77 def read(group, data): 

78 # h5py's read_direct method fails if any dim has zero size. 

79 # see https://github.com/h5py/h5py/issues/870 

80 if data.values.flags['C_CONTIGUOUS'] and data.values.size > 0: 

81 group['values'].read_direct(_as_hdf5_type(data.values)) 

82 else: 

83 # Values of Eigen matrices are transposed 

84 data.values = group['values'] 

85 if 'variances' in group and data.variances.size > 0: 

86 group['variances'].read_direct(data.variances) 

87 

88 

89class BinDataIO: 

90 @staticmethod 

91 def write(group, data): 

92 bins = data.bins.constituents 

93 buffer_len = bins['data'].sizes[bins['dim']] 

94 # Crude mechanism to avoid writing large buffers, e.g., from 

95 # overallocation or when writing a slice of a larger variable. The 

96 # copy causes some overhead, but so would the (much more complicated) 

97 # solution to extract contents bin-by-bin. This approach will likely 

98 # need to be revisited in the future. 

99 if buffer_len > 1.5 * data.bins.size().sum().value: 

100 data = data.copy() 

101 bins = data.bins.constituents 

102 values = group.create_group('values') 

103 VariableIO.write(values.create_group('begin'), var=bins['begin']) 

104 VariableIO.write(values.create_group('end'), var=bins['end']) 

105 data_group = values.create_group('data') 

106 data_group.attrs['dim'] = bins['dim'] 

107 HDF5IO.write(data_group, bins['data']) 

108 return values 

109 

110 @staticmethod 

111 def read(group): 

112 from .._scipp import core as sc 

113 

114 values = group['values'] 

115 begin = VariableIO.read(values['begin']) 

116 end = VariableIO.read(values['end']) 

117 dim = values['data'].attrs['dim'] 

118 data = HDF5IO.read(values['data']) 

119 return sc.bins(begin=begin, end=end, dim=dim, data=data) 

120 

121 

122class ScippDataIO: 

123 @staticmethod 

124 def write(group, data): 

125 values = group.create_group('values') 

126 if len(data.shape) == 0: 

127 HDF5IO.write(values, data.value) 

128 else: 

129 for i, item in enumerate(data.values): 

130 HDF5IO.write(values.create_group(f'value-{i}'), item) 

131 return values 

132 

133 @staticmethod 

134 def read(group, data): 

135 values = group['values'] 

136 if len(data.shape) == 0: 

137 data.value = HDF5IO.read(values) 

138 else: 

139 for i in range(len(data.values)): 

140 data.values[i] = HDF5IO.read(values[f'value-{i}']) 

141 

142 

143class StringDataIO: 

144 @staticmethod 

145 def write(group, data): 

146 import h5py 

147 

148 dt = h5py.string_dtype(encoding='utf-8') 

149 dset = group.create_dataset('values', shape=data.shape, dtype=dt) 

150 if len(data.shape) == 0: 

151 dset[()] = data.value 

152 else: 

153 for i in range(len(data.values)): 

154 dset[i] = data.values[i] 

155 return dset 

156 

157 @staticmethod 

158 def read(group, data): 

159 values = group['values'] 

160 if len(data.shape) == 0: 

161 data.value = values[()] 

162 else: 

163 for i in range(len(data.values)): 

164 data.values[i] = values[i] 

165 

166 

167def _write_scipp_header(group, what): 

168 from .._scipp import __version__ 

169 

170 group.attrs['scipp-version'] = __version__ 

171 group.attrs['scipp-type'] = what 

172 

173 

174def _check_scipp_header(group, what): 

175 if 'scipp-version' not in group.attrs: 

176 raise RuntimeError( 

177 "This does not look like an HDF5 file/group written by Scipp." 

178 ) 

179 if group.attrs['scipp-type'] != what: 

180 raise RuntimeError( 

181 f"Attempt to read {what}, found {group.attrs['scipp-type']}." 

182 ) 

183 

184 

185def _data_handler_lut(): 

186 from .._scipp.core import DType as d 

187 

188 handler = {} 

189 for dtype in [ 

190 d.float64, 

191 d.float32, 

192 d.int64, 

193 d.int32, 

194 d.bool, 

195 d.datetime64, 

196 d.vector3, 

197 d.linear_transform3, 

198 d.rotation3, 

199 d.translation3, 

200 d.affine_transform3, 

201 ]: 

202 handler[str(dtype)] = NumpyDataIO 

203 for dtype in [d.VariableView, d.DataArrayView, d.DatasetView]: 

204 handler[str(dtype)] = BinDataIO 

205 for dtype in [d.Variable, d.DataArray, d.Dataset]: 

206 handler[str(dtype)] = ScippDataIO 

207 for dtype in [d.string]: 

208 handler[str(dtype)] = StringDataIO 

209 return handler 

210 

211 

212def _serialize_unit(unit): 

213 unit_dict = unit.to_dict() 

214 dtype = [('__version__', int), ('multiplier', float)] 

215 vals = [unit_dict['__version__'], unit_dict['multiplier']] 

216 if 'powers' in unit_dict: 

217 dtype.append(('powers', [(name, int) for name in unit_dict['powers']])) 

218 vals.append(tuple(val for val in unit_dict['powers'].values())) 

219 return np.array(tuple(vals), dtype=dtype) 

220 

221 

222def _read_unit_attr(ds): 

223 u = ds.attrs['unit'] 

224 if isinstance(u, str): 

225 return Unit(u) # legacy encoding as a string 

226 

227 # u is a structured numpy array 

228 unit_dict = {'__version__': u['__version__'], 'multiplier': u['multiplier']} 

229 if 'powers' in u.dtype.names: 

230 unit_dict['powers'] = { 

231 name: u['powers'][name] for name in u['powers'].dtype.names 

232 } 

233 return Unit.from_dict(unit_dict) 

234 

235 

236class VariableIO: 

237 _dtypes = _dtype_lut() 

238 _data_handlers = _data_handler_lut() 

239 

240 @classmethod 

241 def _write_data(cls, group, data): 

242 return cls._data_handlers[str(data.dtype)].write(group, data) 

243 

244 @classmethod 

245 def _read_data(cls, group, data): 

246 return cls._data_handlers[str(data.dtype)].read(group, data) 

247 

248 @classmethod 

249 def write(cls, group, var): 

250 if var.dtype not in cls._dtypes.values(): 

251 # In practice this may make the file unreadable, e.g., if values 

252 # have unsupported dtype. 

253 get_logger().warning( 

254 'Writing with dtype=%s not implemented, skipping.', var.dtype 

255 ) 

256 return 

257 _write_scipp_header(group, 'Variable') 

258 dset = cls._write_data(group, var) 

259 dset.attrs['dims'] = [str(dim) for dim in var.dims] 

260 dset.attrs['shape'] = var.shape 

261 dset.attrs['dtype'] = str(var.dtype) 

262 if var.unit is not None: 

263 dset.attrs['unit'] = _serialize_unit(var.unit) 

264 dset.attrs['aligned'] = var.aligned 

265 return group 

266 

267 @classmethod 

268 def read(cls, group): 

269 _check_scipp_header(group, 'Variable') 

270 from .._scipp import core as sc 

271 from .._scipp.core import DType as d 

272 

273 values = group['values'] 

274 contents = {key: values.attrs[key] for key in ['dims', 'shape']} 

275 contents['dtype'] = cls._dtypes[values.attrs['dtype']] 

276 if 'unit' in values.attrs: 

277 contents['unit'] = _read_unit_attr(values) 

278 else: 

279 contents['unit'] = None # essential, otherwise default unit is used 

280 contents['with_variances'] = 'variances' in group 

281 contents['aligned'] = values.attrs.get('aligned', True) 

282 if contents['dtype'] in [d.VariableView, d.DataArrayView, d.DatasetView]: 

283 var = BinDataIO.read(group) 

284 else: 

285 var = sc.empty(**contents) 

286 cls._read_data(group, var) 

287 return var 

288 

289 

290def _write_mapping(parent, mapping, override=None): 

291 if override is None: 

292 override = {} 

293 for i, name in enumerate(mapping): 

294 var_group_name = collection_element_name(name, i) 

295 if (g := override.get(name)) is not None: 

296 parent[var_group_name] = g 

297 else: 

298 g = HDF5IO.write( 

299 group=parent.create_group(var_group_name), data=mapping[name] 

300 ) 

301 if g is None: 

302 del parent[var_group_name] 

303 else: 

304 g.attrs['name'] = str(name) 

305 

306 

307def _read_mapping(group, override=None): 

308 if override is None: 

309 override = {} 

310 return { 

311 g.attrs['name']: override[g.attrs['name']] 

312 if g.attrs['name'] in override 

313 else HDF5IO.read(g) 

314 for g in group.values() 

315 } 

316 

317 

318class DataArrayIO: 

319 @staticmethod 

320 def write(group, data, override=None): 

321 if override is None: 

322 override = {} 

323 _write_scipp_header(group, 'DataArray') 

324 group.attrs['name'] = data.name 

325 if VariableIO.write(group.create_group('data'), var=data.data) is None: 

326 return None 

327 views = [data.coords, data.masks, data.attrs] 

328 # Note that we write aligned and unaligned coords into the same group. 

329 # Distinction is via an attribute, which is more natural than having 

330 # 2 separate groups. 

331 for view_name, view in zip(['coords', 'masks', 'attrs'], views): 

332 subgroup = group.create_group(view_name) 

333 _write_mapping(subgroup, view, override.get(view_name)) 

334 return group 

335 

336 @staticmethod 

337 def read(group, override=None): 

338 _check_scipp_header(group, 'DataArray') 

339 if override is None: 

340 override = {} 

341 from ..core import DataArray 

342 

343 contents = {} 

344 contents['name'] = group.attrs['name'] 

345 contents['data'] = VariableIO.read(group['data']) 

346 for category in ['coords', 'masks', 'attrs']: 

347 contents[category] = _read_mapping(group[category], override.get(category)) 

348 return DataArray(**contents) 

349 

350 

351class DatasetIO: 

352 @staticmethod 

353 def write(group, data): 

354 _write_scipp_header(group, 'Dataset') 

355 coords = group.create_group('coords') 

356 _write_mapping(coords, data.coords) 

357 entries = group.create_group('entries') 

358 # We cannot use coords directly, since we need lookup by name. The key used as 

359 # group name includes an integer index which may differ when writing items and 

360 # is not sufficient. 

361 coords = {v.attrs['name']: v for v in coords.values()} 

362 for i, (name, da) in enumerate(data.items()): 

363 HDF5IO.write( 

364 entries.create_group(collection_element_name(name, i)), 

365 da, 

366 override={'coords': coords}, 

367 ) 

368 return group 

369 

370 @staticmethod 

371 def read(group): 

372 _check_scipp_header(group, 'Dataset') 

373 from ..core import Dataset 

374 

375 coords = _read_mapping(group['coords']) 

376 return Dataset( 

377 coords=coords, 

378 data=_read_mapping(group['entries'], override={'coords': coords}), 

379 ) 

380 

381 

382class DataGroupIO: 

383 @staticmethod 

384 def write(group, data): 

385 _write_scipp_header(group, 'DataGroup') 

386 entries = group.create_group('entries') 

387 _write_mapping(entries, data) 

388 return group 

389 

390 @staticmethod 

391 def read(group): 

392 _check_scipp_header(group, 'DataGroup') 

393 from ..core import DataGroup 

394 

395 return DataGroup(_read_mapping(group['entries'])) 

396 

397 

398def _direct_io(cls, convert=None): 

399 type_name = cls.__name__ 

400 if convert is None: 

401 convert = cls 

402 

403 class GenericIO: 

404 @staticmethod 

405 def write(group, data): 

406 _write_scipp_header(group, type_name) 

407 group['entry'] = data 

408 return group 

409 

410 @staticmethod 

411 def read(group): 

412 _check_scipp_header(group, type_name) 

413 return convert(group['entry'][()]) 

414 

415 return GenericIO 

416 

417 

418class HDF5IO: 

419 _handlers: ClassVar[dict[str, Any]] = { 

420 'Variable': VariableIO, 

421 'DataArray': DataArrayIO, 

422 'Dataset': DatasetIO, 

423 'DataGroup': DataGroupIO, 

424 'str': _direct_io(str, convert=lambda b: b.decode('utf-8')), 

425 'ndarray': _direct_io(np.ndarray, convert=lambda x: x), 

426 **{ 

427 cls.__name__: _direct_io(cls) 

428 for cls in ( 

429 int, 

430 np.int64, 

431 np.int32, 

432 np.uint64, 

433 np.uint32, 

434 float, 

435 np.float32, 

436 np.float64, 

437 bool, 

438 np.bool_, 

439 bytes, 

440 ) 

441 }, 

442 } 

443 

444 @classmethod 

445 def write(cls, group, data, **kwargs): 

446 name = data.__class__.__name__.replace('View', '') 

447 try: 

448 handler = cls._handlers[name] 

449 except KeyError: 

450 get_logger().warning( 

451 "Writing type '%s' to HDF5 not implemented, skipping.", type(data) 

452 ) 

453 return None 

454 return handler.write(group, data, **kwargs) 

455 

456 @classmethod 

457 def read(cls, group, **kwargs): 

458 return cls._handlers[group.attrs['scipp-type']].read(group, **kwargs) 

459 

460 

461def save_hdf5(obj: VariableLike, filename: Union[str, Path]) -> None: 

462 """Write an object out to file in HDF5 format.""" 

463 import h5py 

464 

465 with h5py.File(filename, 'w') as f: 

466 HDF5IO.write(f, obj) 

467 

468 

469def load_hdf5(filename: Union[str, Path]) -> VariableLike: 

470 """Load a Scipp-HDF5 file.""" 

471 import h5py 

472 

473 with h5py.File(filename, 'r') as f: 

474 return HDF5IO.read(f)