Source code for ess.reflectometry.tools

# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2025 Scipp contributors (https://github.com/scipp)
from __future__ import annotations

import re
import uuid
from collections.abc import Mapping, Sequence
from itertools import chain
from typing import Any

import numpy as np
import sciline as sl
import scipp as sc
import scipy.optimize as opt

from ess.reflectometry import orso
from ess.reflectometry.types import (
    Filename,
    ReducibleData,
    ReflectivityOverQ,
    SampleRun,
)
from ess.reflectometry.workflow import with_filenames

_STD_TO_FWHM = sc.scalar(2.0) * sc.sqrt(sc.scalar(2.0) * sc.log(sc.scalar(2.0)))


[docs] def fwhm_to_std(fwhm: sc.Variable) -> sc.Variable: """ Convert from full-width half maximum to standard deviation. Parameters ---------- fwhm: Full-width half maximum. Returns ------- : Standard deviation. """ # Enables the conversion from full width half # maximum to standard deviation return fwhm / _STD_TO_FWHM
[docs] def linlogspace( dim: str, edges: list | np.ndarray, scale: list | str, num: list | int, unit: str | None = None, ) -> sc.Variable: """ Generate a 1d array of bin edges with a mixture of linear and/or logarithmic spacings. Examples: - Create linearly spaced edges (equivalent to `sc.linspace`): linlogspace(dim='x', edges=[0.008, 0.08], scale='linear', num=50, unit='m') - Create logarithmically spaced edges (equivalent to `sc.geomspace`): linlogspace(dim='x', edges=[0.008, 0.08], scale='log', num=50, unit='m') - Create edges with a linear and a logarithmic part: linlogspace(dim='x', edges=[1, 3, 8], scale=['linear', 'log'], num=[16, 20]) Parameters ---------- dim: The dimension of the output Variable. edges: The edges for the different parts of the mesh. scale: A string or list of strings specifying the scaling for the different parts of the mesh. Possible values for the scaling are `"linear"` and `"log"`. If a list is supplied, the length of the list must be one less than the length of the `edges` parameter. num: An integer or a list of integers specifying the number of points to use in each part of the mesh. If a list is supplied, the length of the list must be one less than the length of the `edges` parameter. unit: The unit of the output Variable. Returns ------- : Lin-log spaced Q-bin edges. """ if not isinstance(scale, list): scale = [scale] if not isinstance(num, list): num = [num] if len(scale) != len(edges) - 1: raise ValueError( "Sizes do not match. The length of edges should be one greater than scale." ) funcs = {"linear": sc.linspace, "log": sc.geomspace} grids = [] for i in range(len(edges) - 1): # Skip the leading edge in the piece when concatenating start = int(i > 0) mesh = funcs[scale[i]]( dim=dim, start=edges[i], stop=edges[i + 1], num=num[i] + start, unit=unit ) grids.append(mesh[dim, start:]) return sc.concat(grids, dim)
[docs] class MultiGraphViz: """ A dummy class to concatenate multiple graphviz visualizations into a single repr output for Jupyter notebooks. This combines the SVG representations of multiple graphs vertically with a small gap in between. """
[docs] def __init__(self, graphs: Sequence): self.graphs = graphs
def _repr_mimebundle_(self, include=None, exclude=None): gap = 10 parsed = [] for svg in [g._repr_image_svg_xml() for g in self.graphs]: # extract width, height, and inner <g> content m = re.search(r'width="([\d.]+)pt".*?height="([\d.]+)pt"', svg, re.S) w, h = float(m.group(1)), float(m.group(2)) inner = re.search(r'<svg[^>]*>(.*)</svg>', svg, re.S).group(1) parsed.append((w, h, inner)) # vertical shift total_width = max(w for w, _, _ in parsed) total_height = sum(h for _, h, _ in parsed) + gap * (len(parsed) - 1) pieces = [] offset_x = offset_y = 0 for _, h, inner in parsed: pieces.append( f'<g transform="translate({offset_x},{offset_y})">{inner}</g>' ) offset_y += h + gap # TODO: for some reason, combining the svgs seems to scale them down. This # then means that the computed bounding box is too large. For now, we # apply a fudge factor of 0.75 to the width and height. It is unclear where # exactly this comes from. combined = f''' <svg xmlns="http://www.w3.org/2000/svg" width="{total_width * 0.75}pt" height="{total_height * 0.75}pt"> {''.join(pieces)} </svg> ''' return {"image/svg+xml": combined}
[docs] class BatchProcessor: """ A collection of sciline workflows that can be used to compute multiple targets from multiple workflows. It can also be used to set parameters for all workflows in a single shot. """
[docs] def __init__(self, workflows: Mapping[str, sl.Pipeline]): self.workflows = workflows
def __setitem__(self, key: type, value: Mapping[str, Any]) -> None: """ A mapping (dict or DataGroup) should be supplied as the value. The keys of the mapping should correspond to the names of the workflows in the collection. The node matching the key will be set to the corresponding value for each of the workflows. """ for name, v in value.items(): self.workflows[name][key] = v
[docs] def __getitem__(self, name: str) -> BatchProcessor: """ Get a new BatchProcessor where the workflows are the sub-workflows that lead to the node with the given name. """ return BatchProcessor({k: wf[name] for k, wf in self.workflows.items()})
[docs] def compute(self, targets: type | Sequence[type], **kwargs) -> Mapping[str, Any]: """ Compute the given target(s) for all workflows in the collection. Parameters ---------- targets: The target type(s) to compute. **kwargs: Additional keyword arguments passed to `sciline.Pipeline.compute`. """ if not isinstance(targets, list | tuple): targets = [targets] out = {} for t in targets: out[t] = sc.DataGroup() for name, wf in self.workflows.items(): try: out[t][name] = wf.compute(t, **kwargs) except sl.UnsatisfiedRequirement as e: try: out[t][name] = sl.compute_mapped( wf, t, **kwargs ).values.tolist() except (sl.UnsatisfiedRequirement, ValueError): # ValueError is raised when the requested type is not mapped raise e from e return next(iter(out.values())) if len(out) == 1 else out
[docs] def copy(self) -> BatchProcessor: """ Create a copy of the workflow collection. """ return BatchProcessor({k: wf.copy() for k, wf in self.workflows.items()})
[docs] def visualize(self, targets: type | Sequence[type], **kwargs) -> MultiGraphViz: """ Visualize all workflows in the collection. Parameters ---------- targets : type | Sequence[type] The target type(s) to visualize. **kwargs: Additional keyword arguments passed to `sciline.Pipeline.visualize`. """ from graphviz import Digraph # Place all the graphviz Digraphs side by side into a single one. if not isinstance(targets, list | tuple): targets = [targets] graphs = [] for key, wf in self.workflows.items(): v = wf.visualize(targets, **kwargs) g = Digraph( graph_attr=v.graph_attr, node_attr=v.node_attr, edge_attr=v.edge_attr ) with g.subgraph(name=f"cluster_{key}") as c: c.attr(label=key, style="rounded", color="black") c.body.extend(v.body) graphs.append(g) return MultiGraphViz(graphs)
def _sort_by(a, by): return [x for x, _ in sorted(zip(a, by, strict=True), key=lambda x: x[1])] def _find_interval_overlaps(intervals): '''Returns the intervals where at least two or more of the provided intervals are overlapping.''' edges = list(chain.from_iterable(intervals)) is_start_edge = list(chain.from_iterable((True, False) for _ in intervals)) edges_sorted = sorted(edges) is_start_edge_sorted = _sort_by(is_start_edge, edges) number_overlapping = 0 overlap_intervals = [] for x, is_start in zip(edges_sorted, is_start_edge_sorted, strict=True): if number_overlapping == 1 and is_start: start = x if number_overlapping == 2 and not is_start: overlap_intervals.append((start, x)) if is_start: number_overlapping += 1 else: number_overlapping -= 1 return overlap_intervals def _searchsorted(a, v): for i, e in enumerate(a): if e > v: return i return len(a) def _create_qgrid_where_overlapping(qgrids): '''Given a number of Q-grids, construct a new grid covering the regions where (any two of the) provided grids overlap.''' pieces = [] for start, end in _find_interval_overlaps([(q.min(), q.max()) for q in qgrids]): interval_sliced_from_qgrids = [ q[max(_searchsorted(q, start) - 1, 0) : _searchsorted(q, end) + 1] for q in qgrids ] densest_grid_in_interval = max(interval_sliced_from_qgrids, key=len) pieces.append(densest_grid_in_interval) return sc.concat(pieces, dim='Q') def _same_dtype(arrays): return [arr.to(dtype='float64') for arr in arrays] def _interpolate_on_qgrid(curves, grid): return sc.concat( _same_dtype([sc.lookup(c, grid.dim)[sc.midpoints(grid)] for c in curves]), dim='curves', )
[docs] def scale_for_reflectivity_overlap( reflectivities: sc.DataArray | Mapping[str, sc.DataArray] | sc.DataGroup, critical_edge_interval: tuple[sc.Variable, sc.Variable] | list[sc.Variable] | None = None, ) -> sc.DataArray | sc.DataGroup: ''' Compute a scaling for 1D reflectivity curves in a way that would makes the curves overlap. One can supply either a single curve or a collection/DataGroup of curves. If :code:`critical_edge_interval` is not provided, all curves are scaled except the data with the lowest Q-range, which is considered to be the reference curve. The scaling factors are determined by a maximum likelihood estimate (assuming the errors are normal distributed). If :code:`critical_edge_interval` is provided then all data are scaled. All reflectivity curves must be have the same unit for data and the Q-coordinate. Parameters --------- reflectivities: The reflectivity curves that should be scaled. critical_edge_interval: A tuple denoting an interval that is known to belong to the critical edge, i.e. where the reflectivity is known to be 1. Returns --------- : A DataGroup with the same keys as the input containing the scaling factors for each reflectivity curve. ''' only_one_curve = isinstance(reflectivities, sc.DataArray) if only_one_curve: reflectivities = {"": reflectivities} # First sort the dict of reflectivities by the Q min value curves = { k: v.hist() if v.bins is not None else v for k, v in sorted( reflectivities.items(), key=lambda item: item[1].coords['Q'].min().value ) } critical_edge_key = uuid.uuid4().hex if critical_edge_interval is not None: q = {key: c.coords['Q'] for key, c in curves.items()} q = min(q.values(), key=lambda q_: q_.min()) # TODO: This is slightly different from before: it extracts the bins from the # QBins variable that cover the critical edge interval. This means that the # resulting curve will not necessarily begin and end exactly at the values # specified, but rather at the closest bin edges. edge = sc.DataArray( data=sc.ones(sizes={q.dim: q.sizes[q.dim] - 1}, with_variances=True), coords={q.dim: q}, )[q.dim, critical_edge_interval[0] : critical_edge_interval[1]] # Now place the critical edge at the beginning curves = {critical_edge_key: edge} | curves if len({c.data.unit for c in curves.values()}) != 1: raise ValueError('The reflectivity curves must have the same unit') if len({c.coords['Q'].unit for c in curves.values()}) != 1: raise ValueError('The Q-coordinates must have the same unit for each curve') qgrid = _create_qgrid_where_overlapping([c.coords['Q'] for c in curves.values()]) r = _interpolate_on_qgrid(map(sc.values, curves.values()), qgrid).values v = _interpolate_on_qgrid(map(sc.variances, curves.values()), qgrid).values def cost(scaling_factors): scaling_factors = np.concatenate([[1.0], scaling_factors])[:, None] r_scaled = scaling_factors * r v_scaled = scaling_factors**2 * v v_scaled[v_scaled == 0] = np.nan inv_v_scaled = 1 / v_scaled r_avg = np.nansum(r_scaled * inv_v_scaled, axis=0) / np.nansum( inv_v_scaled, axis=0 ) return np.nansum((r_scaled - r_avg) ** 2 * inv_v_scaled) sol = opt.minimize(cost, [1.0] * (len(curves) - 1)) scaling_factors = (1.0, *map(float, sol.x)) out = sc.DataGroup( { k: v for k, v in zip(curves.keys(), scaling_factors, strict=True) if k != critical_edge_key } ) return out[""] if only_one_curve else out
[docs] def combine_curves( curves: Sequence[sc.DataArray] | sc.DataGroup | Mapping[str, sc.DataArray], q_bin_edges: sc.Variable | None = None, ) -> sc.DataArray: '''Combines the given curves by interpolating them on a 1d grid defined by :code:`q_bin_edges` and averaging over the provided reflectivity curves. The averaging is done using a weighted mean where the weights are proportional to the variances. Unless the curves are already scaled correctly they might need to be scaled using :func:`scale_reflectivity_curves_to_overlap` before calling this function. All curves must be have the same unit for data and the Q-coordinate. Parameters ---------- curves: the reflectivity curves that should be combined q_bin_edges: the Q bin edges of the resulting combined reflectivity curve Returns --------- : A data array representing the combined reflectivity curve ''' if hasattr(curves, 'items'): curves = list(curves.values()) if len({c.data.unit for c in curves}) != 1: raise ValueError('The reflectivity curves must have the same unit') if len({c.coords['Q'].unit for c in curves}) != 1: raise ValueError('The Q-coordinates must have the same unit for each curve') r = _interpolate_on_qgrid(map(sc.values, curves), q_bin_edges).values v = _interpolate_on_qgrid(map(sc.variances, curves), q_bin_edges).values v[v == 0] = np.nan inv_v = 1.0 / v r_avg = np.nansum(r * inv_v, axis=0) / np.nansum(inv_v, axis=0) v_avg = 1 / np.nansum(inv_v, axis=0) return sc.DataArray( data=sc.array( dims='Q', values=r_avg, variances=v_avg, unit=next(iter(curves)).data.unit, ), coords={'Q': q_bin_edges}, )
[docs] def batch_processor( workflow: sl.Pipeline, runs: Mapping[Any, Mapping[type, Any]] ) -> BatchProcessor: """ Creates a collection of sciline workflows from the provided runs. Example: ``` from ess.reflectometry import amor, tools workflow = amor.AmorWorkflow() runs = { '608': { SampleRotationOffset[SampleRun]: sc.scalar(0.05, unit='deg'), Filename[SampleRun]: "file_608.hdf", }, '609': { SampleRotationOffset[SampleRun]: sc.scalar(0.05, unit='deg'), Filename[SampleRun]: "file_609.hdf", }, '610': { SampleRotationOffset[SampleRun]: sc.scalar(0.05, unit='deg'), Filename[SampleRun]: "file_610.hdf", }, '611': { SampleRotationOffset[SampleRun]: sc.scalar(0.05, unit='deg'), Filename[SampleRun]: "file_611.hdf", }, } batch = tools.batch_processor(workflow, runs) results = batch.compute(ReflectivityOverQ) ``` Additionally, if a list of filenames is provided for ``Filename[SampleRun]``, the events from the files will be concatenated into a single event list before processing. Example: ``` runs = { '608': { Filename[SampleRun]: "file_608.hdf", }, '609+610': { Filename[SampleRun]: ["file_609.hdf", "file_610.hdf"], }, } ``` Parameters ---------- workflow: The sciline workflow used to compute the targets for each of the runs. runs: The sciline parameters to be used for each run. Should be a mapping where the keys are the names of the runs and the values are mappings of type to value pairs. In addition, if one of the values for ``Filename[SampleRun]`` is a list or a tuple, then the events from the files will be concatenated into a single event list. """ workflows = {} for name, parameters in runs.items(): wf = workflow.copy() for tp, value in parameters.items(): if tp is Filename[SampleRun]: continue wf[tp] = value if Filename[SampleRun] in parameters: if isinstance(parameters[Filename[SampleRun]], list | tuple): wf = with_filenames(wf, SampleRun, parameters[Filename[SampleRun]]) else: wf[Filename[SampleRun]] = parameters[Filename[SampleRun]] workflows[name] = wf return BatchProcessor(workflows)
[docs] def batch_compute( workflow: sl.Pipeline, runs: Sequence[Mapping[type, Any]] | Mapping[Any, Mapping[type, Any]], target: type | Sequence[type] = orso.OrsoIofQDataset, *, scale_to_overlap: bool | tuple[sc.Variable, sc.Variable] | list[sc.Variable] = False, ) -> list | Mapping: ''' Computes requested target(s) from a supplied workflow for a number of runs. Each entry of :code:`runs` is a mapping of parameters and values needed to produce the targets. This is an alternative to using :func:`batch_processor`: instead of returning a BatchProcessor object which can operate on multiple workflows at once, this function directly computes the requested targets, reducing the risk of accidentally compromizing the workflows in the collection. It also provides the option to scale the reflectivity curves so that they overlap in the regions where they have the same Q-value. Beginners should prefer this function over :func:`batch_processor` unless they need the extra flexibility of the latter (caching intermediate results, quickly exploring results, etc). Example: ``` from ess.reflectometry import amor, tools workflow = amor.AmorWorkflow() runs = { '608': { SampleRotationOffset[SampleRun]: sc.scalar(0.05, unit='deg'), Filename[SampleRun]: "file_608.hdf", }, '609': { SampleRotationOffset[SampleRun]: sc.scalar(0.05, unit='deg'), Filename[SampleRun]: "file_609.hdf", }, '610': { SampleRotationOffset[SampleRun]: sc.scalar(0.05, unit='deg'), Filename[SampleRun]: "file_610.hdf", }, '611': { SampleRotationOffset[SampleRun]: sc.scalar(0.05, unit='deg'), Filename[SampleRun]: "file_611.hdf", }, } r_of_q = tools.batch_compute(workflow, runs, target=ReflectivityOverQ) ``` Additionally, if a list of filenames is provided for ``Filename[SampleRun]``, the events from the files will be concatenated into a single event list before processing. Example: ``` runs = { '608': { Filename[SampleRun]: "file_608.hdf", }, '609+610': { Filename[SampleRun]: ["file_609.hdf", "file_610.hdf"], }, } ``` Parameters ----------- workflow: The sciline workflow used to compute `ReflectivityOverQ` for each of the runs. runs: The sciline parameters to be used for each run. target: The domain type(s) to compute for each run. scale_to_overlap: If ``True`` the loaded data will be scaled so that the computed reflectivity curves to overlap. If a tuple is provided, it is interpreted as a critical edge interval where the reflectivity is known to be 1. ''' batch = batch_processor(workflow=workflow, runs=runs) if scale_to_overlap: results = batch.compute((ReflectivityOverQ, ReducibleData[SampleRun])) scale_factors = scale_for_reflectivity_overlap( results[ReflectivityOverQ].hist(), critical_edge_interval=scale_to_overlap if isinstance(scale_to_overlap, tuple | list) else None, ) batch[ReducibleData[SampleRun]] = ( scale_factors * results[ReducibleData[SampleRun]] ) batch[ReflectivityOverQ] = scale_factors * results[ReflectivityOverQ] return batch.compute(target)