# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
"""Workflow and workflow components for interacting with NeXus files."""
from collections.abc import Sequence
from typing import Any
import networkx as nx
import sciline
import scipp as sc
import scippnexus as snx
from scipp.constants import g
from . import _nexus_loader as nexus
from .types import (
CalibratedBeamline,
CalibratedDetector,
CalibratedMonitor,
Component,
DetectorBankSizes,
DetectorData,
DetectorPositionOffset,
Filename,
GravityVector,
MonitorData,
MonitorPositionOffset,
MonitorType,
NeXusClass,
NeXusComponent,
NeXusComponentLocationSpec,
NeXusData,
NeXusDataLocationSpec,
NeXusFileSpec,
NeXusName,
NeXusTransformation,
NeXusTransformationChain,
Position,
PreopenNeXusFile,
PulseSelection,
RunType,
UniqueComponent,
)
origin = sc.vector([0, 0, 0], unit="m")
"""The origin, used as default sample position."""
no_offset = sc.vector([0, 0, 0], unit="m")
"""Offset that does not change the position."""
[docs]
def file_path_to_file_spec(
filename: Filename[RunType], preopen: PreopenNeXusFile
) -> NeXusFileSpec[RunType]:
return NeXusFileSpec[RunType](
snx.File(filename, definitions=definitions) if preopen else filename
)
[docs]
def no_monitor_position_offset() -> MonitorPositionOffset[RunType, MonitorType]:
return MonitorPositionOffset[RunType, MonitorType](no_offset)
[docs]
def no_detector_position_offset() -> DetectorPositionOffset[RunType]:
return DetectorPositionOffset[RunType](no_offset)
[docs]
def all_pulses() -> PulseSelection[RunType]:
"""Select all neutron pulses in the data."""
return PulseSelection[RunType](slice(None, None))
[docs]
def gravity_vector_neg_y() -> GravityVector:
"""
Gravity vector for default instrument coordinate system where y is up.
"""
return GravityVector(sc.vector(value=[0, -1, 0]) * g)
[docs]
def component_spec_by_name(
filename: NeXusFileSpec[RunType], name: NeXusName[Component]
) -> NeXusComponentLocationSpec[Component, RunType]:
"""
Create a location spec for a component group in a NeXus file.
Parameters
----------
filename:
NeXus file to use for the location spec.
name:
Name of the component group.
"""
return NeXusComponentLocationSpec[Component, RunType](
filename=filename.value, component_name=name
)
[docs]
def unique_component_spec(
filename: NeXusFileSpec[RunType],
) -> NeXusComponentLocationSpec[UniqueComponent, RunType]:
"""
Create a location spec for a unique component group in a NeXus file.
Parameters
----------
filename:
NeXus file to use for the location spec.
"""
return NeXusComponentLocationSpec[UniqueComponent, RunType](filename=filename.value)
[docs]
def data_by_name(
filename: NeXusFileSpec[RunType],
name: NeXusName[Component],
selection: PulseSelection[RunType],
) -> NeXusDataLocationSpec[Component, RunType]:
"""
Create a location spec for monitor or detector data in a NeXus file.
Parameters
----------
filename:
NeXus file to use for the location spec.
name:
Name of the monitor or detector group.
selection:
Time range (start and stop as a Python slice object).
"""
return NeXusDataLocationSpec[Component, RunType](
filename=filename.value, component_name=name, selection=selection.value
)
[docs]
def load_nexus_sample(
location: NeXusComponentLocationSpec[snx.NXsample, RunType],
) -> NeXusComponent[snx.NXsample, RunType]:
"""
Load a NeXus sample group from a file.
If there is no sample group in the file, an empty group is returned. This should
not happen, but handling it gracefully makes testing and working with
pre-production files easier. There should be little harm in returning an empty
group. Subsequent extract of the sample position will then default to the origin.
Parameters
----------
location:
Location spec for the sample group.
"""
try:
dg = nexus.load_component(location, nx_class=snx.NXsample)
except ValueError:
dg = sc.DataGroup()
if 'depends_on' not in dg:
dg['depends_on'] = snx.TransformationChain(parent='', value='.')
return NeXusComponent[snx.NXsample, RunType](dg)
[docs]
def nx_class_for_monitor() -> NeXusClass[MonitorType]:
return NeXusClass[MonitorType](snx.NXmonitor)
[docs]
def nx_class_for_detector() -> NeXusClass[snx.NXdetector]:
return NeXusClass[snx.NXdetector](snx.NXdetector)
[docs]
def nx_class_for_source() -> NeXusClass[snx.NXsource]:
return NeXusClass[snx.NXsource](snx.NXsource)
[docs]
def nx_class_for_sample() -> NeXusClass[snx.NXsample]:
return NeXusClass[snx.NXsample](snx.NXsample)
[docs]
def load_nexus_component(
location: NeXusComponentLocationSpec[Component, RunType],
nx_class: NeXusClass[Component],
) -> NeXusComponent[Component, RunType]:
"""
Load a NeXus component group from a file.
When loading a detector or monitor, event data is replaced by placeholders.
As the event data can be large and is not needed at this stage, it is replaced by
a placeholder. A placeholder is used to allow for returning a scipp.DataArray, which
is what most downstream code will expect.
Currently the placeholder is the detector number (for detectors) or a size-0 array
(for monitors), but this may change in the future.
The returned object is a scipp.DataGroup, as it may contain additional information
about the detector that cannot be represented as a single scipp.DataArray. Most
downstream code will only be interested in the contained scipp.DataArray so this
needs to be extracted. However, other processing steps may require the additional
information, so it is kept in the DataGroup.
Parameters
----------
location:
Location spec for the source group.
nx_class:
NX_class to identify the component.
"""
return NeXusComponent[Component, RunType](
nexus.load_component(location, nx_class=nx_class, definitions=definitions)
)
[docs]
def load_nexus_data(
location: NeXusDataLocationSpec[Component, RunType],
) -> NeXusData[Component, RunType]:
"""
Load event or histogram data from a NeXus detector group.
Parameters
----------
location:
Location spec for the detector group.
"""
return NeXusData[Component, RunType](
nexus.load_data(
file_path=location.filename,
entry_name=location.entry_name,
selection=location.selection,
component_name=location.component_name,
)
)
[docs]
def compute_position(
transformation: NeXusTransformation[Component, RunType],
) -> Position[Component, RunType]:
"""Compute the position of a component from a transformation matrix."""
return Position[Component, RunType](transformation.value * origin)
[docs]
def get_calibrated_detector(
detector: NeXusComponent[snx.NXdetector, RunType],
*,
transform: NeXusTransformation[snx.NXdetector, RunType],
# Strictly speaking we could apply an offset by modifying the transformation chain,
# using a more generic implementation. However, this may in general require
# extending the chain and it is currently not clear if that is desirable. As far as
# I am aware the offset is currently mainly used for handling files from other
# facilities and it is not clear if it is needed for ESS data and should be kept at
# all.
offset: DetectorPositionOffset[RunType],
bank_sizes: DetectorBankSizes,
) -> CalibratedDetector[RunType]:
"""
Extract the data array corresponding to a detector's signal field.
The returned data array includes coords and masks pertaining directly to the
signal values array, but not additional information about the detector. The
data array is reshaped to the logical detector shape, which by folding the data
array along the detector_number dimension.
Parameters
----------
detector:
NeXus detector group.
offset:
Offset to add to the detector position.
bank_sizes:
Dictionary of detector bank sizes.
"""
da = nexus.extract_signal_data_array(detector)
if (
sizes := (bank_sizes or {}).get(detector.get('nexus_component_name'))
) is not None:
da = da.fold(dim="detector_number", sizes=sizes)
# Note: We apply offset as early as possible, i.e., right in this function
# the detector array from the raw loader NeXus group, to prevent a source of bugs.
position = transform.value * snx.zip_pixel_offsets(da.coords)
return CalibratedDetector[RunType](
da.assign_coords(position=position + offset.to(unit=position.unit))
)
[docs]
def assemble_beamline(
detector: CalibratedDetector[RunType],
source_position: Position[snx.NXsource, RunType],
sample_position: Position[snx.NXsample, RunType],
gravity: GravityVector,
) -> CalibratedBeamline[RunType]:
"""
Add beamline information (gravity vector, source- and sample-position) to detector.
This is performed separately and after :py:func:`get_calibrated_detector` to avoid
as false dependency of, e.g., the reshaped detector numbers on the sample position.
The latter can change during a run, e.g., for a rotating sample. The detector
numbers might be used, e.g., to mask certain detector pixels, and should not depend
on the sample position.
Parameters
----------
detector:
NeXus detector group.
source_position:
Position of the neutron source.
sample_position:
Position of the sample.
gravity:
Gravity vector.
"""
return CalibratedBeamline[RunType](
detector.assign_coords(
source_position=source_position,
sample_position=sample_position,
gravity=gravity,
)
)
[docs]
def assemble_detector_data(
detector: CalibratedBeamline[RunType],
event_data: NeXusData[snx.NXdetector, RunType],
) -> DetectorData[RunType]:
"""
Assemble a detector data array with event data.
Also adds variances to the event data if they are missing.
Parameters
----------
detector:
Calibrated detector data array.
event_data:
Event data array.
"""
grouped = nexus.group_event_data(
event_data=event_data, detector_number=detector.coords['detector_number']
)
return DetectorData[RunType](
_add_variances(grouped)
.assign_coords(detector.coords)
.assign_masks(detector.masks)
)
[docs]
def get_calibrated_monitor(
monitor: NeXusComponent[MonitorType, RunType],
offset: MonitorPositionOffset[RunType, MonitorType],
source_position: Position[snx.NXsource, RunType],
) -> CalibratedMonitor[RunType, MonitorType]:
"""
Extract the data array corresponding to a monitor's signal field.
The returned data array includes coords pertaining directly to the
signal values array, but not additional information about the monitor.
Parameters
----------
monitor:
NeXus monitor group.
offset:
Offset to add to the monitor position.
source_position:
Position of the neutron source.
"""
monitor = nexus.compute_component_position(monitor)
return CalibratedMonitor[RunType, MonitorType](
nexus.extract_signal_data_array(monitor).assign_coords(
position=monitor['position'] + offset.to(unit=monitor['position'].unit),
source_position=source_position,
)
)
[docs]
def assemble_monitor_data(
monitor: CalibratedMonitor[RunType, MonitorType],
data: NeXusData[MonitorType, RunType],
) -> MonitorData[RunType, MonitorType]:
"""
Assemble a monitor data array with event data.
Also adds variances to the event data if they are missing.
Parameters
----------
monitor:
Calibrated monitor data array.
data:
Data array with neutron counts.
"""
da = data.assign_coords(monitor.coords).assign_masks(monitor.masks)
return MonitorData[RunType, MonitorType](_add_variances(da))
def _drop(
children: dict[str, snx.Field | snx.Group], classes: tuple[snx.NXobject, ...]
) -> dict[str, snx.Field | snx.Group]:
return {
name: child
for name, child in children.items()
if not (isinstance(child, snx.Group) and (child.nx_class in classes))
}
class _StrippedDetector(snx.NXdetector):
"""Detector definition without large geometry or event data for ScippNexus.
Drops NXoff_geometry and NXevent_data groups, data is replaced by detector_number.
"""
def __init__(
self, attrs: dict[str, Any], children: dict[str, snx.Field | snx.Group]
):
children = _drop(children, (snx.NXoff_geometry, snx.NXevent_data))
children['data'] = children['detector_number']
super().__init__(attrs=attrs, children=children)
class _DummyField:
"""Dummy field that can replace snx.Field in NXmonitor."""
def __init__(self):
self.attrs = {}
self.sizes = {'event_time_zero': 0}
self.dims = ('event_time_zero',)
self.shape = (0,)
def __getitem__(self, key: Any) -> sc.Variable:
return sc.empty(dims=self.dims, shape=self.shape, unit=None)
class _StrippedMonitor(snx.NXmonitor):
"""Monitor definition without event data for ScippNexus.
Drops NXevent_data group, data is replaced by a dummy field.
"""
def __init__(
self, attrs: dict[str, Any], children: dict[str, snx.Field | snx.Group]
):
children = _drop(children, (snx.NXevent_data,))
children['data'] = _DummyField()
super().__init__(attrs=attrs, children=children)
def _add_variances(da: sc.DataArray) -> sc.DataArray:
out = da.copy(deep=False)
if out.bins is not None:
content = out.bins.constituents['data']
if content.variances is None:
content.variances = content.values
return out
definitions = snx.base_definitions()
definitions["NXdetector"] = _StrippedDetector
definitions["NXmonitor"] = _StrippedMonitor
_common_providers = (
gravity_vector_neg_y,
file_path_to_file_spec,
all_pulses,
component_spec_by_name,
unique_component_spec, # after component_spec_by_name, partially overrides
get_transformation_chain,
to_transformation,
compute_position,
load_nexus_data,
load_nexus_component,
data_by_name,
nx_class_for_detector,
nx_class_for_monitor,
nx_class_for_source,
nx_class_for_sample,
)
_monitor_providers = (
no_monitor_position_offset,
get_calibrated_monitor,
assemble_monitor_data,
)
_detector_providers = (
no_detector_position_offset,
load_nexus_sample,
get_calibrated_detector,
assemble_beamline,
assemble_detector_data,
)
[docs]
def LoadMonitorWorkflow() -> sciline.Pipeline:
"""Generic workflow for loading monitor data from a NeXus file."""
wf = sciline.Pipeline((*_common_providers, *_monitor_providers))
wf[PreopenNeXusFile] = PreopenNeXusFile(False)
return wf
[docs]
def LoadDetectorWorkflow() -> sciline.Pipeline:
"""Generic workflow for loading detector data from a NeXus file."""
wf = sciline.Pipeline((*_common_providers, *_detector_providers))
wf[DetectorBankSizes] = DetectorBankSizes({})
wf[PreopenNeXusFile] = PreopenNeXusFile(False)
return wf
[docs]
def GenericNeXusWorkflow(
*,
run_types: Sequence[sciline.typing.Key] | None = None,
monitor_types: Sequence[sciline.typing.Key] | None = None,
) -> sciline.Pipeline:
"""
Generic workflow for loading detector and monitor data from a NeXus file.
Parameters
----------
run_types:
List of run types to include in the workflow. If not provided, all run types
are included. It is recommended to specify run types to avoid creating very
large workflows.
monitor_types:
List of monitor types to include in the workflow. If not provided, all monitor
types are included. It is recommended to specify monitor types to avoid creating
very large workflows.
Returns
-------
:
The workflow.
"""
if monitor_types is not None and run_types is None:
raise ValueError("run_types must be specified if monitor_types is specified")
wf = sciline.Pipeline(
(*_common_providers, *_monitor_providers, *_detector_providers)
)
wf[DetectorBankSizes] = DetectorBankSizes({})
wf[PreopenNeXusFile] = PreopenNeXusFile(False)
g = wf.underlying_graph
ancestors = set()
# DetectorData and MonitorData are the "final" outputs, so finding and removing all
# their ancestors is what we need to strip unused run and monitor types.
for rt in run_types or ():
ancestors |= nx.ancestors(g, DetectorData[rt])
ancestors.add(DetectorData[rt])
for mt in monitor_types or ():
ancestors |= nx.ancestors(g, MonitorData[rt, mt])
ancestors.add(MonitorData[rt, mt])
if run_types is not None:
g.remove_nodes_from(set(g.nodes) - ancestors)
return wf