# Bragg-edge imaging with ODIN

This notebook illustrates how to convert recorded events on the ODIN detector to a single wavelength spectrum,
revealing a Bragg edge in the data.
WFM mode was used in the chopper cascade.

## Loading dataset

> Loader is not part of ``essimaging`` since McStas dataset format is not stabilized yet.

In [None]:
import scipp as sc
import scippnexus as snx
import scipp.constants as scc
from typing import cast, NewType
from ess.reduce.nexus.types import FilePath


_DataPath = NewType('_DataPath', str)
_DefaultDataPath = _DataPath(
    "entry1/data/transmission_event_signal_dat_list_p_t_x_y_z_vx_vy_vz/events"
)
_FileLock = NewType('_FileLock', bool)
"""Lock the file to prevent concurrent access."""
_DefaultFileLock = _FileLock(True)
OdinSimulationRawData = NewType('OdinSimulationRawData', sc.DataArray)
ProbabilityToCountsScaleFactor = NewType('ProbabilityToCountsScaleFactor', sc.Variable)
"""Translate the probability to counts."""
DefaultProbabilityToCountsScaleFactor = ProbabilityToCountsScaleFactor(
    sc.scalar(1_000, unit='dimensionless')
)
DetectorStartX = NewType('DetectorStartX', sc.Variable)
"""Start of the detector in x direction."""
DefaultDetectorStartX = DetectorStartX(sc.scalar(-0.03, unit='m'))
DetectorStartY = NewType('DetectorStartY', sc.Variable)
"""Start of the detector in y direction."""
DefaultDetectorStartY = DetectorStartY(sc.scalar(-0.03, unit='m'))

DetectorEndX = NewType('DetectorEndX', sc.Variable)
"""End of the detector in x direction."""
DefaultDetectorEndX = DetectorEndX(sc.scalar(0.03, unit='m'))
DetectorEndY = NewType('DetectorEndY', sc.Variable)
"""End of the detector in y direction."""
DefaultDetectorEndY = DetectorEndY(sc.scalar(0.03, unit='m'))

McStasManualResolution = NewType('McStasManualResolution', tuple)
"""Manual resolution for McStas data (how many pixels per axis x, y)"""
DefaultMcStasManualResolution = McStasManualResolution((1024, 1024))

example_resolution = McStasManualResolution((128, 128))
# Small resolution for faster testing and documentation build.


def _nth_col_or_row_lookup(
    start: sc.Variable, stop: sc.Variable, resolution: int, dim: str
) -> sc.Lookup:
    """Lookup the nth column or row."""
    position = sc.linspace(
        dim, start=start, stop=stop, num=resolution + 1, unit=start.unit
    )
    nth_col_or_row = sc.arange(dim=dim, start=0, stop=resolution, unit='dimensionless')
    hist = sc.DataArray(data=nth_col_or_row, coords={dim: position})
    return sc.lookup(hist, dim)


def _position_to_pixel_id(
    *,
    x_pos: sc.Variable,
    y_pos: sc.Variable,
    detector_start_x: DetectorStartX = DefaultDetectorStartX,
    detector_start_y: DetectorStartY = DefaultDetectorStartY,
    detector_end_x: DetectorEndX = DefaultDetectorEndX,
    detector_end_y: DetectorEndY = DefaultDetectorEndY,
    resolution: McStasManualResolution = DefaultMcStasManualResolution,
) -> sc.Variable:
    """Hardcode pixel ids from positions."""
    x_position_lookup = _nth_col_or_row_lookup(
        detector_start_x, detector_end_x, resolution[0], 'x'
    )
    y_position_lookup = _nth_col_or_row_lookup(
        detector_start_y, detector_end_y, resolution[1], 'y'
    )
    n_cols = x_position_lookup[x_pos]
    n_rows = y_position_lookup[y_pos]
    return n_rows * resolution[0] + n_cols


McStasVelocities = NewType('McStasVelocities', sc.DataGroup)


