# Reducing WFM data

This notebook aims to illustrate how to work with the wavelength frame multiplication submodule `wfm`.

We will create a beamline that resembles the ODIN instrument beamline,
generate some fake neutron data,
and then show how to convert the neutron arrival times at the detector to neutron time-of-flight,
from which a wavelength can then be computed (or process also commonly known as 'stitching').

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipp as sc
from scipp import constants
import scippneutron as scn
import ess.wfm as wfm
import ess.choppers as ch
np.random.seed(1) # Fixed for reproducibility

## Create beamline components

We first create all the components necessary to a beamline to run in WFM mode
(see [Introduction to WFM](introduction-to-wfm.ipynb) for the meanings of the different symbols).
The beamline will contain

- a neutron source, located at the origin ($x = y = z =  0$)
- a pulse with a defined length ($2860 ~\mu s$) and $t_0$ ($130 ~\mu s$)
- a single pixel detector, located at $z = 60$ m
- two WFM choppers, located at $z = 6.775$ m and $z = 7.225$ m, each with 6 frame windows/openings

The `wfm` module provides a helper function to quickly create such a beamline.
It returns a `dict` of coordinates, that can then be subsequently added to a data container.

In [None]:
coords = wfm.make_fake_beamline(nframes=6)
coords

## Generate some fake data

Next, we will generate some fake imaging data (no scattering will be considered),
that is supposed to mimic a spectrum with a Bragg edge located at $4\unicode{x212B}$.
We start with describing a function which will act as our underlying distribution

In [None]:
x = np.linspace(1, 10.0, 100000)
a = 20.0
b = 4.0
y1 = 0.7 / (np.exp(-a * (x - b)) + 1.0)
y2 = 1.4-0.2*x
y = y1 + y2
fig1, ax1 = plt.subplots()
ax1.plot(x, y)
ax1.set_xlabel("Wavelength [angstroms]")

We then proceed to generate two sets of 1,000,000 events:
- one for the `sample` using the distribution defined above
- and one for the `vanadium` which will be just a flat random distribution

For the events in both `sample` and `vanadium`,
we define a wavelength for the neutrons as well as a birth time,
which will be a random time between the pulse $t_0$ and the end of the useable pulse $t_0$ + pulse_length.

In [None]:
nevents = 1_000_000
events = {
    "sample": {
        "wavelengths": sc.array(
            dims=["event"],
            values=np.random.choice(x, size=nevents, p=y/np.sum(y)),
            unit="angstrom"),
        "birth_times": sc.array(
            dims=["event"],
            values=np.random.random(nevents) * coords["source_pulse_length"].value,
            unit="us") + coords["source_pulse_t_0"]
    },
    "vanadium": {
        "wavelengths": sc.array(
            dims=["event"],
            values=np.random.random(nevents) * 9.0 + 1.0,
            unit="angstrom"),
        "birth_times": sc.array(
            dims=["event"],
            values=np.random.random(nevents) * coords["source_pulse_length"].value,
            unit="us") + coords["source_pulse_t_0"]
    }
}

We can then take a quick look at our fake data by histogramming the events

In [None]:
# Histogram and plot the event data
bins = np.linspace(1.0, 10.0, 129)
fig2, ax2 = plt.subplots()
for key in events:
    h = ax2.hist(events[key]["wavelengths"].values, bins=128, alpha=0.5, label=key)
ax2.set_xlabel("Wavelength [angstroms]")
ax2.set_ylabel("Counts")
ax2.legend()

We can also verify that the birth times fall within the expected range:

In [None]:
for key, item in events.items():
    print(key)
    print(sc.min(item["birth_times"]))
    print(sc.max(item["birth_times"]))

We can then compute the arrival times of the events at the detector pixel

In [None]:
# The ratio of neutron mass to the Planck constant
alpha = sc.to_unit(constants.m_n / constants.h, 'us/m/angstrom')
# The distance between the source and the detector
dz = sc.norm(coords['position'] - coords['source_position'])
for key, item in events.items():
    item["arrival_times"] = alpha * dz * item["wavelengths"] + item["birth_times"]
events["sample"]["arrival_times"]

## Visualize the beamline's chopper cascade

We first attach the beamline geometry to a Dataset

In [None]:
ds = sc.Dataset(coords=coords)
ds

The `wfm.plot` submodule provides a useful tool to visualise the chopper cascade as a time-distance diagram.
This is achieved by calling

In [None]:
f = wfm.plot.time_distance_diagram(ds)

