Source code for plopp.plotting._inspector

# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

from functools import partial
from typing import Literal

import numpy as np
import scipp as sc
from matplotlib.path import Path

from ..core import Node
from ..core.typing import Plottable
from ..core.utils import coord_as_bin_edges
from ..graphics import imagefigure, linefigure
from ..widgets import Box, PointsTool, PolygonTool, RectangleTool
from ._slicer import SlicerPlot
from .common import preprocess, require_interactive_figure


def _to_bin_edges(da: sc.DataArray, dim: str) -> sc.DataArray:
    """
    Convert dimension coords to bin edges.
    """
    for d in set(da.dims) - {dim}:
        da.coords[d] = coord_as_bin_edges(da, d)
    return da


def _to_bin_centers(da: sc.DataArray, dim: str) -> sc.DataArray:
    """
    Convert dimension coords to bin centers.
    """
    for d in set(da.dims) - {dim}:
        da.coords[d] = sc.midpoints(da.coords[d], dim=d)
    return da


def _apply_op(da: sc.DataArray, op: str, dim: str) -> sc.DataArray:
    out = getattr(sc, op)(da, dim=dim)
    if out.name:
        out.name = f'{op} of {out.name}'
    return out


def _slice_xy(da: sc.DataArray, xy: dict[str, dict[str, int]]) -> sc.DataArray:
    x = xy['x']
    y = xy['y']
    try:
        # If there is a 2D coordinate in the data, we need to slice the other dimension
        # first, as trying to slice a 2D coordinate using label-based indexing raises an
        # error in Scipp. After slicing the other dimension, the 2D coordinate will be
        # 1D and can be sliced normally using label-based indexing.
        # We assume here that there would only be one multi-dimensional coordinate in a
        # given DataArray (which is very likely the case).
        if da.coords[y['dim']].ndim > 1:
            return da[x['dim'], x['value']][y['dim'], y['value']]
        else:
            return da[y['dim'], y['value']][x['dim'], x['value']]
    except IndexError:
        # If the index is out of bounds, return an empty DataArray
        return sc.full_like(da[y['dim'], 0][x['dim'], 0], value=np.nan, dtype=float)


def _slice_rectangular_region(da: sc.DataArray, rect: dict, op: str) -> sc.DataArray:
    x = rect['x']
    y = rect['y']
    xmin, xmax = x['value'].min(), x['value'].max()
    ymin, ymax = y['value'].min(), y['value'].max()
    try:
        # If there is a 2D coordinate in the data, we need to slice the other dimension
        # first, as trying to slice a 2D coordinate using label-based indexing raises an
        # error in Scipp. After slicing the other dimension, the 2D coordinate will be
        # 1D and can be sliced normally using label-based indexing.
        # We assume here that there would only be one multi-dimensional coordinate in a
        # given DataArray (which is very likely the case).
        if da.coords[y['dim']].ndim > 1:
            out = da[x['dim'], xmin:xmax][y['dim'], ymin:ymax]
        else:
            out = da[y['dim'], ymin:ymax][x['dim'], xmin:xmax]
        # If the operation is a mean, there is currently a bug in the implementation
        # in scipp where doing a mean over a subset of the array's dimensions gives the
        # wrong result: https://github.com/scipp/scipp/issues/3841
        # Instead, we manually compute the mean
        dims = (x['dim'], y['dim'])
        if 'mean' not in op:
            return getattr(out, op)(dims)
        if 'nan' in op:
            numerator = out.nansum(dims)
            denominator = (~sc.isnan(out.data)).sum()
            denominator.unit = ""
        else:
            numerator = out.sum(dims)
            denominator = out.size
        return numerator / denominator
    except IndexError:
        # If the index is out of bounds, return an empty DataArray
        return sc.full_like(da[y['dim'], 0][x['dim'], 0], value=np.nan, dtype=float)


def _mask_outside_polygon(
    da: sc.DataArray,
    poly: dict,
    points: np.ndarray,
    sizes: dict[str, int],
    op: str,
    non_nan: sc.Variable,
) -> sc.DataArray:
    vx = poly['x']['value'].values
    vy = poly['y']['value'].values
    verts = np.column_stack([vx, vy])
    path = Path(verts)
    dims = sizes.keys()
    inside = sc.array(
        dims=dims,
        values=path.contains_points(points).reshape(tuple(sizes.values())),
    )
    masked = da.assign_masks({str(da.masks.keys()): ~inside})
    # If the operation is a mean, there is currently a bug in the implementation
    # in scipp where doing a mean over a subset of the array's dimensions gives the
    # wrong result: https://github.com/scipp/scipp/issues/3841
    # Instead, we manually compute the mean
    if 'mean' not in op:
        return getattr(masked, op)(dims)
    if 'nan' in op:
        numerator = masked.nansum(dims)
        denominator = (inside & non_nan).sum()
    else:
        numerator = masked.sum(dims)
        denominator = inside.sum()
    denominator.unit = ""
    return numerator / denominator


