# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from functools import partial
from typing import Literal
import numpy as np
import scipp as sc
from matplotlib.path import Path
from ..core import Node
from ..core.typing import Plottable
from ..core.utils import coord_as_bin_edges
from ..graphics import imagefigure, linefigure
from ..widgets import Box, PointsTool, PolygonTool, RectangleTool
from ._slicer import SlicerPlot
from .common import preprocess, require_interactive_figure
def _to_bin_edges(da: sc.DataArray, dim: str) -> sc.DataArray:
"""
Convert dimension coords to bin edges.
"""
for d in set(da.dims) - {dim}:
da.coords[d] = coord_as_bin_edges(da, d)
return da
def _to_bin_centers(da: sc.DataArray, dim: str) -> sc.DataArray:
"""
Convert dimension coords to bin centers.
"""
for d in set(da.dims) - {dim}:
da.coords[d] = sc.midpoints(da.coords[d], dim=d)
return da
def _apply_op(da: sc.DataArray, op: str, dim: str) -> sc.DataArray:
out = getattr(sc, op)(da, dim=dim)
if out.name:
out.name = f'{op} of {out.name}'
return out
def _slice_xy(da: sc.DataArray, xy: dict[str, dict[str, int]]) -> sc.DataArray:
x = xy['x']
y = xy['y']
try:
# If there is a 2D coordinate in the data, we need to slice the other dimension
# first, as trying to slice a 2D coordinate using label-based indexing raises an
# error in Scipp. After slicing the other dimension, the 2D coordinate will be
# 1D and can be sliced normally using label-based indexing.
# We assume here that there would only be one multi-dimensional coordinate in a
# given DataArray (which is very likely the case).
if da.coords[y['dim']].ndim > 1:
return da[x['dim'], x['value']][y['dim'], y['value']]
else:
return da[y['dim'], y['value']][x['dim'], x['value']]
except IndexError:
# If the index is out of bounds, return an empty DataArray
return sc.full_like(da[y['dim'], 0][x['dim'], 0], value=np.nan, dtype=float)
def _slice_rectangular_region(da: sc.DataArray, rect: dict, op: str) -> sc.DataArray:
x = rect['x']
y = rect['y']
xmin, xmax = x['value'].min(), x['value'].max()
ymin, ymax = y['value'].min(), y['value'].max()
try:
# If there is a 2D coordinate in the data, we need to slice the other dimension
# first, as trying to slice a 2D coordinate using label-based indexing raises an
# error in Scipp. After slicing the other dimension, the 2D coordinate will be
# 1D and can be sliced normally using label-based indexing.
# We assume here that there would only be one multi-dimensional coordinate in a
# given DataArray (which is very likely the case).
if da.coords[y['dim']].ndim > 1:
out = da[x['dim'], xmin:xmax][y['dim'], ymin:ymax]
else:
out = da[y['dim'], ymin:ymax][x['dim'], xmin:xmax]
# If the operation is a mean, there is currently a bug in the implementation
# in scipp where doing a mean over a subset of the array's dimensions gives the
# wrong result: https://github.com/scipp/scipp/issues/3841
# Instead, we manually compute the mean
dims = (x['dim'], y['dim'])
if 'mean' not in op:
return getattr(out, op)(dims)
if 'nan' in op:
numerator = out.nansum(dims)
denominator = (~sc.isnan(out.data)).sum()
denominator.unit = ""
else:
numerator = out.sum(dims)
denominator = out.size
return numerator / denominator
except IndexError:
# If the index is out of bounds, return an empty DataArray
return sc.full_like(da[y['dim'], 0][x['dim'], 0], value=np.nan, dtype=float)
def _mask_outside_polygon(
da: sc.DataArray,
poly: dict,
points: np.ndarray,
sizes: dict[str, int],
op: str,
non_nan: sc.Variable,
) -> sc.DataArray:
vx = poly['x']['value'].values
vy = poly['y']['value'].values
verts = np.column_stack([vx, vy])
path = Path(verts)
dims = sizes.keys()
inside = sc.array(
dims=dims,
values=path.contains_points(points).reshape(tuple(sizes.values())),
)
masked = da.assign_masks({str(da.masks.keys()): ~inside})
# If the operation is a mean, there is currently a bug in the implementation
# in scipp where doing a mean over a subset of the array's dimensions gives the
# wrong result: https://github.com/scipp/scipp/issues/3841
# Instead, we manually compute the mean
if 'mean' not in op:
return getattr(masked, op)(dims)
if 'nan' in op:
numerator = masked.nansum(dims)
denominator = (inside & non_nan).sum()
else:
numerator = masked.sum(dims)
denominator = inside.sum()
denominator.unit = ""
return numerator / denominator
[docs]
def inspector(
obj: Plottable,
dim: str | None = None,
*,
aspect: Literal['auto', 'equal'] | None = None,
autoscale: bool = True,
cbar: bool = True,
clabel: str | None = None,
cmax: sc.Variable | float | None = None,
cmin: sc.Variable | float | None = None,
continuous_update: bool = True,
coords: list[str] | None = None,
errorbars: bool = True,
figsize: tuple[float, float] | None = None,
grid: bool = False,
legend: bool | tuple[float, float] = True,
logc: bool | None = None,
mask_cmap: str = 'gray',
mask_color: str = 'black',
mode: Literal['point', 'polygon', 'rectangle'] = 'point',
nan_color: str | None = None,
norm: Literal['linear', 'log'] | None = None,
operation: Literal[
'sum', 'mean', 'min', 'max', 'nansum', 'nanmean', 'nanmin', 'nanmax'
] = 'sum',
orientation: Literal['horizontal', 'vertical'] = 'horizontal',
title: str | None = None,
vmax: sc.Variable | float | None = None,
vmin: sc.Variable | float | None = None,
xlabel: str | None = None,
xmax: sc.Variable | float | None = None,
xmin: sc.Variable | float | None = None,
ylabel: str | None = None,
ymax: sc.Variable | float | None = None,
ymin: sc.Variable | float | None = None,
with_slider: bool = True,
**kwargs,
):
"""
Inspector takes in a three-dimensional input and applies a reduction operation
(``'sum'`` by default) along one of the dimensions specified by ``dim``.
It displays the result as a two-dimensional image.
In addition, an 'inspection' tool is available in the toolbar. In ``mode='point'``
it allows placing point markers on the image to slice at that position, retaining
only the third dimension and displaying the resulting one-dimensional slice in the
right-hand side figure. In ``mode='polygon'`` it allows drawing a polygon to compute
the total intensity inside the polygon as a function of the third dimension.
In ``mode='rectangle'`` it allows drawing a rectangle to compute the total intensity
inside the rectangle as a function of the third dimension.
Controls (point mode):
- Left-click to make new points
- Left-click and hold on point to move point
- Middle-click to delete point
Controls (rectangle mode):
- Left-click to make new rectangles
- Left-click and hold on rectangle vertices to resize rectangle
- Right-click and hold to drag/move the entire rectangle
- Middle-click to delete rectangle
Controls (polygon mode):
- Left-click to make new polygons
- Left-click and hold on polygon vertex to move vertex
- Right-click and hold to drag/move the entire polygon
- Middle-click to delete polygon
Notes
-----
Almost all the arguments for plot customization apply to the two-dimensional image
(unless specified).
In rectangle mode, if any part of a data pixel lies inside the rectangle, the whole
pixel is included in the selected region (as opposed to computing the overlap area).
This is because it is not always possible to know how the bin fraction should be
included, depending on the reduction operation which is applied to the data (sum,
mean, etc).
If one of the edges of the rectangle lies exactly on a bin edge
between two data points, the selected region does not include the data on the
outside of the rectangle, even though the edge is touching the rectangle.
In polygon mode, only data whose bin centers are inside the polygon are included in
the selected region.
Parameters
----------
obj:
The object to be plotted.
dim:
The dimension along which to apply the reduction operation. This will also be
the dimension that remains in the one-dimensional slices generated by adding
markers on the image. If no dim is provided, the last (inner) dim of the input
data will be used.
aspect:
Aspect ratio for the axes.
autoscale:
Automatically scale the axes/colormap on updates if ``True``.
cbar:
Show colorbar if ``True``.
clabel:
Label for colorscale.
cmax:
Upper limit for colorscale.
cmin:
Lower limit for colorscale.
continuous_update:
If ``True``, update the data selected by the markers, rectangles, or polygons
continuously as they are being moved or resized. If ``False``, only update the
data when the user releases the mouse button after moving/resizing.
coords:
If supplied, use these coords instead of the input's dimension coordinates.
errorbars:
Show errorbars if ``True`` (1d figure).
figsize:
The width and height of the figure, in inches.
grid:
Show grid if ``True``.
legend:
Show legend if ``True``. If ``legend`` is a tuple, it should contain the
``(x, y)`` coordinates of the legend's anchor point in axes coordinates
(1d figure).
logc:
If ``True``, use logarithmic scale for colorscale.
mask_cmap:
Colormap to use for masks.
mask_color:
Color of masks (overrides ``mask_cmap``).
mode:
Select ``'point'`` for point inspection, ``'polygon'`` for polygon selection,
or ``'rectangle'`` for rectangle selection with total intensity inside the
shape plotted as a function of ``dim``.
nan_color:
Color to use for NaN values.
norm:
Set to ``'log'`` for a logarithmic colorscale. Legacy, prefer ``logc`` instead.
operation:
The operation to apply along the third (undisplayed) dimension specified by
``dim``. The same operation is also applied to the data within the selected
region in the case of ``polygon`` and ``rectangle`` modes.
orientation:
Display the two panels side-by-side ('horizontal') or one below the other
('vertical').
title:
The figure title.
vmax:
Upper limit for data colorscale to be displayed.
Legacy, prefer ``cmax`` instead.
vmin:
Lower limit for data colorscale to be displayed.
Legacy, prefer ``cmin`` instead.
xlabel:
Label for x-axis.
xmax:
Upper limit for x-axis (1d figure)
xmin:
Lower limit for x-axis (1d figure)
ylabel:
Label for y-axis.
ymax:
Upper limit for y-axis (1d figure).
ymin:
Lower limit for y-axis (1d figure).
with_slider:
Show slider under 2d image for selecting data range if ``True``. A currently
selected range indicator will also be displayed on the 1d profile figure.
**kwargs:
Additional arguments forwarded to the underlying plotting library.
Returns
-------
:
A :class:`Box` which will contain two :class:`Figure` and one slider widget.
"""
if mode not in ['point', 'polygon', 'rectangle']:
raise ValueError(
f'Invalid mode: {mode}. Allowed modes are "point", "polygon", "rectangle".'
)
f1d = linefigure(
autoscale=autoscale,
errorbars=errorbars,
grid=grid,
legend=legend,
mask_color=mask_color,
xmax=xmax,
xmin=xmin,
ymax=ymax,
ymin=ymin,
)
require_interactive_figure(f1d, 'inspector')
in_node = Node(preprocess, obj, ignore_size=True, coords=coords)
data = in_node()
if data.ndim != 3:
raise ValueError(
'The inspector plot currently only works with '
f'three-dimensional data, found {data.ndim} dims.'
)
if dim is None:
dim = data.dims[-1]
bin_edges_node = Node(_to_bin_edges, in_node, dim=dim)
bin_centers_node = Node(_to_bin_centers, bin_edges_node, dim=dim)
f2d_args = dict(
aspect=aspect,
cbar=cbar,
clabel=clabel,
cmax=cmax,
cmin=cmin,
figsize=figsize,
grid=grid,
logc=logc,
mask_cmap=mask_cmap,
mask_color=mask_color,
nan_color=nan_color,
norm=norm,
title=title,
vmax=vmax,
vmin=vmin,
xlabel=xlabel,
ylabel=ylabel,
**kwargs,
)
if with_slider:
slicer_plot = SlicerPlot(
bin_edges_node, keep=set(data.dims) - {dim}, operation=operation, **f2d_args
)
f2d = slicer_plot.figure
span = f1d.ax.axvspan(
data.coords[dim].min().value,
data.coords[dim].max().value,
color='gray',
alpha=0.2,
zorder=-np.inf,
)
bin_edge_coord = coord_as_bin_edges(data, dim)
def update_span(change: dict) -> None:
start, end = change['owner'].controls[dim].value
start = bin_edge_coord[dim, start].value
end = bin_edge_coord[dim, end + 1].value
span.set_bounds(start, 0, end - start, 1)
f1d.canvas.draw()
slicer_plot.slicer.slider.observe(update_span, names='value')
else:
op_node = Node(_apply_op, da=bin_edges_node, op=operation, dim=dim)
f2d = imagefigure(op_node, **f2d_args)
match mode:
case 'point':
tool = PointsTool(
figure=f2d,
input_node=bin_edges_node,
func=_slice_xy,
destination=f1d,
tooltip="Activate inspector tool",
continuous_update=continuous_update,
)
case 'rectangle':
tool = RectangleTool(
figure=f2d,
input_node=bin_edges_node,
func=partial(_slice_rectangular_region, op=operation),
destination=f1d,
tooltip="Activate rectangle inspector tool",
continuous_update=continuous_update,
)
case 'polygon':
da = bin_centers_node()
xdim = f2d.canvas.dims['x']
ydim = f2d.canvas.dims['y']
x = da.coords[xdim]
y = da.coords[ydim]
sizes = {**x.sizes, **y.sizes}
xx = sc.broadcast(x, sizes=sizes)
yy = sc.broadcast(y, sizes=sizes)
points = np.column_stack([xx.values.ravel(), yy.values.ravel()])
non_nan = ~sc.isnan(da.data)
tool = PolygonTool(
figure=f2d,
input_node=bin_centers_node,
func=partial(
_mask_outside_polygon,
points=points,
sizes=sizes,
op=operation,
non_nan=non_nan,
),
destination=f1d,
tooltip="Activate polygon inspector tool",
continuous_update=continuous_update,
)
f2d.toolbar['inspect'] = tool
out = [f2d, f1d]
if orientation == 'horizontal':
out = [out]
return Box(out)