This shows the 6 frames, generated by the WFM choppers,
as well as their predicted time boundaries at the position of the detector.

Each frame has a time window during which neutrons are allowed to pass through,
as well as minimum and maximum allowed wavelengths.

This information is obtained from the beamline geometry by calling

In [None]:
frames = wfm.get_frames(ds)
frames

## Discard neutrons that do not make it through the chopper windows

Once we have the parameters of the 6 wavelength frames,
we need to run through all our generated neutrons and filter out all the neutrons with invalid flight paths,
i.e. the ones that do not make it through both chopper openings in a given frame.

In [None]:
for item in events.values():
    item["valid_indices"] = []
near_wfm_chopper = ds.coords["chopper_wfm_1"].value
far_wfm_chopper = ds.coords["chopper_wfm_2"].value
near_time_open = ch.time_open(near_wfm_chopper)
near_time_close = ch.time_closed(near_wfm_chopper)
far_time_open = ch.time_open(far_wfm_chopper)
far_time_close = ch.time_closed(far_wfm_chopper)

for item in events.values():
    # Compute event arrival times at wfm choppers 1 and 2
    slopes = 1.0 / (alpha * item["wavelengths"])
    intercepts = -slopes * item["birth_times"]
    times_at_wfm1 = (sc.norm(near_wfm_chopper["position"].data) - intercepts) / slopes
    times_at_wfm2 = (sc.norm(far_wfm_chopper["position"].data) - intercepts) / slopes
    # Create a mask to see if neutrons go through one of the openings
    mask = sc.zeros(dims=times_at_wfm1.dims, shape=times_at_wfm1.shape, dtype=bool)
    for i in range(len(frames["time_min"])):
        mask |= ((times_at_wfm1 >= near_time_open["frame", i]) &
                 (times_at_wfm1 <= near_time_close["frame", i]) &
                 (item["wavelengths"] >= frames["wavelength_min"]["frame", i]).data &
                 (item["wavelengths"] <= frames["wavelength_max"]["frame", i]).data)
    item["valid_indices"] = np.ravel(np.where(mask.values))

## Create a realistic Dataset

We now create a dataset that contains:
- the beamline geometry
- the time coordinate
- the histogrammed events

In [None]:
for item in events.values():
    item["valid_times"] = item["arrival_times"].values[item["valid_indices"]]

tmin = min([item["valid_times"].min() for item in events.values()])
tmax = max([item["valid_times"].max() for item in events.values()])

dt = 0.1 * (tmax - tmin)
time_coord = sc.linspace(dim='time',
                         start=tmin - dt,
                         stop=tmax + dt,
                         num=257,
                         unit=events["sample"]["arrival_times"].unit)

# Histogram the data
for key, item in events.items():
    da = sc.DataArray(
        data=sc.ones(dims=['time'], shape=[len(item["valid_times"])],
                     unit=sc.units.counts, with_variances=True),
        coords={
            'time': sc.array(dims=['time'], values=item["valid_times"], unit=sc.units.us)})
    ds[key] = da.hist(time=time_coord)

ds

In [None]:
ds.plot()

## Stitch the frames

Wave-frame multiplication consists of making 6 new pulses from the original pulse.
This implies that the WFM choppers are acting as a source chopper.
Hence, to compute a wavelength from a time and a distance between source and detector,
the location of the source must now be at the position of the WFM choppers,
or more exactly at the mid-point between the two WFM choppers.

The stitching operation equates to converting the `time` dimension to `time-of-flight`,
by subtracting from each frame a time shift equal to the mid-point between the two WFM choppers.

This is performed with the `stitch` function in the `wfm` module:

In [None]:
stitched = wfm.stitch(frames=frames,
                      data=ds,
                      dim='time',
                      bins=257)
stitched

In [None]:
stitched.plot()

For diagnostic purposes,
it can be useful to visualize the individual frames before and after the stitching process.
The `wfm.plot` module provides two helper functions to do just this:

In [None]:
wfm.plot.frames_before_stitching(data=ds['sample'], frames=frames, dim='time')

In [None]:
wfm.plot.frames_after_stitching(data=ds['sample'], frames=frames, dim='time')

## Convert to wavelength

Now that the data coordinate is time-of-flight (`tof`),
we can use `scippneutron` to perform the unit conversion from `tof` to `wavelength`.