[docs] def inspector( obj: Plottable, dim: str | None = None, *, aspect: Literal['auto', 'equal'] | None = None, autoscale: bool = True, cbar: bool = True, clabel: str | None = None, cmax: sc.Variable | float | None = None, cmin: sc.Variable | float | None = None, continuous_update: bool = True, coords: list[str] | None = None, errorbars: bool = True, figsize: tuple[float, float] | None = None, grid: bool = False, legend: bool | tuple[float, float] = True, logc: bool | None = None, mask_cmap: str = 'gray', mask_color: str = 'black', mode: Literal['point', 'polygon', 'rectangle'] = 'point', nan_color: str | None = None, norm: Literal['linear', 'log'] | None = None, operation: Literal[ 'sum', 'mean', 'min', 'max', 'nansum', 'nanmean', 'nanmin', 'nanmax' ] = 'sum', orientation: Literal['horizontal', 'vertical'] = 'horizontal', title: str | None = None, vmax: sc.Variable | float | None = None, vmin: sc.Variable | float | None = None, xlabel: str | None = None, xmax: sc.Variable | float | None = None, xmin: sc.Variable | float | None = None, ylabel: str | None = None, ymax: sc.Variable | float | None = None, ymin: sc.Variable | float | None = None, with_slider: bool = True, **kwargs, ): """ Inspector takes in a three-dimensional input and applies a reduction operation (``'sum'`` by default) along one of the dimensions specified by ``dim``. It displays the result as a two-dimensional image. In addition, an 'inspection' tool is available in the toolbar. In ``mode='point'`` it allows placing point markers on the image to slice at that position, retaining only the third dimension and displaying the resulting one-dimensional slice in the right-hand side figure. In ``mode='polygon'`` it allows drawing a polygon to compute the total intensity inside the polygon as a function of the third dimension. In ``mode='rectangle'`` it allows drawing a rectangle to compute the total intensity inside the rectangle as a function of the third dimension. Controls (point mode): - Left-click to make new points - Left-click and hold on point to move point - Middle-click to delete point Controls (rectangle mode): - Left-click to make new rectangles - Left-click and hold on rectangle vertices to resize rectangle - Right-click and hold to drag/move the entire rectangle - Middle-click to delete rectangle Controls (polygon mode): - Left-click to make new polygons - Left-click and hold on polygon vertex to move vertex - Right-click and hold to drag/move the entire polygon - Middle-click to delete polygon Notes ----- Almost all the arguments for plot customization apply to the two-dimensional image (unless specified). In rectangle mode, if any part of a data pixel lies inside the rectangle, the whole pixel is included in the selected region (as opposed to computing the overlap area). This is because it is not always possible to know how the bin fraction should be included, depending on the reduction operation which is applied to the data (sum, mean, etc). If one of the edges of the rectangle lies exactly on a bin edge between two data points, the selected region does not include the data on the outside of the rectangle, even though the edge is touching the rectangle. In polygon mode, only data whose bin centers are inside the polygon are included in the selected region. Parameters ---------- obj: The object to be plotted. dim: The dimension along which to apply the reduction operation. This will also be the dimension that remains in the one-dimensional slices generated by adding markers on the image. If no dim is provided, the last (inner) dim of the input data will be used. aspect: Aspect ratio for the axes. autoscale: Automatically scale the axes/colormap on updates if ``True``. cbar: Show colorbar if ``True``. clabel: Label for colorscale. cmax: Upper limit for colorscale. cmin: Lower limit for colorscale. continuous_update: If ``True``, update the data selected by the markers, rectangles, or polygons continuously as they are being moved or resized. If ``False``, only update the data when the user releases the mouse button after moving/resizing. coords: If supplied, use these coords instead of the input's dimension coordinates. errorbars: Show errorbars if ``True`` (1d figure). figsize: The width and height of the figure, in inches. grid: Show grid if ``True``. legend: Show legend if ``True``. If ``legend`` is a tuple, it should contain the ``(x, y)`` coordinates of the legend's anchor point in axes coordinates (1d figure). logc: If ``True``, use logarithmic scale for colorscale. mask_cmap: Colormap to use for masks. mask_color: Color of masks (overrides ``mask_cmap``). mode: Select ``'point'`` for point inspection, ``'polygon'`` for polygon selection, or ``'rectangle'`` for rectangle selection with total intensity inside the shape plotted as a function of ``dim``. nan_color: Color to use for NaN values. norm: Set to ``'log'`` for a logarithmic colorscale. Legacy, prefer ``logc`` instead. operation: The operation to apply along the third (undisplayed) dimension specified by ``dim``. The same operation is also applied to the data within the selected region in the case of ``polygon`` and ``rectangle`` modes. orientation: Display the two panels side-by-side ('horizontal') or one below the other ('vertical'). title: The figure title. vmax: Upper limit for data colorscale to be displayed. Legacy, prefer ``cmax`` instead. vmin: Lower limit for data colorscale to be displayed. Legacy, prefer ``cmin`` instead. xlabel: Label for x-axis. xmax: Upper limit for x-axis (1d figure) xmin: Lower limit for x-axis (1d figure) ylabel: Label for y-axis. ymax: Upper limit for y-axis (1d figure). ymin: Lower limit for y-axis (1d figure). with_slider: Show slider under 2d image for selecting data range if ``True``. A currently selected range indicator will also be displayed on the 1d profile figure. **kwargs: Additional arguments forwarded to the underlying plotting library. Returns ------- : A :class:`Box` which will contain two :class:`Figure` and one slider widget. """ if mode not in ['point', 'polygon', 'rectangle']: raise ValueError( f'Invalid mode: {mode}. Allowed modes are "point", "polygon", "rectangle".' ) f1d = linefigure( autoscale=autoscale, errorbars=errorbars, grid=grid, legend=legend, mask_color=mask_color, xmax=xmax, xmin=xmin, ymax=ymax, ymin=ymin, ) require_interactive_figure(f1d, 'inspector') in_node = Node(preprocess, obj, ignore_size=True, coords=coords) data = in_node() if data.ndim != 3: raise ValueError( 'The inspector plot currently only works with ' f'three-dimensional data, found {data.ndim} dims.' ) if dim is None: dim = data.dims[-1] bin_edges_node = Node(_to_bin_edges, in_node, dim=dim) bin_centers_node = Node(_to_bin_centers, bin_edges_node, dim=dim) f2d_args = dict( aspect=aspect, cbar=cbar, clabel=clabel, cmax=cmax, cmin=cmin, figsize=figsize, grid=grid, logc=logc, mask_cmap=mask_cmap, mask_color=mask_color, nan_color=nan_color, norm=norm, title=title, vmax=vmax, vmin=vmin, xlabel=xlabel, ylabel=ylabel, **kwargs, ) if with_slider: slicer_plot = SlicerPlot( bin_edges_node, keep=set(data.dims) - {dim}, operation=operation, **f2d_args ) f2d = slicer_plot.figure span = f1d.ax.axvspan( data.coords[dim].min().value, data.coords[dim].max().value, color='gray', alpha=0.2, zorder=-np.inf, ) bin_edge_coord = coord_as_bin_edges(data, dim) def update_span(change: dict) -> None: start, end = change['owner'].controls[dim].value start = bin_edge_coord[dim, start].value end = bin_edge_coord[dim, end + 1].value span.set_bounds(start, 0, end - start, 1) f1d.canvas.draw() slicer_plot.slicer.slider.observe(update_span, names='value') else: op_node = Node(_apply_op, da=bin_edges_node, op=operation, dim=dim) f2d = imagefigure(op_node, **f2d_args) match mode: case 'point': tool = PointsTool( figure=f2d, input_node=bin_edges_node, func=_slice_xy, destination=f1d, tooltip="Activate inspector tool", continuous_update=continuous_update, ) case 'rectangle': tool = RectangleTool( figure=f2d, input_node=bin_edges_node, func=partial(_slice_rectangular_region, op=operation), destination=f1d, tooltip="Activate rectangle inspector tool", continuous_update=continuous_update, ) case 'polygon': da = bin_centers_node() xdim = f2d.canvas.dims['x'] ydim = f2d.canvas.dims['y'] x = da.coords[xdim] y = da.coords[ydim] sizes = {**x.sizes, **y.sizes} xx = sc.broadcast(x, sizes=sizes) yy = sc.broadcast(y, sizes=sizes) points = np.column_stack([xx.values.ravel(), yy.values.ravel()]) non_nan = ~sc.isnan(da.data) tool = PolygonTool( figure=f2d, input_node=bin_centers_node, func=partial( _mask_outside_polygon, points=points, sizes=sizes, op=operation, non_nan=non_nan, ), destination=f1d, tooltip="Activate polygon inspector tool", continuous_update=continuous_update, ) f2d.toolbar['inspect'] = tool out = [f2d, f1d] if orientation == 'horizontal': out = [out] return Box(out)