def load_velocities(
    file_path: FilePath,
    _data_path: _DataPath = _DefaultDataPath,
    _file_lock: _FileLock = _DefaultFileLock,
) -> McStasVelocities:
    with snx.File(file_path, "r", locking=_file_lock) as f:
        data = f[_data_path][()].rename_dims({'dim_0': 'event'})
        velocities = data['dim_1', 5:8]
        vx = cast(sc.Variable, velocities['dim_1', 0].copy())
        vy = cast(sc.Variable, velocities['dim_1', 1].copy())
        vz = cast(sc.Variable, velocities['dim_1', 2].copy())
        for v_component in (vx, vy, vz):
            v_component.unit = 'm/s'
        # Add special tags if you want to use them as coordinates
        # for example, da.coords['vx_MC'] = vx
        # to distinguish them from the measurement
        return McStasVelocities(sc.DataGroup(vx=vx, vy=vy, vz=vz))


LoadTrueVelocities = NewType('LoadTrueVelocities', bool)
DefaultLoadTrueVelocities = LoadTrueVelocities(True)


def load_odin_simulation_data(
    file_path: FilePath,
    _data_path: _DataPath = _DefaultDataPath,
    _file_lock: _FileLock = _DefaultFileLock,
    detector_start_x: DetectorStartX = DefaultDetectorStartX,
    detector_start_y: DetectorStartY = DefaultDetectorStartY,
    detector_end_x: DetectorEndX = DefaultDetectorEndX,
    detector_end_y: DetectorEndY = DefaultDetectorEndY,
    resolution: McStasManualResolution = DefaultMcStasManualResolution,
    probability_scale_factor: ProbabilityToCountsScaleFactor = DefaultProbabilityToCountsScaleFactor,
    load_true_velocities: LoadTrueVelocities = DefaultLoadTrueVelocities,
) -> OdinSimulationRawData:
    with snx.File(file_path, "r", locking=_file_lock) as f:
        # The name p_t_x_y_z_vx_vy_vz represents
        # probability, time of arrival, position(x, y, z) and velocity(vx, vy, vz).
        # The name also represents the order of each field in the table.
        # For example, probability is the first field, so data['dim_1', 0] is the probability.
        data = f[_data_path][()].rename_dims({'dim_0': 'event'})
        probabilities = cast(sc.Variable, data['dim_1', 0].copy())
        probabilities.unit = 'dimensionless'
        time_of_arrival = cast(sc.Variable, data['dim_1', 1].copy())
        time_of_arrival.unit = 's'  # Hardcoded unit from the data.
        positions = data['dim_1', 2:5]
        counts = (probabilities / probabilities.max()) * probability_scale_factor
        counts.unit = 'counts'
        # Units are hardcoded from the data.
        x_pos = cast(sc.Variable, positions['dim_1', 0].copy())
        x_pos.unit = 'm'
        y_pos = cast(sc.Variable, positions['dim_1', 1].copy())
        y_pos.unit = 'm'
        pixel_id = _position_to_pixel_id(
            x_pos=x_pos,
            y_pos=y_pos,
            detector_start_x=detector_start_x,
            detector_start_y=detector_start_y,
            detector_end_x=detector_end_x,
            detector_end_y=detector_end_y,
            resolution=resolution,
        )
        da = sc.DataArray(
            data=counts.copy().astype(sc.DType.int32),
            coords={
                'time_of_arrival': time_of_arrival.to(unit='us'),
                'sample_position': sc.vector([0.0, 0.0, 60.5], unit='m'),
                # Hardcoded from the data.
                'source_position': sc.vector([0.0, 0.0, 0.0], unit="m"),
                # Hardcoded from the data.
                'pixel_id': pixel_id,
            },
        )
        if load_true_velocities:
            velocities = load_velocities(file_path, _data_path, _file_lock)
            speeds = sc.norm(
                sc.vectors(
                    dims=['event'],
                    values=sc.transpose(
                        sc.concat(list(velocities.values()), 'speed')
                    ).values,
                    unit='m/s',
                )
            )
            da.coords['sim_wavelength'] = (scc.h / scc.neutron_mass / speeds).to(
                unit='angstrom'
            )

        return OdinSimulationRawData(da.to(dtype=float))


