# WFM

Wavelength-frame-multiplication (WFM) is a technique commonly used at long-pulse facilities to improve the resolution of the results measured at the neutron detectors.
See for example the article by [Schmakat et al. (2020)](https://www.sciencedirect.com/science/article/pii/S0168900220308640) for a description of how WFM works.

In this notebook, we show how `tof` can be used to convert a neutron time of arrival at the detector to a wavelength.

In [None]:
import scipp as sc
import plopp as pp
import tof

Hz = sc.Unit("Hz")
deg = sc.Unit("deg")
meter = sc.Unit("m")

## Create a source pulse

We first create a source with one pulse containing 500,000 neutrons whose distribution follows the ESS time and wavelength profiles (both thermal and cold neutrons are included).

In [None]:
source = tof.Source(facility="ess", neutrons=500_000)
source.plot()

In [None]:
source.data

## Component set-up

### Choppers

We create a list of choppers that will be included in our beamline.
In our case, we make two WFM choppers, and two frame-overlap choppers.
All choppers have 6 openings.

Finally, we also add a pulse-overlap chopper with a single opening.
These choppers are copied after the [V20 ESS beamline at HZB](https://www.sciencedirect.com/science/article/pii/S0168900216309597).

In [None]:
choppers = [
    tof.Chopper(
        frequency=70.0 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[98.71, 155.49, 208.26, 257.32, 302.91, 345.3],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[109.7, 170.79, 227.56, 280.33, 329.37, 375.0],
            unit="deg",
        ),
        phase=47.10 * deg,
        distance=6.6 * meter,
        name="WFM1",
    ),
    tof.Chopper(
        frequency=70 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[80.04, 141.1, 197.88, 250.67, 299.73, 345.0],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[91.03, 156.4, 217.18, 269.97, 322.74, 375.0],
            unit="deg",
        ),
        phase=76.76 * deg,
        distance=7.1 * meter,
        name="WFM2",
    ),
    tof.Chopper(
        frequency=56 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[74.6, 139.6, 194.3, 245.3, 294.8, 347.2],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[95.2, 162.8, 216.1, 263.1, 310.5, 371.6],
            unit="deg",
        ),
        phase=62.40 * deg,
        distance=8.8 * meter,
        name="Frame-overlap 1",
    ),
    tof.Chopper(
        frequency=28 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[98.0, 154.0, 206.8, 254.0, 299.0, 344.65],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[134.6, 190.06, 237.01, 280.88, 323.56, 373.76],
            unit="deg",
        ),
        phase=12.27 * deg,
        distance=15.9 * meter,
        name="Frame-overlap 2",
    ),
    tof.Chopper(
        frequency=7 * Hz,
        open=sc.array(
            dims=["cutout"],
            values=[30.0],
            unit="deg",
        ),
        close=sc.array(
            dims=["cutout"],
            values=[140.0],
            unit="deg",
        ),
        phase=0 * deg,
        distance=22 * meter,
        name="Pulse-overlap",
    ),
]

### Detectors

We add a single detector 32 meters from the source.

In [None]:
detectors = [
    tof.Detector(distance=32.0 * meter, name="detector"),
]

## Run the simulation

We propagate our pulse of neutrons through the chopper cascade and inspect the results.

In [None]:
model = tof.Model(source=source, choppers=choppers, detectors=detectors)
results = model.run()
results.plot(blocked_rays=5000)

## Wavelength as a function of time-of-arrival

### Plotting wavelength vs time-of-arrival

Since we know the true wavelength of our neutrons,
as well as the time at which the neutrons arrive at the detector
(coordinate named `toa` in the detector reading),
we can plot an image of the wavelengths as a function of time-of-arrival:

In [None]:
# Squeeze the pulse dimension since we only have one pulse
events = results['detector'].data.squeeze()
# Remove the events that don't make it to the detector
events = events[~events.masks['blocked_by_others']]
# Histogram and plot
events.hist(wavelength=500, toa=500).plot(norm='log', grid=True)

### Defining a conversion from `toa` to `wavelength`

The image above shows that there is a pretty tight correlation between time-of-arrival and wavelength.

We compute the mean wavelength inside a given `toa` bin to define a relation between `toa` and `wavelength`.

In [None]:
binned = events.bin(toa=500)

# Weighted mean of wavelength inside each bin
mu = (
    binned.bins.data * binned.bins.coords['wavelength']
).bins.sum() / binned.bins.sum()

# Variance of wavelengths inside each bin
var = (
    binned.bins.data * (binned.bins.coords['wavelength'] - mu) ** 2
) / binned.bins.sum()

We can now overlay our mean wavelength function on the image above:

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(2, 1)

f = events.hist(wavelength=500, tof=500).plot(norm='log', cbar=False, ax=ax[0])
mu.name = 'Wavelength'
mu.plot(ax=ax[0], color='C1', grid=True)
stddev = sc.sqrt(var.hist())
stddev.name = 'Standard deviation'
stddev.plot(ax=ax[1], grid=True)
fig.set_size_inches(6, 8)
fig.tight_layout()

## Computing wavelengths

We set up an interpolator that will compute wavelengths given an array of `toas`.

In [None]:
from scipp.scipy.interpolate import interp1d

