Source code for ess.nmx.mcstas.nexus

# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
import io
import pathlib
import warnings
from collections.abc import Callable, Generator
from functools import partial, wraps
from typing import Any, TypeVar

import h5py
import numpy as np
import sciline as sl
import scipp as sc

from .types import (
    DetectorIndex,
    DetectorName,
    FilePath,
    NMXDetectorMetadata,
    NMXExperimentMetadata,
    NMXReducedDataGroup,
)


def _fallback_compute_positions(dg: sc.DataGroup) -> sc.DataGroup:
    import warnings

    import scippnexus as snx

    warnings.warn(
        "Using fallback compute_positions due to empty log entries. "
        "This may lead to incorrect results. Please check the data carefully."
        "The fallback will replace empty logs with a scalar value of zero.",
        UserWarning,
        stacklevel=2,
    )

    empty_transformations = [
        transformation
        for transformation in dg['depends_on'].transformations.values()
        if 'time' in transformation.value.dims
        and transformation.sizes['time'] == 0  # empty log
    ]
    for transformation in empty_transformations:
        orig_value = transformation.value
        orig_value = sc.scalar(0, unit=orig_value.unit, dtype=orig_value.dtype)
        transformation.value = orig_value
    return snx.compute_positions(dg, store_transform='transform_matrix')


def _compute_positions(
    dg: sc.DataGroup, auto_fix_transformations: bool = False
) -> sc.DataGroup:
    """Compute positions of the data group from transformations.

    Wraps the `scippnexus.compute_positions` function
    and provides a fallback for cases where the transformations
    contain empty logs.

    Parameters
    ----------
    dg:
        Data group containing the transformations and data.
    auto_fix_transformations:
        If `True`, it will attempt to fix empty transformations.
        It will replace them with a scalar value of zero.
        It is because adding a time dimension will make it not possible
        to compute positions of children due to time-dependent transformations.

    Returns
    -------
    :
        Data group with computed positions.

    Warnings
    --------
    If `auto_fix_transformations` is `True`, it will warn about the fallback
    being used due to empty logs or scalar transformations.
    This is because the fallback may lead to incorrect results.

    """
    import scippnexus as snx

    try:
        return snx.compute_positions(dg, store_transform='transform_matrix')
    except ValueError as e:
        if auto_fix_transformations:
            return _fallback_compute_positions(dg)
        raise e


def _create_dataset_from_string(*, root_entry: h5py.Group, name: str, var: str) -> None:
    root_entry.create_dataset(name, dtype=h5py.string_dtype(), data=var)


def _create_dataset_from_var(
    *,
    root_entry: h5py.Group,
    var: sc.Variable,
    name: str,
    long_name: str | None = None,
    compression: str | None = None,
    compression_opts: int | tuple[int, int] | None = None,
    chunks: tuple[int, ...] | int | bool | None = None,
    dtype: Any = None,
) -> h5py.Dataset:
    compression_options = {}
    if compression is not None:
        compression_options["compression"] = compression
    if compression_opts is not None:
        compression_options["compression_opts"] = compression_opts

    dataset = root_entry.create_dataset(
        name,
        data=var.values if dtype is None else var.values.astype(dtype, copy=False),
        chunks=chunks,
        **compression_options,
    )
    if var.unit is not None:
        dataset.attrs["units"] = str(var.unit)
    if long_name is not None:
        dataset.attrs["long_name"] = long_name
    return dataset


@wraps(_create_dataset_from_var)
def _create_compressed_dataset(*args, **kwargs):
    """Create dataset with compression options.

    It will try to use ``bitshuffle`` for compression if available.
    Otherwise, it will fall back to ``gzip`` compression.

    [``Bitshuffle/LZ4``](https://github.com/kiyo-masui/bitshuffle)
    is used for convenience.
    Since ``Dectris`` uses it for their Nexus file compression,
    it is compatible with DIALS.
    ``Bitshuffle/LZ4`` tends to give similar results to
    GZIP and other compression algorithms with better performance.
    A naive implementation of bitshuffle/LZ4 compression,
    shown in [issue #124](https://github.com/scipp/essnmx/issues/124),
    led to 80% file reduction (365 MB vs 1.8 GB).

    """
    try:
        import bitshuffle.h5

        compression_filter = bitshuffle.h5.H5FILTER
        default_compression_opts = (0, bitshuffle.h5.H5_COMPRESS_LZ4)
    except ImportError:
        warnings.warn(
            UserWarning(
                "Could not find the bitshuffle.h5 module from bitshuffle package. "
                "The bitshuffle package is not installed or only partially installed. "
                "Exporting to NeXus files with bitshuffle compression is not possible."
            ),
            stacklevel=2,
        )
        compression_filter = "gzip"
        default_compression_opts = 4

    return _create_dataset_from_var(
        *args,
        **kwargs,
        compression=compression_filter,
        compression_opts=default_compression_opts,
    )


