Source code for plopp.plotting.slicer

# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

from functools import partial
from itertools import groupby
from typing import Literal

import scipp as sc

from ..core import widget_node
from ..core.typing import FigureLike, PlottableMulti
from ..graphics import imagefigure, linefigure
from .common import (
    categorize_args,
    input_to_nodes,
    preprocess,
    raise_multiple_inputs_for_2d_plot_error,
    require_interactive_figure,
)


class Slicer:
    """
    Class that slices out dimensions from the data and displays the resulting data as
    either a 1D line or a 2D image.

    Note:

    This class primarily exists to facilitate unit testing. When running unit tests, we
    are not in a Jupyter notebook, and the generated figures are not widgets that can
    be placed in the `Box` widget container at the end of the `slicer` function.
    We therefore place most of the code for creating a Slicer in this class, which is
    under unit test coverage. The thin `slicer` wrapper is not covered by unit tests.

    Parameters
    ----------
    obj:
        The input data.
    coords:
        If supplied, use these coords instead of the input's dimension coordinates.
    enable_player:
        If ``True``, add a play button to the sliders to automatically step through
        the slices.
    keep:
        The dimensions to be kept, all remaining dimensions will be sliced. This should
        be a list of dims. If no dims are provided, the last dim will be kept in the
        case of a 2-dimensional input, while the last two dims will be kept in the case
        of higher dimensional inputs.
    **kwargs:
        The additional arguments are forwarded to the underlying 1D or 2D figures.
    """

    def __init__(
        self,
        obj: PlottableMulti,
        *,
        coords: list[str] | None = None,
        enable_player: bool = False,
        keep: list[str] | None = None,
        **kwargs,
    ):
        nodes = input_to_nodes(
            obj,
            processor=partial(preprocess, ignore_size=True, coords=coords),
        )

        dims = nodes[0]().dims
        if keep is None:
            keep = dims[-(2 if len(dims) > 2 else 1) :]

        if isinstance(keep, str):
            keep = [keep]

        # Ensure all dims in keep have the same size
        sizes = [
            {dim: shape for dim, shape in node().sizes.items() if dim not in keep}
            for node in nodes
        ]
        g = groupby(sizes)
        if not (next(g, True) and not next(g, False)):
            raise ValueError(
                'Slicer plot: all inputs must have the same sizes, but '
                f'the following sizes were found: {sizes}'
            )

        if len(keep) == 0:
            raise ValueError(
                'Slicer plot: the list of dims to be kept cannot be empty.'
            )
        if not all(dim in dims for dim in keep):
            raise ValueError(
                f"Slicer plot: one or more of the requested dims to be kept {keep} "
                f"were not found in the input's dimensions {dims}."
            )

        from ..widgets import SliceWidget, slice_dims

        self.slider = SliceWidget(
            nodes[0](),
            dims=[dim for dim in dims if dim not in keep],
            enable_player=enable_player,
        )
        self.slider_node = widget_node(self.slider)
        self.slice_nodes = [slice_dims(node, self.slider_node) for node in nodes]

        args = categorize_args(**kwargs)

        ndims = len(keep)
        if ndims == 1:
            make_figure = partial(linefigure, **args['1d'])
        elif ndims == 2:
            if len(self.slice_nodes) > 1:
                raise_multiple_inputs_for_2d_plot_error(origin='slicer')
            make_figure = partial(imagefigure, **args['2d'])
        else:
            raise ValueError(
                f'Slicer plot: the number of dims to be kept must be 1 or 2, '
                f'but {ndims} were requested.'
            )

        self.figure = make_figure(*self.slice_nodes)
        require_interactive_figure(self.figure, 'slicer')
        self.figure.bottom_bar.add(self.slider)


