Source code for ess.powder.filtering

# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
# @author Jan-Lukas Wynen
"""
Prototype for event filtering.

IMPORTANT Will be moved to a different place and potentially modified.
"""

from contextlib import contextmanager
from numbers import Real

import scipp as sc

from .types import DetectorData, FilteredData, RunType


def _equivalent_bin_indices(a, b) -> bool:
    a_begin = a.bins.constituents["begin"].flatten(to="")
    a_end = a.bins.constituents["end"].flatten(to="")
    b_begin = b.bins.constituents["begin"].flatten(to="")
    b_end = b.bins.constituents["end"].flatten(to="")
    non_empty = a_begin != a_end
    return (
        sc.all((a_begin == b_begin)[non_empty]).value
        and sc.all((a_end == b_end)[non_empty]).value
    )


@contextmanager
def _temporary_bin_coord(data: sc.DataArray, name: str, coord: sc.Variable) -> None:
    if not _equivalent_bin_indices(data, coord):
        raise ValueError("data and coord do not have equivalent bin indices")
    coord = sc.bins(
        data=coord.bins.constituents["data"],
        begin=data.bins.coords["pulse_time"].bins.constituents["begin"],
        end=data.bins.coords["pulse_time"].bins.constituents["end"],
        dim=coord.bins.constituents["dim"],
    )
    data.bins.coords[name] = coord
    yield
    del data.bins.coords[name]


# TODO non-monotonic proton charge -> raise?
def _with_pulse_time_edges(da: sc.DataArray, dim: str) -> sc.DataArray:
    pulse_time = da.coords[dim]
    one = sc.scalar(1, dtype="int64", unit=pulse_time.unit)
    lo = pulse_time[0] - one
    hi = pulse_time[-1] + one
    mid = sc.midpoints(pulse_time)
    da.coords[dim] = sc.concat([lo, mid, hi], dim)
    return da


[docs] def remove_bad_pulses( data: sc.DataArray, *, proton_charge: sc.DataArray, threshold_factor: Real ) -> sc.DataArray: """ assumes that there are bad pulses """ min_charge = proton_charge.data.mean() * threshold_factor good_pulse = _with_pulse_time_edges(proton_charge >= min_charge, proton_charge.dim) with _temporary_bin_coord( data, "good_pulse", sc.lookup(good_pulse, good_pulse.dim)[data.bins.coords[good_pulse.dim]], ): filtered = data.group(sc.array(dims=["good_pulse"], values=[True])) filtered = filtered.squeeze("good_pulse").copy(deep=False) del filtered.coords["good_pulse"] return filtered
[docs] def filter_events(data: DetectorData[RunType]) -> FilteredData[RunType]: """Remove bad events. Attention --------- This function currently does nothing because it is unclear how to filter events at ESS. In the future, this function will filter out events that cannot be used for analysis. Parameters ---------- data: Input events to be filtered. Returns ------- : `data` with bad events removed. """ # TODO this needs to filter by proton charge once we know how return FilteredData[RunType](data)
providers = (filter_events,) """Sciline providers for event filtering."""