def _create_root_data_entry(file_obj: h5py.File) -> h5py.Group:
    nx_entry = file_obj.create_group("NMX_data")
    nx_entry.attrs["NX_class"] = "NXentry"
    nx_entry.attrs["default"] = "data"
    nx_entry.attrs["name"] = "NMX"
    nx_entry["name"] = "NMX"
    nx_entry["definition"] = "TOFRAW"
    return nx_entry


def _create_sample_group(data: sc.DataGroup, nx_entry: h5py.Group) -> h5py.Group:
    nx_sample = nx_entry.create_group("NXsample")
    nx_sample["name"] = data['sample_name'].value
    _create_dataset_from_var(
        root_entry=nx_sample,
        var=data['crystal_rotation'],
        name='crystal_rotation',
        long_name='crystal rotation in Phi (XYZ)',
    )
    return nx_sample


def _create_instrument_group(data: sc.DataGroup, nx_entry: h5py.Group) -> h5py.Group:
    nx_instrument = nx_entry.create_group("NXinstrument")
    nx_instrument.create_dataset("proton_charge", data=data['proton_charge'].values)

    nx_detector_1 = nx_instrument.create_group("detector_1")
    # Detector counts
    _create_compressed_dataset(
        root_entry=nx_detector_1,
        name="counts",
        var=data['counts'],
    )
    # Time of arrival bin edges
    _create_dataset_from_var(
        root_entry=nx_detector_1,
        var=data['counts'].coords['t'],
        name="t_bin",
        long_name="t_bin TOF (ms)",
    )
    # Pixel IDs
    _create_compressed_dataset(
        root_entry=nx_detector_1,
        name="pixel_id",
        var=data['counts'].coords['id'],
        long_name="pixel ID",
    )
    return nx_instrument


def _create_detector_group(data: sc.DataGroup, nx_entry: h5py.Group) -> h5py.Group:
    nx_detector = nx_entry.create_group("NXdetector")
    # Position of the first pixel (lowest ID) in the detector
    _create_compressed_dataset(
        root_entry=nx_detector,
        name="origin",
        var=data['origin_position'],
    )
    # Fast axis, along where the pixel ID increases by 1
    _create_dataset_from_var(
        root_entry=nx_detector, var=data['fast_axis'], name="fast_axis"
    )
    # Slow axis, along where the pixel ID increases
    # by the number of pixels in the fast axis
    _create_dataset_from_var(
        root_entry=nx_detector, var=data['slow_axis'], name="slow_axis"
    )
    return nx_detector


def _create_source_group(data: sc.DataGroup, nx_entry: h5py.Group) -> h5py.Group:
    nx_source = nx_entry.create_group("NXsource")
    nx_source["name"] = "European Spallation Source"
    nx_source["short_name"] = "ESS"
    nx_source["type"] = "Spallation Neutron Source"
    nx_source["distance"] = sc.norm(data['source_position']).value
    nx_source["probe"] = "neutron"
    nx_source["target_material"] = "W"
    return nx_source


