# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
import sciline
import scipp as sc
import scippnexus as snx
from scippneutron.conversion.beamline import (
beam_aligned_unit_vectors,
scattering_angles_with_gravity,
)
from scippneutron.conversion.graph import beamline, tof
from ess.reduce.uncertainty import broadcast_uncertainties
from .common import mask_range
from .types import (
BinnedQ,
BinnedQxQy,
CorrectedDetector,
CorrectForGravity,
Denominator,
GravityVector,
IofQPart,
MonitorTerm,
MonitorType,
NormalizedQ,
NormalizedQxQy,
Numerator,
Position,
QDetector,
QxyDetector,
RunType,
ScatteringRunType,
TofMonitor,
UncertaintyBroadcastMode,
WavelengthDetector,
WavelengthMask,
WavelengthMonitor,
)
[docs]
def cyl_unit_vectors(incident_beam: sc.Variable, gravity: sc.Variable):
vectors = beam_aligned_unit_vectors(incident_beam=incident_beam, gravity=gravity)
return {
'cyl_x_unit_vector': vectors['beam_aligned_unit_x'],
'cyl_y_unit_vector': vectors['beam_aligned_unit_y'],
}
[docs]
def cylindrical_x(
cyl_x_unit_vector: sc.Variable, scattered_beam: sc.Variable
) -> sc.Variable:
"""
Compute the horizontal x coordinate perpendicular to the incident beam direction.
Note that it is assumed here that the incident beam is perpendicular to the gravity
vector.
"""
return sc.dot(scattered_beam, cyl_x_unit_vector)
[docs]
def cylindrical_y(
cyl_y_unit_vector: sc.Variable, scattered_beam: sc.Variable
) -> sc.Variable:
"""
Compute the vertical y coordinate perpendicular to the incident beam direction.
Note that it is assumed here that the incident beam is perpendicular to the gravity
vector.
"""
return sc.dot(scattered_beam, cyl_y_unit_vector)
[docs]
def phi_no_gravity(
cylindrical_x: sc.Variable, cylindrical_y: sc.Variable
) -> sc.Variable:
"""
Compute the cylindrical phi angle around the incident beam. Note that it is assumed
here that the incident beam is perpendicular to the gravity vector.
"""
return sc.atan2(y=cylindrical_y, x=cylindrical_x)
[docs]
def Qxy(Q: sc.Variable, phi: sc.Variable) -> dict[str, sc.Variable]:
"""
Compute the Qx and Qy components of the scattering vector from the scattering angle,
wavelength, and phi angle.
"""
Qx = sc.cos(phi)
Qy = sc.sin(phi)
if Q.bins is not None and phi.bins is not None:
Qx *= Q
Qy *= Q
else:
Qx = Qx * Q
Qy = Qy * Q
return {'Qx': Qx, 'Qy': Qy}
[docs]
def sans_elastic(
correct_for_gravity: CorrectForGravity,
*,
sample_position: Position[snx.NXsample, RunType],
source_position: Position[snx.NXsource, RunType],
gravity: GravityVector,
) -> ElasticCoordTransformGraph[RunType]:
"""
Generate a coordinate transformation graph for SANS elastic scattering.
It is based on classical conversions from ``tof`` and pixel ``position`` to
:math:`\\lambda` (``wavelength``), :math:`\\theta` (``theta``) and
:math:`Q` (``Q``), but can take into account the Earth's gravitational field,
which bends the flight path of the neutrons, to compute the scattering angle
:math:`\\theta`.
The angle can be found using the following expression
(`Seeger & Hjelm 1991 <https://doi.org/10.1107/S0021889891004764>`_):
.. math::
\\theta = \\frac{1}{2}\\sin^{-1}\\left(\\frac{\\sqrt{ x^{2} + \\left( y + \\frac{g m_{\\rm n}}{2 h^{2}} \\lambda^{2} L_{2}^{2} \\right)^{2} } }{L_{2}}\\right)
where :math:`x` and :math:`y` are the spatial coordinates of the pixels in the
horizontal and vertical directions, respectively,
:math:`m_{\\rm n}` is the neutron mass,
:math:`L_{2}` is the distance between the sample and a detector pixel,
:math:`g` is the acceleration due to gravity,
and :math:`h` is Planck's constant.
By default, the effects of gravity on the neutron flight paths are not included
(equivalent to :math:`g = 0` in the expression above).
Parameters
----------
correct_for_gravity:
Take into account the bending of the neutron flight paths from the
Earth's gravitational field if ``True``.
gravity:
A vector indicating the strength and direction of gravity.
Required even if ``correct_for_gravity`` is ``False``.
sample_position:
Position of the sample as a vector.
source_position:
Position of the source as a vector.
""" # noqa: E501
graph = {
**beamline.beamline(scatter=True),
**tof.elastic_Q('tof'),
'sample_position': lambda: sample_position,
'source_position': lambda: source_position,
'gravity': lambda: gravity,
}
if correct_for_gravity:
del graph['two_theta']
graph[('two_theta', 'phi')] = scattering_angles_with_gravity
else:
graph['phi'] = phi_no_gravity
graph[('cyl_x_unit_vector', 'cyl_y_unit_vector')] = cyl_unit_vectors
graph['cylindrical_x'] = cylindrical_x
graph['cylindrical_y'] = cylindrical_y
graph[('Qx', 'Qy')] = Qxy
return ElasticCoordTransformGraph(graph)
[docs]
def sans_monitor(
source_position: Position[snx.NXsource, RunType],
) -> MonitorCoordTransformGraph[RunType]:
"""
Generate a coordinate transformation graph for SANS monitor (no scattering).
"""
return MonitorCoordTransformGraph(
{
**beamline.beamline(scatter=False),
**tof.elastic_wavelength('tof'),
'source_position': lambda: source_position,
}
)
[docs]
def monitor_to_wavelength(
monitor: TofMonitor[RunType, MonitorType],
graph: MonitorCoordTransformGraph[RunType],
) -> WavelengthMonitor[RunType, MonitorType]:
return WavelengthMonitor[RunType, MonitorType](
monitor.transform_coords('wavelength', graph=graph, keep_inputs=False)
)
# TODO This demonstrates a problem: Transforming to wavelength should be possible
# for RawData, MaskedData, ... no reason to restrict necessarily.
# Would we be fine with just choosing on option, or will this get in the way for users?
[docs]
def detector_to_wavelength(
detector: CorrectedDetector[ScatteringRunType, Numerator],
graph: ElasticCoordTransformGraph[ScatteringRunType],
) -> WavelengthDetector[ScatteringRunType, Numerator]:
return WavelengthDetector[ScatteringRunType, Numerator](
detector.transform_coords('wavelength', graph=graph, keep_inputs=False)
)
[docs]
def mask_wavelength_q(
da: BinnedQ[ScatteringRunType, Numerator], mask: WavelengthMask
) -> NormalizedQ[ScatteringRunType, Numerator]:
if mask is not None:
da = mask_range(da, mask=mask)
return NormalizedQ[ScatteringRunType, Numerator](da)
[docs]
def mask_wavelength_qxy(
da: BinnedQxQy[ScatteringRunType, Numerator], mask: WavelengthMask
) -> NormalizedQxQy[ScatteringRunType, Numerator]:
if mask is not None:
da = mask_range(da, mask=mask)
return NormalizedQxQy[ScatteringRunType, Numerator](da)
[docs]
def mask_and_scale_wavelength_q(
da: BinnedQ[ScatteringRunType, Denominator],
mask: WavelengthMask,
wavelength_term: MonitorTerm[ScatteringRunType],
uncertainties: UncertaintyBroadcastMode,
) -> NormalizedQ[ScatteringRunType, Denominator]:
da = da * broadcast_uncertainties(wavelength_term, prototype=da, mode=uncertainties)
if mask is not None:
da = mask_range(da, mask=mask)
return NormalizedQ[ScatteringRunType, Denominator](da)
[docs]
def mask_and_scale_wavelength_qxy(
da: BinnedQxQy[ScatteringRunType, Denominator],
mask: WavelengthMask,
wavelength_term: MonitorTerm[ScatteringRunType],
uncertainties: UncertaintyBroadcastMode,
) -> NormalizedQxQy[ScatteringRunType, Denominator]:
da = da * broadcast_uncertainties(wavelength_term, prototype=da, mode=uncertainties)
if mask is not None:
da = mask_range(da, mask=mask)
return NormalizedQxQy[ScatteringRunType, Denominator](da)
def _compute_Q(
data: sc.DataArray, graph: ElasticCoordTransformGraph, target: tuple[str, ...]
) -> sc.DataArray:
# Keep naming of wavelength dim, subsequent steps use a (Q[xy], wavelength) binning.
return QDetector[ScatteringRunType, IofQPart](
data.transform_coords(
target,
graph=graph,
keep_intermediate=False,
rename_dims=False,
)
)
[docs]
def compute_Q(
data: WavelengthDetector[ScatteringRunType, IofQPart],
graph: ElasticCoordTransformGraph[ScatteringRunType],
) -> QDetector[ScatteringRunType, IofQPart]:
"""
Convert a data array from wavelength to Q.
"""
return QDetector[ScatteringRunType, IofQPart](
_compute_Q(data=data, graph=graph, target=('Q',))
)
[docs]
def compute_Qxy(
data: WavelengthDetector[ScatteringRunType, IofQPart],
graph: ElasticCoordTransformGraph[ScatteringRunType],
) -> QxyDetector[ScatteringRunType, IofQPart]:
"""
Convert a data array from wavelength to Qx and Qy.
"""
return QxyDetector[ScatteringRunType, IofQPart](
_compute_Q(data=data, graph=graph, target=('Qx', 'Qy'))
)
providers = (
sans_elastic,
sans_monitor,
monitor_to_wavelength,
detector_to_wavelength,
mask_wavelength_q,
mask_wavelength_qxy,
mask_and_scale_wavelength_q,
mask_and_scale_wavelength_qxy,
compute_Q,
compute_Qxy,
)