In [None]:
from scippneutron.conversion.graph import beamline, tof
graph = {**beamline.beamline(scatter=False), **tof.elastic("tof")}
converted = stitched.transform_coords("wavelength", graph=graph)
converted.plot()

## Normalization

Normalization is performed simply by dividing the counts of the `sample` run by the counts of the `vanadium` run.

In [None]:
normalized = converted['sample'] / converted['vanadium']
normalized.plot()

## Comparing to the raw wavelengths

The final step is a sanity check to verify that the wavelength-dependent data obtained from the stitching process
agrees (to within the beamline resolution) with the original wavelength distribution that was generated at
the start of the workflow.

For this, we simply histogram the raw neutron events using the same bins as the `normalized` data,
filtering out the neutrons with invalid flight paths.

In [None]:
for item in events.values():
    item["wavelength_counts"], _ = np.histogram(
        item["wavelengths"].values[item["valid_indices"]],
        bins=normalized.coords['wavelength'].values)

We then normalize the `sample` by the `vanadium` run,
and plot the resulting spectrum alongside the one obtained from the stitching.

In [None]:
original = sc.DataArray(
    data=sc.array(dims=['wavelength'],
                  values=events["sample"]["wavelength_counts"] /
                         events["vanadium"]["wavelength_counts"]),
    coords = {"wavelength": normalized.coords['wavelength']})

sc.plot({"stitched": normalized, "original": original})

We can see that the counts in the `stitched` data agree very well with the original data.
There is some smoothing of the data seen in the `stitched` result,
and this is expected because of the resolution limitations of the beamline due to its long source pulse.
This smoothing (or smearing) would, however, be much stronger if WFM choppers were not used.

## Without WFM choppers

In this section, we compare the results obtained above to a beamline that does not have a WFM chopper system.
We make a new set of events,
where the number of events is equal to the number of neutrons that make it through the chopper cascade in the previous case.

In [None]:
nevents_no_wfm = len(events["sample"]["valid_times"])
events_no_wfm = {
    "sample": {
        "wavelengths": sc.array(
            dims=["event"],
            values=np.random.choice(x, size=nevents_no_wfm, p=y/np.sum(y)),
            unit="angstrom"),
        "birth_times": sc.array(
            dims=["event"],
            values=np.random.random(nevents_no_wfm) * coords["source_pulse_length"].value,
            unit="us") + coords["source_pulse_t_0"]
    },
    "vanadium": {
        "wavelengths": sc.array(
            dims=["event"],
            values=np.random.random(nevents_no_wfm) * 9.0 + 1.0,
            unit="angstrom"),
        "birth_times": sc.array(
            dims=["event"],
            values=np.random.random(nevents_no_wfm) * coords["source_pulse_length"].value,
            unit="us") + coords["source_pulse_t_0"]
    }
}
for key, item in events_no_wfm.items():
    item["arrival_times"] = alpha * dz * item["wavelengths"] + item["birth_times"]
events_no_wfm["sample"]["arrival_times"]

We then histogram these events to create a new Dataset.
Because we are no longer make new pulses with the WFM choppers,
the event time-of-flight is simply the arrival time of the event at the detector.

In [None]:
tmin = min([item["arrival_times"].values.min() for item in events_no_wfm.values()])
tmax = max([item["arrival_times"].values.max() for item in events_no_wfm.values()])

dt = 0.1 * (tmax - tmin)
time_coord_no_wfm = sc.linspace(dim='tof',
                         start=tmin - dt,
                         stop=tmax + dt,
                         num=257,
                         unit=events_no_wfm["sample"]["arrival_times"].unit)

ds_no_wfm = sc.Dataset(coords=coords)

# Histogram the data
for key, item in events_no_wfm.items():
    da = sc.DataArray(
        data=sc.ones(dims=['tof'], shape=[len(item["arrival_times"])],
                     unit=sc.units.counts, with_variances=True),
        coords={
            'tof': sc.array(dims=['tof'], values=item["arrival_times"].values, unit=sc.units.us)})
    ds_no_wfm[key] = da.hist(tof=time_coord_no_wfm)

ds_no_wfm

In [None]:
sc.plot(ds_no_wfm)

We then perform the standard unit conversion and normalization

In [None]:
converted_no_wfm = ds_no_wfm.transform_coords("wavelength", graph=graph)
normalized_no_wfm = converted_no_wfm['sample'] / converted_no_wfm['vanadium']
normalized_no_wfm.plot()

In the same manner and in the previous section, we compare to the real neutron wavelengths