# Set up interpolator
y = mu.copy()
y.coords['toa'] = sc.midpoints(y.coords['toa'])
f = interp1d(y, 'toa', bounds_error=False)

# Compute wavelengths
wavs = f(events.coords['toa'].rename_dims(event='toa'))
wavelengths = sc.DataArray(
    data=sc.ones(sizes=wavs.sizes, unit='counts'), coords={'wavelength': wavs.data}
).rename_dims(toa='event')
wavelengths

We can now compare our computed wavelengths to the true wavelengths of the neutrons.
We also include a naive computation of the neutron wavelengths using the `toa` coordinate directly with the detector distance.

In [None]:
naive = events.copy()
speed = detectors[0].distance / naive.coords['toa']
naive.coords['wavelength'] = sc.reciprocal(
    speed * sc.constants.m_n / sc.constants.h
).to(unit='angstrom')

In [None]:
pp.plot(
    {
        'naive': naive.hist(wavelength=300),
        'wfm': wavelengths.hist(wavelength=300),
        'original': events.hist(wavelength=300),
    }
)

We can see that the WFM estimate clearly outperforms the naive computation.

## Multiple detectors

Detectors in real life are usually composed of hundreds of thousands of pixels,
and each pixel can have a different distance from the source.
For example, the edges of a flat detector panel will be slightly further away from the source than the pixels in the center of the panel.

This does not mean we need to compute an interpolator for every detector pixel.
We can instead find the range of pixel distances,
and compute a 2d interpolator with a reasonable amount of bins as a function of distance.

### Using a range of detectors

Here, we assume that the minimum and maximum distances of our pixel range between 30 and 35 meters
(in practise, the range would typically be much narrower).

In [None]:
# Use 50 distances between 30m and 35m
distances = sc.linspace('distance', 30, 35, 50, unit='m')

detectors = [
    tof.Detector(distance=d, name=f"detector-{i}") for i, d in enumerate(distances)
]

# Re-run the simulation
model = tof.Model(source=source, choppers=choppers, detectors=detectors)
results = model.run()

We can now concatenate all the readings along the `distance` dimension into a single data array:

In [None]:
events = [res.data.squeeze() for res in results.detectors.values()]
events = sc.concat(
    [ev[~ev.masks['blocked_by_others']] for ev in events], dim='distance'
)
events.coords['distance'] = distances
events

### Relation between `toa` and `wavelength` in 2D

As in the previous section,
we compute the weighted mean of the wavelengths inside each `toa` bin.

This results in a 2D function of wavelength as a function of `toa` and `distance`.

In [None]:
binned = events.bin(toa=500, dim='event')

# Weighted mean of wavelength inside each bin
mu2d = (
    binned.bins.data * binned.bins.coords['wavelength']
).bins.sum() / binned.bins.sum()

mu2d.plot()

### Computing wavelengths with a 2D interpolator

We now set up a 2D grid interpolator to compute wavelengths for our neutrons.

In [None]:
from scipy.interpolate import RegularGridInterpolator

f = RegularGridInterpolator(
    (sc.midpoints(mu2d.coords['toa']).values, mu2d.coords['distance'].values),
    mu2d.values.T,
    method='linear',
    bounds_error=False,
)


# Flatten the event list
flat = events.flatten(to='event')

# Compute wavelengths
wavs = f((flat.coords['toa'].values, flat.coords['distance'].values))
flat.coords['wavelength'] = sc.array(dims=['event'], values=wavs, unit='angstrom')

We can now compare the results to the original wavelengths.

Once again, we also include the naive computation for reference.

In [None]:
# Naive wavelength computation
naive = events.flatten(to='event')
speed = naive.coords['distance'] / naive.coords['toa']
naive.coords['wavelength'] = sc.reciprocal(
    speed * sc.constants.m_n / sc.constants.h
).to(unit='angstrom')

# True wavelengths
orig = events.hist(distance=40, wavelength=300)

# Plot
style = {'cmap': 'RdBu_r', 'vmin': -5, 'vmax': 5}
fig1 = ((flat.hist(**orig.coords) - orig) / orig).plot(title='WFM', **style)
fig2 = ((naive.hist(**orig.coords) - orig) / orig).plot(title='Naive', **style)
fig1 + fig2

This once again illustrates the superiority of the WFM estimate.

An alternative way of comparing the accuracy of the methods is to look at the probability that a computed wavelength has relative error above $x$,
as a function of $x$:

In [None]:
true_wavs = events.flatten(to='event').coords['wavelength']

err_wfm = sc.abs(true_wavs - flat.coords['wavelength']) / true_wavs
err_naive = sc.abs(true_wavs - naive.coords['wavelength']) / true_wavs

bins = sc.geomspace('relative_error', 1e-3, 0.2, 101)
err_wfm = sc.cumsum(err_wfm.hist(relative_error=bins))
err_naive = sc.cumsum(err_naive.hist(relative_error=bins))

p = pp.plot(
    {'naive': 1 - err_naive / sc.max(err_naive), 'wfm': 1 - err_wfm / sc.max(err_wfm)},
    scale={'relative_error': 'log'},
)
p.canvas.ylabel = 'Probability of $rel. err. > x$'
p