Source code for ess.reduce.time_of_flight.interpolator_numba
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2025 Scipp contributors (https://github.com/scipp)
import numpy as np
from numba import njit, prange
[docs]
@njit(boundscheck=False, cache=True, fastmath=False, parallel=True)
def interpolate(
x: np.ndarray,
y: np.ndarray,
values: np.ndarray,
xp: np.ndarray,
yp: np.ndarray,
xoffset: np.ndarray | None,
deltax: float,
fill_value: float,
out: np.ndarray,
):
"""
Linear interpolation of data on a 2D regular grid.
Parameters
----------
x:
1D array of grid edges along the x-axis (size nx). They must be linspaced.
y:
1D array of grid edges along the y-axis (size ny). They must be linspaced.
values:
2D array of values on the grid. The shape must be (ny, nx).
xp:
1D array of x-coordinates where to interpolate (size N).
yp:
1D array of y-coordinates where to interpolate (size N).
xoffset:
1D array of integer offsets to apply to the x-coordinates (size N).
deltax:
Multiplier to apply to the integer offsets (i.e. the step size).
fill_value:
Value to use for points outside of the grid.
out:
1D array where the interpolated values will be stored (size N).
"""
if not (len(xp) == len(yp) == len(out)):
raise ValueError("Interpolator: all input arrays must have the same size.")
nx = len(x)
ny = len(y)
npoints = len(xp)
xmin = x[0]
xmax = x[nx - 1]
ymin = y[0]
ymax = y[ny - 1]
dx = x[1] - xmin
dy = y[1] - ymin
one_over_dx = 1.0 / dx
one_over_dy = 1.0 / dy
norm = one_over_dx * one_over_dy
for i in prange(npoints):
xx = xp[i] + (xoffset[i] * deltax if xoffset is not None else 0.0)
yy = yp[i]
if (xx < xmin) or (xx > xmax) or (yy < ymin) or (yy > ymax):
out[i] = fill_value
else:
ix = nx - 2 if xx == xmax else int((xx - xmin) * one_over_dx)
iy = ny - 2 if yy == ymax else int((yy - ymin) * one_over_dy)
x1 = x[ix]
x2 = x[ix + 1]
y1 = y[iy]
y2 = y[iy + 1]
a11 = values[iy, ix]
a21 = values[iy, ix + 1]
a12 = values[iy + 1, ix]
a22 = values[iy + 1, ix + 1]
x2mxx = x2 - xx
xxmx1 = xx - x1
out[i] = (
(y2 - yy) * (x2mxx * a11 + xxmx1 * a21)
+ (yy - y1) * (x2mxx * a12 + xxmx1 * a22)
) * norm
[docs]
class Interpolator:
[docs]
def __init__(
self,
time_edges: np.ndarray,
distance_edges: np.ndarray,
values: np.ndarray,
fill_value: float = np.nan,
):
"""
Interpolator for 2D regular grid data (Numba implementation).
Parameters
----------
time_edges:
1D array of time edges.
distance_edges:
1D array of distance edges.
values:
2D array of values on the grid. The shape must be (ny, nx).
fill_value:
Value to use for points outside of the grid.
"""
self.time_edges = time_edges
self.distance_edges = distance_edges
self.values = values
self.fill_value = fill_value
def __call__(
self,
times: np.ndarray,
distances: np.ndarray,
pulse_period: float = 0.0,
pulse_index: np.ndarray | None = None,
) -> np.ndarray:
out = np.empty_like(times)
interpolate(
x=self.time_edges,
y=self.distance_edges,
values=self.values,
xp=times,
yp=distances,
xoffset=pulse_index,
deltax=pulse_period,
fill_value=self.fill_value,
out=out,
)
return out