In [None]:
from ess.imaging.data import get_mcstas_ob_images_path, get_mcstas_sample_images_path

ob_file_path = FilePath(get_mcstas_ob_images_path())
sample_file_path = FilePath(get_mcstas_sample_images_path())
ob_da = load_odin_simulation_data(ob_file_path, resolution=example_resolution)
sample_da = load_odin_simulation_data(sample_file_path, resolution=example_resolution)
sample_da

In [None]:
def _pixel_ids_to_x(
    *,
    pixel_id: sc.Variable,
    resolution: McStasManualResolution = DefaultMcStasManualResolution,
    detector_start_x: DetectorStartX = DefaultDetectorStartX,
    detector_end_x: DetectorEndX = DefaultDetectorEndX,
) -> sc.Variable:
    n_col = pixel_id % resolution[0]
    x_interval = (detector_end_x - detector_start_x) / resolution[0]
    return (
        detector_start_x + n_col * x_interval
    ) + x_interval / 2  # Center of the pixel|


def _pixel_ids_to_y(
    *,
    pixel_id: sc.Variable,
    resolution: McStasManualResolution = DefaultMcStasManualResolution,
    detector_start_y: DetectorStartY = DefaultDetectorStartY,
    detector_end_y: DetectorEndY = DefaultDetectorEndY,
) -> sc.Variable:
    n_row = pixel_id // resolution[0]
    y_interval = (detector_end_y - detector_start_y) / resolution[1]
    return (
        detector_start_y + n_row * y_interval
    ) + y_interval / 2  # Center of the pixel


def _pixel_ids_to_position(
    *, x: sc.Variable, y: sc.Variable, z_pos: sc.Variable
) -> sc.Variable:
    z = sc.zeros_like(x) + z_pos
    var = (
        sc.concat([x, y, z], 'event')
        .fold('event', dims=['pos', 'event'], shape=[3, len(x)])
        .transpose(dims=['event', 'pos'])
        .values
    )
    return sc.vectors(dims=['event'], values=var, unit='m')


In [None]:
import scipp as sc
from scippneutron.conversion import graph

plane_graph = {**graph.beamline.beamline(False), **graph.tof.kinematic("tof")}

# TODO: Replace this with actual WFM stitching method
plane_graph['tof'] = lambda time_of_arrival: time_of_arrival
plane_graph['x'] = lambda pixel_id: _pixel_ids_to_x(
    pixel_id=pixel_id, resolution=example_resolution
)
plane_graph['y'] = lambda pixel_id: _pixel_ids_to_y(
    pixel_id=pixel_id, resolution=example_resolution
)
plane_graph['position'] = lambda x, y: _pixel_ids_to_position(
    x=x,
    y=y,
    z_pos=sc.scalar(60.5, unit='m'),  # Hardcoded from the data.
)

sc.show_graph(plane_graph, simplified=True)

In [None]:
coords = ["tof", "position", "x", "y", "sim_wavelength", "Ltotal"]

sample_da = sample_da.transform_coords(
    coords,
    graph=plane_graph,
    keep_intermediate=False,
)
ob_da = ob_da.transform_coords(
    coords,
    graph=plane_graph,
    keep_intermediate=False,
)

sample_da

In [None]:
sample_da.hist(tof=300).plot()

## Convert McStas raw data to NeXus

