Coverage for install/scipp/io/hdf5.py: 64%
277 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-04-28 01:28 +0000
1# SPDX-License-Identifier: BSD-3-Clause
2# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
3# @file
4# @author Simon Heybrock
6from __future__ import annotations
8from pathlib import Path
9from typing import Any, ClassVar, Union
11import numpy as np
13from ..core.cpp_classes import Unit
14from ..logging import get_logger
15from ..typing import VariableLike
18def _dtype_lut():
19 from .._scipp.core import DType as d
21 # For types understood by numpy we do not actually need this special
22 # handling, but will do as we add support for other types such as
23 # variable-length strings.
24 dtypes = [
25 d.float64,
26 d.float32,
27 d.int64,
28 d.int32,
29 d.bool,
30 d.datetime64,
31 d.string,
32 d.Variable,
33 d.DataArray,
34 d.Dataset,
35 d.VariableView,
36 d.DataArrayView,
37 d.DatasetView,
38 d.vector3,
39 d.linear_transform3,
40 d.affine_transform3,
41 d.translation3,
42 d.rotation3,
43 ]
44 names = [str(dtype) for dtype in dtypes]
45 return dict(zip(names, dtypes))
48def _as_hdf5_type(a):
49 if np.issubdtype(a.dtype, np.datetime64):
50 return a.view(np.int64)
51 return a
54def collection_element_name(name, index):
55 """
56 Convert name into an ASCII string that can be used as an object name in HDF5.
57 """
58 ascii_name = (
59 name.replace('.', '.')
60 .replace('/', '/')
61 .encode('ascii', 'xmlcharrefreplace')
62 .decode('ascii')
63 )
64 return f'elem_{index:03d}_{ascii_name}'
67class NumpyDataIO:
68 @staticmethod
69 def write(group, data):
70 dset = group.create_dataset('values', data=_as_hdf5_type(data.values))
71 if data.variances is not None:
72 variances = group.create_dataset('variances', data=data.variances)
73 dset.attrs['variances'] = variances.ref
74 return dset
76 @staticmethod
77 def read(group, data):
78 # h5py's read_direct method fails if any dim has zero size.
79 # see https://github.com/h5py/h5py/issues/870
80 if data.values.flags['C_CONTIGUOUS'] and data.values.size > 0:
81 group['values'].read_direct(_as_hdf5_type(data.values))
82 else:
83 # Values of Eigen matrices are transposed
84 data.values = group['values']
85 if 'variances' in group and data.variances.size > 0:
86 group['variances'].read_direct(data.variances)
89class BinDataIO:
90 @staticmethod
91 def write(group, data):
92 bins = data.bins.constituents
93 buffer_len = bins['data'].sizes[bins['dim']]
94 # Crude mechanism to avoid writing large buffers, e.g., from
95 # overallocation or when writing a slice of a larger variable. The
96 # copy causes some overhead, but so would the (much more complicated)
97 # solution to extract contents bin-by-bin. This approach will likely
98 # need to be revisited in the future.
99 if buffer_len > 1.5 * data.bins.size().sum().value:
100 data = data.copy()
101 bins = data.bins.constituents
102 values = group.create_group('values')
103 VariableIO.write(values.create_group('begin'), var=bins['begin'])
104 VariableIO.write(values.create_group('end'), var=bins['end'])
105 data_group = values.create_group('data')
106 data_group.attrs['dim'] = bins['dim']
107 HDF5IO.write(data_group, bins['data'])
108 return values
110 @staticmethod
111 def read(group):
112 from .._scipp import core as sc
114 values = group['values']
115 begin = VariableIO.read(values['begin'])
116 end = VariableIO.read(values['end'])
117 dim = values['data'].attrs['dim']
118 data = HDF5IO.read(values['data'])
119 return sc.bins(begin=begin, end=end, dim=dim, data=data)
122class ScippDataIO:
123 @staticmethod
124 def write(group, data):
125 values = group.create_group('values')
126 if len(data.shape) == 0:
127 HDF5IO.write(values, data.value)
128 else:
129 for i, item in enumerate(data.values):
130 HDF5IO.write(values.create_group(f'value-{i}'), item)
131 return values
133 @staticmethod
134 def read(group, data):
135 values = group['values']
136 if len(data.shape) == 0:
137 data.value = HDF5IO.read(values)
138 else:
139 for i in range(len(data.values)):
140 data.values[i] = HDF5IO.read(values[f'value-{i}'])
143class StringDataIO:
144 @staticmethod
145 def write(group, data):
146 import h5py
148 dt = h5py.string_dtype(encoding='utf-8')
149 dset = group.create_dataset('values', shape=data.shape, dtype=dt)
150 if len(data.shape) == 0:
151 dset[()] = data.value
152 else:
153 for i in range(len(data.values)):
154 dset[i] = data.values[i]
155 return dset
157 @staticmethod
158 def read(group, data):
159 values = group['values']
160 if len(data.shape) == 0:
161 data.value = values[()]
162 else:
163 for i in range(len(data.values)):
164 data.values[i] = values[i]
167def _write_scipp_header(group, what):
168 from .._scipp import __version__
170 group.attrs['scipp-version'] = __version__
171 group.attrs['scipp-type'] = what
174def _check_scipp_header(group, what):
175 if 'scipp-version' not in group.attrs:
176 raise RuntimeError(
177 "This does not look like an HDF5 file/group written by Scipp."
178 )
179 if group.attrs['scipp-type'] != what:
180 raise RuntimeError(
181 f"Attempt to read {what}, found {group.attrs['scipp-type']}."
182 )
185def _data_handler_lut():
186 from .._scipp.core import DType as d
188 handler = {}
189 for dtype in [
190 d.float64,
191 d.float32,
192 d.int64,
193 d.int32,
194 d.bool,
195 d.datetime64,
196 d.vector3,
197 d.linear_transform3,
198 d.rotation3,
199 d.translation3,
200 d.affine_transform3,
201 ]:
202 handler[str(dtype)] = NumpyDataIO
203 for dtype in [d.VariableView, d.DataArrayView, d.DatasetView]:
204 handler[str(dtype)] = BinDataIO
205 for dtype in [d.Variable, d.DataArray, d.Dataset]:
206 handler[str(dtype)] = ScippDataIO
207 for dtype in [d.string]:
208 handler[str(dtype)] = StringDataIO
209 return handler
212def _serialize_unit(unit):
213 unit_dict = unit.to_dict()
214 dtype = [('__version__', int), ('multiplier', float)]
215 vals = [unit_dict['__version__'], unit_dict['multiplier']]
216 if 'powers' in unit_dict:
217 dtype.append(('powers', [(name, int) for name in unit_dict['powers']]))
218 vals.append(tuple(val for val in unit_dict['powers'].values()))
219 return np.array(tuple(vals), dtype=dtype)
222def _read_unit_attr(ds):
223 u = ds.attrs['unit']
224 if isinstance(u, str):
225 return Unit(u) # legacy encoding as a string
227 # u is a structured numpy array
228 unit_dict = {'__version__': u['__version__'], 'multiplier': u['multiplier']}
229 if 'powers' in u.dtype.names:
230 unit_dict['powers'] = {
231 name: u['powers'][name] for name in u['powers'].dtype.names
232 }
233 return Unit.from_dict(unit_dict)
236class VariableIO:
237 _dtypes = _dtype_lut()
238 _data_handlers = _data_handler_lut()
240 @classmethod
241 def _write_data(cls, group, data):
242 return cls._data_handlers[str(data.dtype)].write(group, data)
244 @classmethod
245 def _read_data(cls, group, data):
246 return cls._data_handlers[str(data.dtype)].read(group, data)
248 @classmethod
249 def write(cls, group, var):
250 if var.dtype not in cls._dtypes.values():
251 # In practice this may make the file unreadable, e.g., if values
252 # have unsupported dtype.
253 get_logger().warning(
254 'Writing with dtype=%s not implemented, skipping.', var.dtype
255 )
256 return
257 _write_scipp_header(group, 'Variable')
258 dset = cls._write_data(group, var)
259 dset.attrs['dims'] = [str(dim) for dim in var.dims]
260 dset.attrs['shape'] = var.shape
261 dset.attrs['dtype'] = str(var.dtype)
262 if var.unit is not None:
263 dset.attrs['unit'] = _serialize_unit(var.unit)
264 dset.attrs['aligned'] = var.aligned
265 return group
267 @classmethod
268 def read(cls, group):
269 _check_scipp_header(group, 'Variable')
270 from .._scipp import core as sc
271 from .._scipp.core import DType as d
273 values = group['values']
274 contents = {key: values.attrs[key] for key in ['dims', 'shape']}
275 contents['dtype'] = cls._dtypes[values.attrs['dtype']]
276 if 'unit' in values.attrs:
277 contents['unit'] = _read_unit_attr(values)
278 else:
279 contents['unit'] = None # essential, otherwise default unit is used
280 contents['with_variances'] = 'variances' in group
281 contents['aligned'] = values.attrs.get('aligned', True)
282 if contents['dtype'] in [d.VariableView, d.DataArrayView, d.DatasetView]:
283 var = BinDataIO.read(group)
284 else:
285 var = sc.empty(**contents)
286 cls._read_data(group, var)
287 return var
290def _write_mapping(parent, mapping, override=None):
291 if override is None:
292 override = {}
293 for i, name in enumerate(mapping):
294 var_group_name = collection_element_name(name, i)
295 if (g := override.get(name)) is not None:
296 parent[var_group_name] = g
297 else:
298 g = HDF5IO.write(
299 group=parent.create_group(var_group_name), data=mapping[name]
300 )
301 if g is None:
302 del parent[var_group_name]
303 else:
304 g.attrs['name'] = str(name)
307def _read_mapping(group, override=None):
308 if override is None:
309 override = {}
310 return {
311 g.attrs['name']: override[g.attrs['name']]
312 if g.attrs['name'] in override
313 else HDF5IO.read(g)
314 for g in group.values()
315 }
318class DataArrayIO:
319 @staticmethod
320 def write(group, data, override=None):
321 if override is None:
322 override = {}
323 _write_scipp_header(group, 'DataArray')
324 group.attrs['name'] = data.name
325 if VariableIO.write(group.create_group('data'), var=data.data) is None:
326 return None
327 views = [data.coords, data.masks, data.attrs]
328 # Note that we write aligned and unaligned coords into the same group.
329 # Distinction is via an attribute, which is more natural than having
330 # 2 separate groups.
331 for view_name, view in zip(['coords', 'masks', 'attrs'], views):
332 subgroup = group.create_group(view_name)
333 _write_mapping(subgroup, view, override.get(view_name))
334 return group
336 @staticmethod
337 def read(group, override=None):
338 _check_scipp_header(group, 'DataArray')
339 if override is None:
340 override = {}
341 from ..core import DataArray
343 contents = {}
344 contents['name'] = group.attrs['name']
345 contents['data'] = VariableIO.read(group['data'])
346 for category in ['coords', 'masks', 'attrs']:
347 contents[category] = _read_mapping(group[category], override.get(category))
348 return DataArray(**contents)
351class DatasetIO:
352 @staticmethod
353 def write(group, data):
354 _write_scipp_header(group, 'Dataset')
355 coords = group.create_group('coords')
356 _write_mapping(coords, data.coords)
357 entries = group.create_group('entries')
358 # We cannot use coords directly, since we need lookup by name. The key used as
359 # group name includes an integer index which may differ when writing items and
360 # is not sufficient.
361 coords = {v.attrs['name']: v for v in coords.values()}
362 for i, (name, da) in enumerate(data.items()):
363 HDF5IO.write(
364 entries.create_group(collection_element_name(name, i)),
365 da,
366 override={'coords': coords},
367 )
368 return group
370 @staticmethod
371 def read(group):
372 _check_scipp_header(group, 'Dataset')
373 from ..core import Dataset
375 coords = _read_mapping(group['coords'])
376 return Dataset(
377 coords=coords,
378 data=_read_mapping(group['entries'], override={'coords': coords}),
379 )
382class DataGroupIO:
383 @staticmethod
384 def write(group, data):
385 _write_scipp_header(group, 'DataGroup')
386 entries = group.create_group('entries')
387 _write_mapping(entries, data)
388 return group
390 @staticmethod
391 def read(group):
392 _check_scipp_header(group, 'DataGroup')
393 from ..core import DataGroup
395 return DataGroup(_read_mapping(group['entries']))
398def _direct_io(cls, convert=None):
399 type_name = cls.__name__
400 if convert is None:
401 convert = cls
403 class GenericIO:
404 @staticmethod
405 def write(group, data):
406 _write_scipp_header(group, type_name)
407 group['entry'] = data
408 return group
410 @staticmethod
411 def read(group):
412 _check_scipp_header(group, type_name)
413 return convert(group['entry'][()])
415 return GenericIO
418class HDF5IO:
419 _handlers: ClassVar[dict[str, Any]] = {
420 'Variable': VariableIO,
421 'DataArray': DataArrayIO,
422 'Dataset': DatasetIO,
423 'DataGroup': DataGroupIO,
424 'str': _direct_io(str, convert=lambda b: b.decode('utf-8')),
425 'ndarray': _direct_io(np.ndarray, convert=lambda x: x),
426 **{
427 cls.__name__: _direct_io(cls)
428 for cls in (
429 int,
430 np.int64,
431 np.int32,
432 np.uint64,
433 np.uint32,
434 float,
435 np.float32,
436 np.float64,
437 bool,
438 np.bool_,
439 bytes,
440 )
441 },
442 }
444 @classmethod
445 def write(cls, group, data, **kwargs):
446 name = data.__class__.__name__.replace('View', '')
447 try:
448 handler = cls._handlers[name]
449 except KeyError:
450 get_logger().warning(
451 "Writing type '%s' to HDF5 not implemented, skipping.", type(data)
452 )
453 return None
454 return handler.write(group, data, **kwargs)
456 @classmethod
457 def read(cls, group, **kwargs):
458 return cls._handlers[group.attrs['scipp-type']].read(group, **kwargs)
461def save_hdf5(obj: VariableLike, filename: Union[str, Path]) -> None:
462 """Write an object out to file in HDF5 format."""
463 import h5py
465 with h5py.File(filename, 'w') as f:
466 HDF5IO.write(f, obj)
469def load_hdf5(filename: Union[str, Path]) -> VariableLike:
470 """Load a Scipp-HDF5 file."""
471 import h5py
473 with h5py.File(filename, 'r') as f:
474 return HDF5IO.read(f)