In [None]:
for item in events_no_wfm.values():
    item["wavelength_counts"], _ = np.histogram(
        item["wavelengths"].values,
        bins=normalized_no_wfm.coords['wavelength'].values)

In [None]:
original_no_wfm = sc.DataArray(
    data=sc.array(dims=['wavelength'],
                  values=events_no_wfm["sample"]["wavelength_counts"] /
                         events_no_wfm["vanadium"]["wavelength_counts"]),
    coords = {"wavelength": normalized_no_wfm.coords['wavelength']})

w_min = 2.0 * sc.units.angstrom
w_max = 5.5 * sc.units.angstrom
sc.plot({"without WFM": normalized_no_wfm['wavelength', w_min:w_max],
         "original": original_no_wfm['wavelength', w_min:w_max]},
        errorbars=False)

We can see that there is a significant shift between the calculated wavelength of the Bragg edge around $4\unicode{x212B}$
and the original underlying wavelengths.
In comparison, the same plot for the WFM run yields a much better agreement

In [None]:
sc.plot({"stitched": normalized['wavelength', w_min:w_max],
         "original": original['wavelength', w_min:w_max]},
        errorbars=False)

## Working in event mode

It is also possible to work with WFM data in event mode.
The `stitch` utility will accept both histogrammed and binned (event) data.

We first create a new dataset, with the same events as in the first example,
but this time we bin the data with `sc.bin` instead of using `sc.histogram`,
so we can retain the raw events.

In [None]:
for item in events.values():
    item["valid_times"] = item["arrival_times"].values[item["valid_indices"]]

tmin = min([item["valid_times"].min() for item in events.values()])
tmax = max([item["valid_times"].max() for item in events.values()])

dt = 0.1 * (tmax - tmin)
time_coord = sc.linspace(dim='time',
                         start=tmin - dt,
                         stop=tmax + dt,
                         num=257,
                         unit=events["sample"]["arrival_times"].unit)

ds_event = sc.Dataset(coords=coords)

# Bin the data
for key, item in events.items():
    da = sc.DataArray(
        data=sc.ones(dims=['event'], shape=[len(item["valid_times"])], unit=sc.units.counts,
                     with_variances=True),
        coords={
            'time': sc.array(dims=['event'], values=item["valid_times"], unit=sc.units.us)})
    ds_event[key] = da.bin(time=time_coord)

ds_event

The underlying events can be inspected by using the `.bins.constituents['data']` property of our objects:

In [None]:
ds_event["sample"].bins.constituents['data']

We can visualize this to make sure it looks the same as the histogrammed case above:

In [None]:
sc.plot(ds_event.hist())

As explained above, the `stitch` routine accepts both histogrammed and binned (event) data.
So stitching the binned data works in the exact same way as above, namely

In [None]:
stitched_event = wfm.stitch(frames=frames,
                            data=ds_event,
                            dim='time')
stitched_event

The `stitch` function will return a data structure with a single bin in the `'tof'` dimension.
Visualizing this data is therefore slightly more tricky,
because the data needs to be histogrammed using a finer binning before a useful plot can be made.

In [None]:
sc.plot({key: item.hist(tof=256) for key, item in stitched_event.items()})

At this point, it may be useful to compare the results of the two different stitching operations.

In [None]:
rebinned = stitched_event["sample"].bin(tof=stitched["sample"].coords['tof'])
sc.plot({"events": rebinned.hist(), "histogram": stitched["sample"]}, errorbars=False)

We note that histogramming the data early introduces some smoothing in the data.

We can of course continue in event mode and perform the unit conversion and normalization to the Vanadium.

In [None]:
converted_event = stitched_event.transform_coords("wavelength", graph=graph)

In [None]:
# Normalizing binned data is done using the sc.lookup helper
hist = converted_event["vanadium"].hist(wavelength=converted["sample"].coords['wavelength'])
normalized_event = converted_event["sample"].bins / sc.lookup(func=hist, dim='wavelength')

Finally, we compare the end result with the original wavelengths, and see that the agreement is once again good.

In [None]:
to_plot = normalized_event.bin(wavelength=converted["sample"].coords['wavelength']).hist()
sc.plot({"stitched_event": to_plot, "original": original})

We can also compare directly to the histogrammed version,
to see that both methods remain in agreement to a high degree. 

In [None]:
sc.plot({"stitched": normalized['wavelength', w_min:w_max],
         "original": original['wavelength', w_min:w_max],
         "stitched_event": to_plot['wavelength', w_min:w_max]})