Coverage for install/scipp/io/hdf5.py: 60%
297 statements
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
« prev ^ index » next coverage.py v7.6.1, created at 2024-11-17 01:51 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @file
4# @author Simon Heybrock
6from __future__ import annotations
8from collections.abc import Mapping
9from io import BytesIO, StringIO
10from os import PathLike
11from typing import TYPE_CHECKING, Any, ClassVar, Protocol
13import numpy as np
14import numpy.typing as npt
16from .._scipp import core as _cpp
17from ..core import (
18 DataArray,
19 DataGroup,
20 Dataset,
21 DType,
22 DTypeError,
23 Unit,
24 Variable,
25 bins,
26)
27from ..logging import get_logger
28from ..typing import VariableLike
30if TYPE_CHECKING:
31 import h5py as h5
32else:
33 h5 = Any
36def _dtype_lut() -> dict[str, DType]:
37 # For types understood by numpy we do not actually need this special
38 # handling, but will do as we add support for other types such as
39 # variable-length strings.
40 dtypes = [
41 DType.float64,
42 DType.float32,
43 DType.int64,
44 DType.int32,
45 DType.bool,
46 DType.datetime64,
47 DType.string,
48 DType.Variable,
49 DType.DataArray,
50 DType.Dataset,
51 DType.VariableView,
52 DType.DataArrayView,
53 DType.DatasetView,
54 DType.vector3,
55 DType.linear_transform3,
56 DType.affine_transform3,
57 DType.translation3,
58 DType.rotation3,
59 ]
60 names = [str(dtype) for dtype in dtypes]
61 return dict(zip(names, dtypes, strict=True))
64def _as_hdf5_type(a: npt.NDArray[Any]) -> npt.NDArray[Any]:
65 if np.issubdtype(a.dtype, np.datetime64):
66 return a.view(np.int64)
67 return a
70def _collection_element_name(name: str, index: int) -> str:
71 """
72 Convert name into an ASCII string that can be used as an object name in HDF5.
73 """
74 ascii_name = (
75 name.replace('.', '.')
76 .replace('/', '/')
77 .encode('ascii', 'xmlcharrefreplace')
78 .decode('ascii')
79 )
80 return f'elem_{index:03d}_{ascii_name}'
83class _DataWriter(Protocol):
84 @staticmethod
85 def write(group: h5.Group, data: Variable) -> h5.Dataset | h5.Group: ...
88# All readers are writers and have the same interface for writing.
89# But readers for array variables expect a `data` argument that is
90# absent in other readers.
91class _ArrayDataIO(_DataWriter, Protocol):
92 @staticmethod
93 def read(group: h5.Group, data: Variable) -> None: ...
96class _NumpyDataIO:
97 @staticmethod
98 def write(group: h5.Group, data: Variable) -> h5.Dataset:
99 dset = group.create_dataset('values', data=_as_hdf5_type(data.values))
100 if data.variances is not None:
101 variances = group.create_dataset('variances', data=data.variances)
102 dset.attrs['variances'] = variances.ref
103 return dset
105 @staticmethod
106 def read(group: h5.Group, data: Variable) -> None:
107 # h5py's read_direct method fails if any dim has zero size.
108 # see https://github.com/h5py/h5py/issues/870
109 if data.values.flags['C_CONTIGUOUS'] and data.values.size > 0:
110 group['values'].read_direct(_as_hdf5_type(data.values))
111 else:
112 # Values of Eigen matrices are transposed
113 data.values = group['values']
114 if 'variances' in group and data.variances.size > 0:
115 group['variances'].read_direct(data.variances)
118class _BinDataIO:
119 @staticmethod
120 def write(group: h5.Group, data: Variable) -> h5.Group:
121 if data.bins is None:
122 raise DTypeError("Expected binned data")
123 bins = data.bins.constituents
124 buffer_len = bins['data'].sizes[bins['dim']]
125 # Crude mechanism to avoid writing large buffers, e.g., from
126 # overallocation or when writing a slice of a larger variable. The
127 # copy causes some overhead, but so would the (much more complicated)
128 # solution to extract contents bin-by-bin. This approach will likely
129 # need to be revisited in the future.
130 if buffer_len > 1.5 * data.bins.size().sum().value:
131 data = data.copy()
132 bins = data.bins.constituents # type: ignore[union-attr]
133 values = group.create_group('values')
134 _VariableIO.write(values.create_group('begin'), var=bins['begin'])
135 _VariableIO.write(values.create_group('end'), var=bins['end'])
136 data_group = values.create_group('data')
137 data_group.attrs['dim'] = bins['dim']
138 _HDF5IO.write(data_group, bins['data'])
139 return values
141 @staticmethod
142 def read(group: h5.Group) -> Variable:
143 values = group['values']
144 begin = _VariableIO.read(values['begin'])
145 end = _VariableIO.read(values['end'])
146 dim = values['data'].attrs['dim']
147 data = _HDF5IO.read(values['data'])
148 return bins(begin=begin, end=end, dim=dim, data=data)
151class _ScippDataIO:
152 @staticmethod
153 def write(group: h5.Group, data: Variable) -> h5.Group:
154 values = group.create_group('values')
155 if len(data.shape) == 0:
156 _HDF5IO.write(values, data.value)
157 else:
158 for i, item in enumerate(data.values):
159 _HDF5IO.write(values.create_group(f'value-{i}'), item)
160 return values
162 @staticmethod
163 def read(group: h5.Group, data: Variable) -> None:
164 values = group['values']
165 if len(data.shape) == 0:
166 data.value = _HDF5IO.read(values)
167 else:
168 for i in range(len(data.values)):
169 data.values[i] = _HDF5IO.read(values[f'value-{i}'])
172class _StringDataIO:
173 @staticmethod
174 def write(group: h5.Group, data: Variable) -> h5.Dataset:
175 import h5py
177 dt = h5py.string_dtype(encoding='utf-8')
178 dset = group.create_dataset('values', shape=data.shape, dtype=dt)
179 if len(data.shape) == 0:
180 dset[()] = data.value
181 else:
182 for i in range(len(data.values)):
183 dset[i] = data.values[i]
184 return dset
186 @staticmethod
187 def read(group: h5.Group, data: Variable) -> None:
188 values = group['values']
189 if len(data.shape) == 0:
190 data.value = values[()]
191 else:
192 for i in range(len(data.values)):
193 data.values[i] = values[i]
196def _write_scipp_header(group: h5.Group, what: str) -> None:
197 from ..core import __version__
199 group.attrs['scipp-version'] = __version__
200 group.attrs['scipp-type'] = what
203def _check_scipp_header(group: h5.Group, what: str) -> None:
204 if 'scipp-version' not in group.attrs:
205 raise RuntimeError(
206 "This does not look like an HDF5 file/group written by Scipp."
207 )
208 if group.attrs['scipp-type'] != what:
209 raise RuntimeError(
210 f"Attempt to read {what}, found {group.attrs['scipp-type']}."
211 )
214def _array_data_io_lut() -> dict[str, _ArrayDataIO]:
215 handler: dict[str, _ArrayDataIO] = {}
216 for dtype in [
217 DType.float64,
218 DType.float32,
219 DType.int64,
220 DType.int32,
221 DType.bool,
222 DType.datetime64,
223 DType.vector3,
224 DType.linear_transform3,
225 DType.rotation3,
226 DType.translation3,
227 DType.affine_transform3,
228 ]:
229 handler[str(dtype)] = _NumpyDataIO
230 for dtype in [DType.Variable, DType.DataArray, DType.Dataset]:
231 handler[str(dtype)] = _ScippDataIO
232 for dtype in [DType.string]:
233 handler[str(dtype)] = _StringDataIO
234 return handler
237def _data_writer_lut() -> dict[str, _DataWriter]:
238 # Unpack and repack dict to cast to the correct value type.
239 handler: dict[str, _DataWriter] = {**_array_data_io_lut()}
240 for dtype in [DType.VariableView, DType.DataArrayView, DType.DatasetView]:
241 handler[str(dtype)] = _BinDataIO
242 return handler
245def _serialize_unit(unit: Unit) -> npt.NDArray[Any]:
246 unit_dict = unit.to_dict()
247 dtype: list[tuple[str, Any]] = [('__version__', int), ('multiplier', float)]
248 vals = [unit_dict['__version__'], unit_dict['multiplier']]
249 if 'powers' in unit_dict:
250 dtype.append(('powers', [(name, int) for name in unit_dict['powers']]))
251 vals.append(tuple(val for val in unit_dict['powers'].values()))
252 return np.array(tuple(vals), dtype=dtype)
255def _read_unit_attr(ds: h5.Dataset) -> Unit:
256 u = ds.attrs['unit']
257 if isinstance(u, str):
258 return Unit(u) # legacy encoding as a string
260 # u is a structured numpy array
261 unit_dict = {'__version__': u['__version__'], 'multiplier': u['multiplier']}
262 if 'powers' in u.dtype.names:
263 unit_dict['powers'] = {
264 name: u['powers'][name] for name in u['powers'].dtype.names
265 }
266 return Unit.from_dict(unit_dict)
269class _VariableIO:
270 _dtypes = _dtype_lut()
271 _array_data_readers = _array_data_io_lut()
272 _data_writers = _data_writer_lut()
274 @classmethod
275 def _write_data(cls, group: h5.Group, data: Variable) -> h5.Dataset | h5.Group:
276 return cls._data_writers[str(data.dtype)].write(group, data)
278 @classmethod
279 def _read_array_data(cls, group: h5.Group, data: Variable) -> None:
280 cls._array_data_readers[str(data.dtype)].read(group, data)
282 @classmethod
283 def write(cls, group: h5.Group, var: Variable) -> h5.Group | None:
284 if var.dtype not in cls._dtypes.values():
285 # In practice this may make the file unreadable, e.g., if values
286 # have unsupported dtype.
287 get_logger().warning(
288 'Writing with dtype=%s not implemented, skipping.', var.dtype
289 )
290 return None
291 _write_scipp_header(group, 'Variable')
292 dset = cls._write_data(group, var)
293 dset.attrs['dims'] = [str(dim) for dim in var.dims]
294 dset.attrs['shape'] = var.shape
295 dset.attrs['dtype'] = str(var.dtype)
296 if var.unit is not None:
297 dset.attrs['unit'] = _serialize_unit(var.unit)
298 dset.attrs['aligned'] = var.aligned
299 return group
301 @classmethod
302 def read(cls, group: h5.Group) -> Variable:
303 _check_scipp_header(group, 'Variable')
304 values = group['values']
305 contents = {key: values.attrs[key] for key in ['dims', 'shape']}
306 contents['dtype'] = cls._dtypes[values.attrs['dtype']]
307 if 'unit' in values.attrs:
308 contents['unit'] = _read_unit_attr(values)
309 else:
310 contents['unit'] = None # essential, otherwise default unit is used
311 contents['with_variances'] = 'variances' in group
312 contents['aligned'] = values.attrs.get('aligned', True)
313 if contents['dtype'] in [
314 DType.VariableView,
315 DType.DataArrayView,
316 DType.DatasetView,
317 ]:
318 var = _BinDataIO.read(group)
319 else:
320 var = _cpp.empty(**contents)
321 cls._read_array_data(group, var)
322 return var
325def _write_mapping(
326 parent: h5.Group,
327 mapping: Mapping[str, VariableLike],
328 override: Mapping[str, h5.Group] | None = None,
329) -> None:
330 if override is None:
331 override = {}
332 for i, name in enumerate(mapping):
333 var_group_name = _collection_element_name(name, i)
334 if (g := override.get(name)) is not None:
335 parent[var_group_name] = g
336 else:
337 g = _HDF5IO.write(
338 group=parent.create_group(var_group_name), data=mapping[name]
339 )
340 if g is None:
341 del parent[var_group_name]
342 else:
343 g.attrs['name'] = str(name)
346def _read_mapping(
347 group: h5.Group, override: Mapping[str, h5.Group] | None = None
348) -> VariableLike:
349 if override is None:
350 override = {}
351 return {
352 g.attrs['name']: override[g.attrs['name']]
353 if g.attrs['name'] in override
354 else _HDF5IO.read(g)
355 for g in group.values()
356 }
359class _DataArrayIO:
360 @staticmethod
361 def write(group: h5.Group, data, override=None):
362 if override is None:
363 override = {}
364 _write_scipp_header(group, 'DataArray')
365 group.attrs['name'] = data.name
366 if _VariableIO.write(group.create_group('data'), var=data.data) is None:
367 return None
368 views = [data.coords, data.masks, data.attrs]
369 # Note that we write aligned and unaligned coords into the same group.
370 # Distinction is via an attribute, which is more natural than having
371 # 2 separate groups.
372 for view_name, view in zip(['coords', 'masks', 'attrs'], views, strict=True):
373 subgroup = group.create_group(view_name)
374 _write_mapping(subgroup, view, override.get(view_name))
375 return group
377 @staticmethod
378 def read(group: h5.Group, override=None):
379 _check_scipp_header(group, 'DataArray')
380 if override is None:
381 override = {}
382 contents = {}
383 contents['name'] = group.attrs['name']
384 contents['data'] = _VariableIO.read(group['data'])
385 for category in ['coords', 'masks', 'attrs']:
386 contents[category] = _read_mapping(group[category], override.get(category))
387 return DataArray(**contents)
390class _DatasetIO:
391 @staticmethod
392 def write(group: h5.Group, data):
393 _write_scipp_header(group, 'Dataset')
394 coords = group.create_group('coords')
395 _write_mapping(coords, data.coords)
396 entries = group.create_group('entries')
397 # We cannot use coords directly, since we need lookup by name. The key used as
398 # group name includes an integer index which may differ when writing items and
399 # is not sufficient.
400 coords = {v.attrs['name']: v for v in coords.values()}
401 for i, (name, da) in enumerate(data.items()):
402 _HDF5IO.write(
403 entries.create_group(_collection_element_name(name, i)),
404 da,
405 override={'coords': coords},
406 )
407 return group
409 @staticmethod
410 def read(group: h5.Group):
411 _check_scipp_header(group, 'Dataset')
412 coords = _read_mapping(group['coords'])
413 return Dataset(
414 coords=coords,
415 data=_read_mapping(group['entries'], override={'coords': coords}),
416 )
419class _DataGroupIO:
420 @staticmethod
421 def write(group: h5.Group, data):
422 _write_scipp_header(group, 'DataGroup')
423 entries = group.create_group('entries')
424 _write_mapping(entries, data)
425 return group
427 @staticmethod
428 def read(group: h5.Group):
429 _check_scipp_header(group, 'DataGroup')
430 return DataGroup(_read_mapping(group['entries']))
433def _direct_io(cls, convert=None):
434 # Use fully qualified name for numpy.bool to avoid confusion with builtin bool."
435 # Numpy bools have different names in 1.x and 2.x
436 # TODO: This should be fixed once we drop support for numpy<2
437 if cls.__module__ == 'numpy' and cls.__name__[:4] == 'bool':
438 type_name = 'numpy.bool'
439 else:
440 type_name = cls.__name__
441 if convert is None:
442 convert = cls
444 class GenericIO:
445 @staticmethod
446 def write(group, data):
447 _write_scipp_header(group, type_name)
448 group['entry'] = data
449 return group
451 @staticmethod
452 def read(group):
453 _check_scipp_header(group, type_name)
454 return convert(group['entry'][()])
456 return GenericIO
459class _HDF5IO:
460 _handlers: ClassVar[dict[str, Any]] = {
461 'Variable': _VariableIO,
462 'DataArray': _DataArrayIO,
463 'Dataset': _DatasetIO,
464 'DataGroup': _DataGroupIO,
465 'str': _direct_io(str, convert=lambda b: b.decode('utf-8')),
466 'ndarray': _direct_io(np.ndarray, convert=lambda x: x),
467 'numpy.bool': _direct_io(np.bool_),
468 **{
469 cls.__name__: _direct_io(cls)
470 for cls in (
471 int,
472 np.int64,
473 np.int32,
474 np.uint64,
475 np.uint32,
476 float,
477 np.float32,
478 np.float64,
479 bool,
480 bytes,
481 )
482 },
483 }
485 @classmethod
486 def write(cls, group: h5.Group, data, **kwargs):
487 data_cls = data.__class__
488 # Numpy bools have different names in 1.x and 2.x
489 # TODO: This should be fixed once we drop support for numpy<2
490 if data_cls.__module__ == 'numpy' and data_cls.__name__[:4] == 'bool':
491 name = 'numpy.bool'
492 else:
493 name = data_cls.__name__.replace('View', '')
494 try:
495 handler = cls._handlers[name]
496 except KeyError:
497 get_logger().warning(
498 "Writing type '%s' to HDF5 not implemented, skipping.", type(data)
499 )
500 return None
501 return handler.write(group, data, **kwargs)
503 @classmethod
504 def read(cls, group: h5.Group, **kwargs):
505 return cls._handlers[group.attrs['scipp-type']].read(group, **kwargs)
508def save_hdf5(
509 obj: VariableLike,
510 filename: str | PathLike[str] | StringIO | BytesIO | h5.Group,
511) -> None:
512 """Write an object out to file in HDF5 format."""
513 import h5py
515 if isinstance(filename, h5py.Group):
516 return _HDF5IO.write(filename, obj)
518 with h5py.File(filename, 'w') as f:
519 _HDF5IO.write(f, obj)
522def load_hdf5(
523 filename: str | PathLike[str] | StringIO | BytesIO | h5.Group,
524) -> VariableLike:
525 """Load a Scipp-HDF5 file."""
526 import h5py
528 if isinstance(filename, h5py.Group):
529 return _HDF5IO.read(filename)
531 with h5py.File(filename, 'r') as f:
532 return _HDF5IO.read(f)