The raw McStas data looks different from what data in a NeXus file would look like.
The time-of-flight recorded by the McStas monitor is a unwrapped time of arrival
(see https://scipp.github.io/scippneutron/user-guide/chopper/frame-unwrapping.html);
the `tof` coordinate has values beyond 71ms,
as can be seen in the plot above.

The workflow that computes wavelengths from the WFM chopper cascade expects data in the NeXus format,
so we transform the data here.

In [None]:
def to_nexus(da):
    unit = da.coords['tof'].unit
    period = (1.0 / sc.scalar(14., unit='Hz')).to(unit=unit)
    # Bin the data into bins with a 71ms period
    n = int(sample_da.coords['tof'].max() / period)
    da = da.bin(tof=sc.arange('tof', n + 2) * period)
    # Add a event_time_zero coord for each bin, but not as bin edges, as all events in the same pulse have the same event_time_zero, hence the `[:2]`
    da.coords['event_time_zero'] = (sc.scalar(1730450434078980000, unit='ns').to(unit=unit) + da.coords['tof'])[:-1]
    # Remove the meaningless tof coord at the top level
    del da.coords['tof']
    # Remove the original (wrong) event_time_zero event coord inside the bins and rename the dim
    del da.bins.coords['time_of_arrival']
    del da.bins.coords['Ltotal']
    da = da.rename_dims(tof='event_time_zero')
    # Compute a proper event_time_offset as tof % period
    da.bins.coords['event_time_offset'] = (da.bins.coords.pop('tof') % period)#.to(unit=)
    return da

def add_positions(da):
    temp = da.bins.concat('event_time_zero').copy()
    out = da.copy()
    out.coords['position'] =  temp.bins.coords['position'].bins.mean()
    del out.bins.coords['position']
    return out.transform_coords(
        "Ltotal",
        graph=plane_graph,
        keep_intermediate=True,
    )

sample_nexus = add_positions(to_nexus(sample_da).group('pixel_id'))
ob_nexus = add_positions(to_nexus(ob_da).group('pixel_id'))

sample_nexus

In [None]:
# Visualize
fig_nexus = sample_nexus.bins.concat().hist(event_time_offset=300).plot(title='McStas simulation: sample')
fig_nexus + ob_nexus.bins.concat().hist(event_time_offset=300).plot(title='McStas simulation: open beam')

In [None]:
import plopp as pp

pp.scatter3d(sample_nexus.sum('event_time_zero'), pos='position', cbar=True, pixel_size=0.0005)

## Choppers

To accurately compute the wavelengths of the neutrons from their time-of-arrival,
we need the parameters of the choppers in the beamline.

In [None]:
import sciline as sl
from scippneutron.chopper import DiskChopper
from scippneutron.tof import unwrap
from scippneutron.tof import chopper_cascade

Hz = sc.Unit("Hz")
deg = sc.Unit("deg")
meter = sc.Unit("m")

parameters = {
    "WFMC_1": {
        "frequency": 56.0,
        "phase": 93.244,
        "distance": 6.85,
        "open": [-1.9419, 49.5756, 98.9315, 146.2165, 191.5176, 234.9179],
        "close": [1.9419, 55.7157, 107.2332, 156.5891, 203.8741, 249.1752]
    },
    "WFMC_2": {
        "frequency": 56.0,
        "phase": 97.128,
        "distance": 7.15,
        "open": [-1.9419, 51.8318, 103.3493, 152.7052, 199.9903, 245.2914],
        "close": [1.9419, 57.9719, 111.6510, 163.0778, 212.3468, 259.5486]
    },
    "FOC_1": {
        "frequency": 42.0,
        "phase": 81.303297,
        "distance": 8.4,
        "open": [-5.1362, 42.5536, 88.2425, 132.0144, 173.9497, 216.7867],
        "close": [5.1362, 54.2095, 101.2237, 146.2653, 189.417, 230.7582]
    },
    "BP_1": {
        "frequency": 7.0,
        "phase": 31.080,
        "distance": 8.45,
        "open": [-23.6029],
        "close": [23.6029]
    },
    "FOC_2": {
        "frequency": 42.0,
        "phase": 107.013442,
        "distance": 12.2,
        "open": [-16.3227, 53.7401, 120.8633, 185.1701, 246.7787, 307.0165],
        "close": [16.3227, 86.8303, 154.3794, 218.7551, 280.7508, 340.3188]
    },
    "BP_2": {
        "frequency": 7.0,
        "phase": 44.224,
        "distance": 12.25,
        "open": [-34.4663],
        "close": [34.4663]
    },
    "T0_alpha": {
        "frequency": 14.0,
        "phase": 179.672,
        "distance": 13.5,
        "open": [-167.8986],
        "close": [167.8986]
    },
    "T0_beta": {
        "frequency": 14.0,
        "phase": 179.672,
        "distance": 13.7,
        "open": [-167.8986],
        "close": [167.8986]
    },
    "FOC_3": {
        "frequency": 28.0,
        "phase": 92.993,
        "distance": 17.0,
        "open": [-20.302, 45.247, 108.0457, 168.2095, 225.8489, 282.2199],
        "close": [20.302, 85.357, 147.6824, 207.3927, 264.5977, 319.4024]
    },
    "FOC_4": {
        "frequency": 14.0,
        "phase": 61.584,
        "distance": 23.69,
        "open": [-16.7157, 29.1882, 73.1661, 115.2988, 155.6636, 195.5254],
        "close": [16.7157, 61.8217, 105.0352, 146.4355, 186.0987, 224.0978]
    },
    "FOC_5": {
        "frequency": 14.0,
        "phase": 82.581,
        "distance": 33.0,
        "open": [-25.8514, 38.3239, 99.8064, 160.1254, 217.4321, 272.5426],
        "close": [25.8514, 88.4621, 147.4729, 204.0245, 257.7603, 313.7139]
    },

}

disk_choppers = {key: DiskChopper(
    frequency=-ch["frequency"] * Hz,
    beam_position=sc.scalar(0.0, unit="deg"),
    phase=-ch["phase"] * deg,
    axle_position=sc.vector(value=[0, 0, ch["distance"]], unit="m"),
    slit_begin=sc.array(dims=["cutout"], values=ch["open"], unit="deg"),
    slit_end=sc.array(dims=["cutout"], values=ch["close"], unit="deg")
) for key, ch in parameters.items() }

In [None]:
disk_choppers["WFMC_1"]

In [None]:
choppers = {
    key: chopper_cascade.Chopper.from_disk_chopper(
        chop,
        pulse_frequency=sc.scalar(14.0, unit="Hz"),
        npulses=1,
    )
    for key, chop in disk_choppers.items()
}

### Check that the chopper settings make sense with a quick `tof` run

As useful sanity check is to run a basic simulation,
propagating neutrons through the chopper cascade,
using the [Tof](https://tof.readthedocs.io) package.

In [None]:
from scippneutron.tof.fakes import FakeBeamlineEss

Ltotal = sample_da.coords['Ltotal'].mean()
ess_beamline = FakeBeamlineEss(
    choppers=choppers,
    monitors={"detector": Ltotal},
    run_length=sc.scalar(1 / 14, unit="s") * 8,
    events_per_pulse=100_000,
)

ess_beamline.model_result.plot()

We observe that the WFM choppers make 6 distinct frames at the detector,
and that the other choppers skip every other pulse to maximize wavelength coverage.

We can now compare the counts on the detector to our raw data,
to make sure they broadly resemble each other.

In [None]:
raw_data = ess_beamline.get_monitor("detector")[0]

# Visualize
fig_nexus + raw_data.hist(event_time_offset=300).sum("pulse").plot(title='Tof simulation')

## Use WFM workflow

We now set up the workflow which will convert the raw neutron arrival times to a real time-of-flight,
and thus a wavelength.

### Chopper cascade

In [None]:
one_pulse = ess_beamline.source.data["pulse", 0]
pulse_tmin = one_pulse.coords["time"].min()
pulse_tmax = one_pulse.coords["time"].max()
pulse_wmin = one_pulse.coords["wavelength"].min()
pulse_wmax = one_pulse.coords["wavelength"].max()

frames = chopper_cascade.FrameSequence.from_source_pulse(
    time_min=pulse_tmin,
    time_max=pulse_tmax,
    wavelength_min=pulse_wmin,
    wavelength_max=pulse_wmax,
)

# Chop the frames
chopped = frames.chop(choppers.values())

# Propagate the neutrons to the detector
at_sample = chopped.propagate_to(Ltotal)

# Visualize the results
cascade_fig, cascade_ax = at_sample.draw()

### Pipeline

In [None]:
workflow = sl.Pipeline(unwrap.providers(), params=unwrap.params())

workflow[unwrap.PulsePeriod] = sc.reciprocal(ess_beamline.source.frequency)
workflow[unwrap.PulseStride] = 2  # Need for pulse-skipping
workflow[unwrap.SourceTimeRange] = pulse_tmin, pulse_tmax
workflow[unwrap.SourceWavelengthRange] = pulse_wmin, pulse_wmax
workflow[unwrap.Choppers] = choppers

workflow[unwrap.Ltotal] = sample_nexus.coords['Ltotal']
workflow[unwrap.RawData] = sample_nexus

workflow.visualize(unwrap.TofData)

In [None]:
sample_tofs = workflow.compute(unwrap.TofData)
sample_tofs

In [None]:
sample_wavs = sample_tofs.transform_coords('wavelength', graph=plane_graph)
sample_wavs

We can now compare our computed wavelengths to the true wavelengths of the neutrons in the McStas simulation:

In [None]:
true_wavs = sample_da.hist(sim_wavelength=300).rename(sim_wavelength='wavelength')

pp.plot({
    'true': true_wavs,
    'wfm': sample_wavs.bins.concat().hist(wavelength=true_wavs.coords['wavelength'])
}, title="ODIN McStas simulation")

## Region of interest

Looking at the counts on the 2d detector panel,
we see that there is a central rectangular darker region,
surrounded by brighter edges.

In [None]:
sample_folded = sample_wavs.bins.concat('event_time_zero').fold(dim='pixel_id', sizes={'y': 128, 'x': 128})
sample_folded.hist().plot(aspect='equal')

The dark region is where the beam was absorbed by the sample,
and this is the region of interest.
The brighter edges need to be discarded.

We crop the data using simple array slicing:

In [None]:
sel = slice(11, 116, 1)
sample_cropped = sample_folded['y', sel]['x', sel]
sample_cropped.hist().plot(aspect='equal')

### Repeat for the open-beam

We repeat the conversion to wavelength and crop the edges of the open-beam measurement.

In [None]:
# Give the same pixel positions to both sample and open beam.
# Note: this is only because of the way we computed the positions.
# In practice, the geometry should come from the nexus file and this won't be needed.
ob_nexus.coords.update({key: sample_nexus.coords[key] for key in ('position', 'Ltotal')})

workflow[unwrap.Ltotal] = ob_nexus.coords['Ltotal']
workflow[unwrap.RawData] = ob_nexus

ob_tofs = workflow.compute(unwrap.TofData)
ob_wavs = ob_tofs.transform_coords('wavelength', graph=plane_graph)
ob_folded = ob_wavs.bins.concat('event_time_zero').fold(dim='pixel_id', sizes={'y': 128, 'x': 128})
ob_cropped = ob_folded['y', sel]['x', sel]

## Normalize the signal

Finally, we are able to normalize our sample measurement to the open-beam data.

Here, we sum over all pixels before normalizing.
There is no spatial structure in the signal, and we are only interested in the wavelength spectrum (where the Bragg edge is).
So this is effectively like degrading the detector resolution to a single pixel.

In [None]:
# Common set of bins
bins = sc.linspace('wavelength', 1.1, 9.5, 301, unit='angstrom')

num = sample_cropped.bins.concat().hist(wavelength=bins)
den = ob_cropped.bins.concat().hist(wavelength=bins)

# Add variances
num.variances = num.values
den.variances = den.values

normalized = num / den
normalized

In [None]:
normalized.plot()

## Save the final result

In [None]:
from scippneutron.io import save_xye

to_disk = normalized.copy(deep=False)
to_disk.coords['wavelength'] = sc.midpoints(to_disk.coords['wavelength'])

save_xye('fe_bragg_edge.xye', to_disk)