[docs] def export_as_nexus( data: sc.DataGroup, output_file: str | pathlib.Path | io.BytesIO ) -> None: """Export the reduced data to a NeXus file. Currently exporting step is not expected to be part of sciline pipelines. """ warnings.warn( DeprecationWarning( "Exporting to custom NeXus format will be deprecated in the near future " ">=26.12.0. " "Please use ``export_as_nxlauetof`` instead." ), stacklevel=2, ) with h5py.File(output_file, "w") as f: f.attrs["default"] = "NMX_data" nx_entry = _create_root_data_entry(f) _create_sample_group(data, nx_entry) _create_instrument_group(data, nx_entry) _create_detector_group(data, nx_entry) _create_source_group(data, nx_entry)
def _create_lauetof_data_entry(file_obj: h5py.File) -> h5py.Group: nx_entry = file_obj.create_group("entry") nx_entry.attrs["NX_class"] = "NXentry" return nx_entry def _add_lauetof_definition(nx_entry: h5py.Group) -> None: _create_dataset_from_string(root_entry=nx_entry, name="definition", var="NXlauetof") def _add_lauetof_instrument(nx_entry: h5py.Group) -> h5py.Group: nx_instrument = nx_entry.create_group("instrument") nx_instrument.attrs["NX_class"] = "NXinstrument" _create_dataset_from_string(root_entry=nx_instrument, name="name", var="NMX") return nx_instrument def _add_lauetof_source_group( dg: NMXExperimentMetadata, nx_instrument: h5py.Group ) -> None: nx_source = nx_instrument.create_group("source") nx_source.attrs["NX_class"] = "NXsource" _create_dataset_from_string( root_entry=nx_source, name="name", var="European Spallation Source" ) _create_dataset_from_string(root_entry=nx_source, name="short_name", var="ESS") _create_dataset_from_string( root_entry=nx_source, name="type", var="Spallation Neutron Source" ) _create_dataset_from_var( root_entry=nx_source, name="distance", var=sc.norm(dg["source_position"]) ) # Legacy probe information. _create_dataset_from_string(root_entry=nx_source, name="probe", var="neutron") def _add_lauetof_detector_group(dg: sc.DataGroup, nx_instrument: h5py.Group) -> None: nx_detector = nx_instrument.create_group(dg["detector_name"].value) # Detector name nx_detector.attrs["NX_class"] = "NXdetector" _create_dataset_from_var( name="polar_angle", root_entry=nx_detector, var=sc.scalar(0, unit='deg'), # TODO: Add real data ) _create_dataset_from_var( name="azimuthal_angle", root_entry=nx_detector, var=sc.scalar(0, unit='deg'), # TODO: Add real data ) _create_dataset_from_var( name="x_pixel_size", root_entry=nx_detector, var=dg["x_pixel_size"] ) _create_dataset_from_var( name="y_pixel_size", root_entry=nx_detector, var=dg["y_pixel_size"] ) _create_dataset_from_var( name="distance", root_entry=nx_detector, var=sc.scalar(0, unit='m'), # TODO: Add real data ) # Legacy geometry information until we have a better way to store it _create_dataset_from_var( name="origin", root_entry=nx_detector, var=dg['origin_position'] ) # Fast axis, along where the pixel ID increases by 1 _create_dataset_from_var( root_entry=nx_detector, var=dg['fast_axis'], name="fast_axis" ) # Slow axis, along where the pixel ID increases # by the number of pixels in the fast axis _create_dataset_from_var( root_entry=nx_detector, var=dg['slow_axis'], name="slow_axis" ) def _add_lauetof_sample_group(dg: NMXExperimentMetadata, nx_entry: h5py.Group) -> None: nx_sample = nx_entry.create_group("sample") nx_sample.attrs["NX_class"] = "NXsample" _create_dataset_from_var( root_entry=nx_sample, var=dg['crystal_rotation'], name='crystal_rotation', long_name='crystal rotation in Phi (XYZ)', ) _create_dataset_from_string( root_entry=nx_sample, name='name', var=dg['sample_name'].value, ) _create_dataset_from_var( name='orientation_matrix', root_entry=nx_sample, var=sc.array( dims=['i', 'j'], values=[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], unit="dimensionless", ), # TODO: Add real data, the sample orientation matrix ) _create_dataset_from_var( name='unit_cell', root_entry=nx_sample, var=sc.array( dims=['i'], values=[1.0, 1.0, 1.0, 90.0, 90.0, 90.0], unit="dimensionless", # TODO: Add real data, # a, b, c, alpha, beta, gamma ), ) def _add_lauetof_monitor_group(data: sc.DataGroup, nx_entry: h5py.Group) -> None: nx_monitor = nx_entry.create_group("control") nx_monitor.attrs["NX_class"] = "NXmonitor" _create_dataset_from_string(root_entry=nx_monitor, name='mode', var='monitor') nx_monitor["preset"] = 0.0 # Check if this is the correct value data_dset = _create_dataset_from_var( name='data', root_entry=nx_monitor, var=sc.array( dims=['tof'], values=[1, 1, 1], unit="counts" ), # TODO: Add real data, bin values ) data_dset.attrs["signal"] = 1 data_dset.attrs["primary"] = 1 _create_dataset_from_var( name='time_of_flight', root_entry=nx_monitor, var=sc.array( dims=['tof'], values=[1, 1, 1], unit="s" ), # TODO: Add real data, bin edges ) def _add_arbitrary_metadata( nx_entry: h5py.Group, **arbitrary_metadata: sc.Variable ) -> None: if not arbitrary_metadata: return metadata_group = nx_entry.create_group("metadata") for key, value in arbitrary_metadata.items(): if not isinstance(value, sc.Variable): import warnings msg = f"Skipping metadata key '{key}' as it is not a scipp.Variable." warnings.warn(UserWarning(msg), stacklevel=2) continue else: _create_dataset_from_var( name=key, root_entry=metadata_group, var=value, ) def _export_static_metadata_as_nxlauetof( experiment_metadata: NMXExperimentMetadata, output_file: str | pathlib.Path | io.BytesIO, **arbitrary_metadata: sc.Variable, ) -> None: """Export the metadata to a NeXus file with the LAUE_TOF application definition. ``Metadata`` in this context refers to the information that is not part of the reduced detector counts itself, but is necessary for the interpretation of the reduced data. Since NMX can have arbitrary number of detectors, this function can take multiple detector metadata objects. Parameters ---------- experiment_metadata: Experiment metadata object. output_file: Output file path. arbitrary_metadata: Arbitrary metadata that does not fit into the existing metadata objects. """ with h5py.File(output_file, "w") as f: f.attrs["NX_class"] = "NXlauetof" nx_entry = _create_lauetof_data_entry(f) _add_lauetof_definition(nx_entry) _add_lauetof_sample_group(experiment_metadata, nx_entry) nx_instrument = _add_lauetof_instrument(nx_entry) _add_lauetof_source_group(experiment_metadata, nx_instrument) # Placeholder for ``monitor`` group _add_lauetof_monitor_group(experiment_metadata, nx_entry) # Skipping ``NXdata``(name) field with data link # Add arbitrary metadata _add_arbitrary_metadata(nx_entry, **arbitrary_metadata) def _export_detector_metadata_as_nxlauetof( *detector_metadatas: NMXDetectorMetadata, output_file: str | pathlib.Path | io.BytesIO, append_mode: bool = True, ) -> None: """Export the detector specific metadata to a NeXus file. Since NMX can have arbitrary number of detectors, this function can take multiple detector metadata objects. Parameters ---------- detector_metadatas: Detector metadata objects. output_file: Output file path. """ if not append_mode: raise NotImplementedError("Only append mode is supported for now.") with h5py.File(output_file, "r+") as f: nx_entry = f["entry"] if "instrument" not in nx_entry: nx_instrument = _add_lauetof_instrument(f["entry"]) else: nx_instrument = nx_entry["instrument"] # Add detector group metadata for detector_metadata in detector_metadatas: _add_lauetof_detector_group(detector_metadata, nx_instrument) def _extract_counts(dg: sc.DataGroup) -> sc.Variable: counts: sc.DataArray = dg['counts'].data if 'id' in counts.dims: num_x, num_y = dg["detector_shape"].value return sc.fold(counts, dim='id', sizes={'x': num_x, 'y': num_y}) else: # If there is no 'id' dimension, we assume it is already in the correct shape return counts def _export_reduced_data_as_nxlauetof( dg: NMXReducedDataGroup, output_file: str | pathlib.Path | io.BytesIO, *, append_mode: bool = True, compress_counts: bool = True, ) -> None: """Export the reduced data to a NeXus file with the LAUE_TOF application definition. Even though this function only exports reduced data(detector counts and its coordinates), the input should contain all the necessary metadata for minimum sanity check. Parameters ---------- dg: Reduced data and metadata. output_file: Output file path. append_mode: If ``True``, the file is opened in append mode. If ``False``, the file is opened in None-append mode. > None-append mode is not supported for now. > Only append mode is supported for now. compress_counts: If ``True``, the detector counts are compressed using bitshuffle. It is because only the detector counts are expected to be large. """ if not append_mode: raise NotImplementedError("Only append mode is supported for now.") with h5py.File(output_file, "r+") as f: nx_detector: h5py.Group = f[f"entry/instrument/{dg['detector_name'].value}"] # Data - shape: [n_x_pixels, n_y_pixels, n_tof_bins] # The actual application definition defines it as integer, # but we keep the original data type for now num_x, num_y = dg["detector_shape"].value # Probably better way to do this if compress_counts: data_dset = _create_compressed_dataset( name="data", root_entry=nx_detector, var=_extract_counts(dg), chunks=(num_x, num_y, 1), dtype=np.uint, ) else: data_dset = _create_dataset_from_var( name="data", root_entry=nx_detector, var=_extract_counts(dg), dtype=np.uint, ) data_dset.attrs["signal"] = 1 _create_dataset_from_var( name='time_of_flight', root_entry=nx_detector, var=sc.midpoints(dg['counts'].coords['t'], dim='t'), ) def _check_file( filename: str | pathlib.Path | io.BytesIO, overwrite: bool ) -> pathlib.Path | io.BytesIO: if isinstance(filename, str | pathlib.Path): filename = pathlib.Path(filename) if filename.exists() and not overwrite: raise FileExistsError( f"File '{filename}' already exists. Use `overwrite=True` to overwrite." ) return filename T = TypeVar("T", bound=sc.DataArray)
[docs] class NXLauetofWriter:
[docs] def __init__( self, *, output_filename: str | pathlib.Path | io.BytesIO, workflow: sl.Pipeline, chunk_generator: Callable[[FilePath, DetectorName], Generator[T, None, None]], chunk_insert_key: type[T], extra_meta: dict[str, sc.Variable] | None = None, compress_counts: bool = True, overwrite: bool = False, ) -> None: from ess.reduce.streaming import EternalAccumulator, StreamProcessor from .types import FilePath, NMXReducedCounts self.compress_counts = compress_counts self._chunk_generator = chunk_generator self._chunk_insert_key = chunk_insert_key self._workflow = workflow self._output_filename = _check_file(output_filename, overwrite) self._input_filename = workflow.compute(FilePath) self._final_stream_processor = partial( StreamProcessor, dynamic_keys=(chunk_insert_key,), target_keys=(NMXReducedDataGroup,), accumulators={NMXReducedCounts: EternalAccumulator}, ) self._detector_metas: dict[DetectorName, NMXDetectorMetadata] = {} self._detector_reduced: dict[DetectorName, NMXReducedDataGroup] = {} _export_static_metadata_as_nxlauetof( experiment_metadata=self._workflow.compute(NMXExperimentMetadata), output_file=self._output_filename, **(extra_meta or {}), )
def add_panel( self, *, detector_id: DetectorIndex | DetectorName ) -> NMXReducedDataGroup: from .types import PixelIds temp_wf = self._workflow.copy() if isinstance(detector_id, int): temp_wf[DetectorIndex] = detector_id elif isinstance(detector_id, str): temp_wf[DetectorName] = detector_id else: raise TypeError( f"Expected detector_id to be an int or str, got {type(detector_id)}" ) _export_detector_metadata_as_nxlauetof( temp_wf.compute(NMXDetectorMetadata), output_file=self._output_filename, ) # First compute static information detector_name = temp_wf.compute(DetectorName) temp_wf[PixelIds] = temp_wf.compute(PixelIds) processor = self._final_stream_processor(temp_wf) # Then iterate over the chunks for da in self._chunk_generator(self._input_filename, detector_name): if any(da.sizes.values()) == 0: continue else: results = processor.add_chunk({self._chunk_insert_key: da}) _export_reduced_data_as_nxlauetof( results[NMXReducedDataGroup], self._output_filename, compress_counts=self.compress_counts, ) return results[NMXReducedDataGroup]