[docs] def slicer( obj: PlottableMulti, keep: list[str] | None = None, *, aspect: Literal['auto', 'equal', None] = None, autoscale: bool = True, cbar: bool = True, clabel: str | None = None, cmap: str = 'viridis', cmax: sc.Variable | float | None = None, cmin: sc.Variable | float | None = None, coords: list[str] | None = None, enable_player: bool = False, errorbars: bool = True, figsize: tuple[float, float] | None = None, grid: bool = False, legend: bool | tuple[float, float] = True, logc: bool | None = None, logx: bool | None = None, logy: bool | None = None, mask_cmap: str = 'gray', mask_color: str | None = None, nan_color: str | None = None, norm: Literal['linear', 'log', None] = None, scale: dict[str, str] | None = None, title: str | None = None, vmax: sc.Variable | float | None = None, vmin: sc.Variable | float | None = None, xlabel: str | None = None, xmax: sc.Variable | float | None = None, xmin: sc.Variable | float | None = None, ylabel: str | None = None, ymax: sc.Variable | float | None = None, ymin: sc.Variable | float | None = None, **kwargs, ) -> FigureLike: """ Plot a multi-dimensional object by slicing one or more of the dimensions. This will produce one slider per sliced dimension, below the figure. Parameters ---------- obj: The object to be plotted. keep: The single dimension to be kept, all remaining dimensions will be sliced. This should be a single string. If no dim is provided, the last/inner dim will be kept. aspect: Aspect ratio for the axes. autoscale: Automatically scale the axes/colormap on updates if ``True``. cbar: Show colorbar in 2d plots if ``True``. clabel: Label for colorscale (2d plots only). cmap: The colormap to be used for the colorscale (2d plots only). cmax: Upper limit for colorscale (2d plots only). cmin: Lower limit for colorscale (2d plots only). coords: If supplied, use these coords instead of the input's dimension coordinates. enable_player: If ``True``, add a play button to the sliders to automatically step through the slices. errorbars: Show errorbars in 1d plots if ``True``. figsize: The width and height of the figure, in inches. grid: Show grid if ``True``. legend: Show legend if ``True``. If ``legend`` is a tuple, it should contain the ``(x, y)`` coordinates of the legend's anchor point in axes coordinates. logc: If ``True``, use logarithmic scale for colorscale (2d plots only). logx: If ``True``, use logarithmic scale for x-axis. logy: If ``True``, use logarithmic scale for y-axis. mask_cmap: Colormap to use for masks in 2d plots. mask_color: Color of masks. nan_color: Color to use for NaN values in 2d plots. norm: Set to ``'log'`` for a logarithmic y-axis (1d plots) or logarithmic colorscale (2d plots). Legacy, prefer ``logy`` and ``logc`` instead. scale: Change axis scaling between ``log`` and ``linear``. For example, specify ``scale={'time': 'log'}`` if you want log-scale for the ``time`` dimension. Legacy, prefer ``logx`` and ``logy`` instead. title: The figure title. vmax: Upper limit for data to be displayed (y-axis for 1d plots, colorscale for 2d plots). Legacy, prefer ``ymax`` and ``cmax`` instead. vmin: Lower limit for data to be displayed (y-axis for 1d plots, colorscale for 2d plots). Legacy, prefer ``ymin`` and ``cmin`` instead. xlabel: Label for x-axis. xmax: Upper limit for x-axis. xmin: Lower limit for x-axis. ylabel: Label for y-axis. ymax: Upper limit for y-axis. ymin: Lower limit for y-axis. **kwargs: Additional arguments forwarded to the underlying plotting library. """ return Slicer( obj, keep=keep, aspect=aspect, autoscale=autoscale, cbar=cbar, clabel=clabel, cmap=cmap, cmax=cmax, cmin=cmin, coords=coords, enable_player=enable_player, errorbars=errorbars, figsize=figsize, grid=grid, legend=legend, logc=logc, logx=logx, logy=logy, mask_color=mask_color, nan_color=nan_color, norm=norm, scale=scale, title=title, vmax=vmax, vmin=vmin, xlabel=xlabel, xmax=xmax, xmin=xmin, ylabel=ylabel, ymax=ymax, ymin=ymin, **kwargs, ).figure