# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
import warnings
from typing import Literal
import matplotlib.pyplot as plt
import numpy as np
import scipp as sc
from matplotlib import dates as mdates
from mpl_toolkits.axes_grid1 import make_axes_locatable
from ...core.utils import maybe_variable_to_number, scalar_to_string
from ...graphics.bbox import BoundingBox
from .utils import fig_to_bytes, is_sphinx_build, make_figure, make_legend
def _cursor_value_to_variable(x: float, dtype: sc.DType, unit: str) -> sc.Variable:
if dtype == sc.DType.datetime64:
# Annoying chain of conversion but matplotlib has its own way of converting
# dates to numbers (number of days since epoch), and num2date returns a python
# datetime object, while scipp expects a numpy datetime64.
return sc.scalar(np.datetime64(mdates.num2date(x).replace(tzinfo=None))).to(
unit=unit
)
return sc.scalar(x, unit=unit)
def _cursor_formatter(x: float, dtype: sc.DType, unit: str) -> str:
if dtype == sc.DType.datetime64:
return mdates.num2date(x).replace(tzinfo=None).isoformat()
return scalar_to_string(sc.scalar(x, unit=unit))
def _maybe_trim_polar_limits(
axis_type: str, limits: tuple[float, float]
) -> tuple[float, float]:
"""
If the axes are polar, trim the limits of the polar plot to be within the range
[0, 2π].
Parameters
----------
axis_type:
The type of the axis. If this is not 'polar', the limits are returned as is.
limits:
The limits of the axis.
"""
if axis_type != 'polar':
return limits
return tuple(np.clip(limits, 0, 2 * np.pi))
[docs]
class Canvas:
"""
Matplotlib-based canvas used to render 2D graphics.
It provides a figure and some axes, as well as functions for controlling the zoom,
panning, and the scale of the axes.
Parameters
----------
ax:
If supplied, use these axes to create the figure. If none are supplied, the
canvas will create its own axes.
cax:
If supplied, use these axes for the colorbar. If none are supplied, and a
colorbar is required, the canvas will create its own axes.
figsize:
The width and height of the figure, in inches.
title:
The title to be placed above the figure.
grid:
Display the figure grid if ``True``.
user_vmin:
The minimum value for the y axis (1d plots only).
user_vmax:
The maximum value for the y axis (1d plots only).
aspect:
The aspect ratio for the axes.
cbar:
Add axes to host a colorbar if ``True``.
legend:
Show legend if ``True``. If ``legend`` is a tuple, it should contain the
``(x, y)`` coordinates of the legend's anchor point in axes coordinates.
"""
[docs]
def __init__(
self,
ax: plt.Axes = None,
cax: plt.Axes = None,
figsize: tuple[float, float] | None = None,
title: str | None = None,
grid: bool = False,
user_vmin: sc.Variable | float = None,
user_vmax: sc.Variable | float = None,
aspect: Literal['auto', 'equal', None] = None,
cbar: bool = False,
legend: bool | tuple[float, float] = True,
**ignored,
):
# Note on the `**ignored`` keyword arguments: the figure which owns the canvas
# creates both the canvas and an artist object (Line or Image). The figure
# accepts keyword arguments, and has to somehow forward them to the canvas and
# the artist. Since the figure has no detailed knowledge of the underlying
# backend that implements the canvas, it cannot have specific arguments (such
# as `ax` for specifying Matplotlib axes).
# Instead, we forward all the kwargs from the figure to both the canvas and the
# artist, and filter out the artist kwargs with `**ignored`.
self.fig = None
self.ax = ax
self.cax = cax
self.bbox = BoundingBox()
self._user_vmin = user_vmin
self._user_vmax = user_vmax
self.units = {}
self.dims = {}
self._legend = legend
if self.ax is None:
self.fig = make_figure(figsize=(6.0, 4.0) if figsize is None else figsize)
self.ax = self.fig.add_subplot()
if self.is_widget():
self.fig.canvas.toolbar_visible = False
self.fig.canvas.header_visible = False
else:
self.fig = self.ax.get_figure()
if aspect is not None:
self.ax.set_aspect(aspect)
if cbar and (self.cax is None):
if self.ax.name == 'polar':
bounds = self.ax.get_position().bounds
self.cax = self.fig.add_axes(
[bounds[0] + bounds[2] + 0.1, 0.1, 0.03, 0.8]
)
else:
divider = make_axes_locatable(self.ax)
self.cax = divider.append_axes("right", "4%", pad="5%")
self.ax.grid(grid)
if title:
self.ax.set_title(title)
self._coord_formatters = []
def is_widget(self):
return hasattr(self.fig.canvas, "on_widget_constructed")
def to_image(self):
"""
Convert the underlying Matplotlib figure to an image widget from ``ipywidgets``.
"""
from ipywidgets import Image
return Image(value=fig_to_bytes(self.fig), format='png')
def to_widget(self):
from ipywidgets import VBox
if self.is_widget() and not is_sphinx_build():
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.fig.tight_layout()
except RuntimeError:
pass
# The Matplotlib canvas tries to fill the entire width of the output cell,
# which can add unnecessary whitespace between it and other widgets. To
# prevent this, we wrap the canvas in a VBox, which seems to help.
return VBox([self.fig.canvas])
return self.to_image()
def draw(self):
"""
Make a draw call to the underlying figure.
"""
self.fig.canvas.draw_idle()
def update_legend(self):
"""
Update the legend on the canvas.
"""
if self._legend:
handles, labels = self.ax.get_legend_handles_labels()
if len(handles) > 1:
self.ax.legend(handles, labels, **make_legend(self._legend))
elif (leg := self.ax.get_legend()) is not None:
leg.remove()
def save(self, filename: str, **kwargs):
"""
Save the figure to file.
The default directory for writing the file is the same as the
directory where the script or notebook is running.
Parameters
----------
filename:
Name of the output file. Possible file extensions are ``.jpg``, ``.png``,
``.svg``, and ``.pdf``.
"""
self.fig.savefig(filename, **{**{'bbox_inches': 'tight'}, **kwargs})
def set_axes(self, dims, units, dtypes):
"""
Set the axes dimensions and units.
Parameters
----------
dims:
The dimensions of the data.
units:
The units of the data.
dtypes:
The data types of the data.
"""
self.units = units
self.dims = dims
self.dtypes = dtypes
self._cursor_x_prefix = ''
self._cursor_y_prefix = ''
if 'y' in self.dims:
self._cursor_x_prefix = self.dims['x'] + '='
self._cursor_y_prefix = self.dims['y'] + '='
self.ax.format_coord = self.format_coord
key = 'y' if 'y' in self.units else 'data'
self.bbox = BoundingBox(
ymin=maybe_variable_to_number(self._user_vmin, unit=self.units[key]),
ymax=maybe_variable_to_number(self._user_vmax, unit=self.units[key]),
)
def register_format_coord(self, formatter):
"""
Register a custom axis formatter for the x-axis.
"""
self._coord_formatters.append(formatter)
def format_coord(self, x: float, y: float) -> str:
"""
Format the coordinates of the mouse pointer to show the value of the
data at that point.
Parameters
----------
x:
The x coordinate of the mouse pointer.
y:
The y coordinate of the mouse pointer.
"""
xstr = _cursor_formatter(x, self.dtypes['x'], self.units['x'])
key = 'y' if 'y' in self.dtypes else 'data'
ystr = _cursor_formatter(y, self.dtypes[key], self.units[key])
out = f"({self._cursor_x_prefix}{xstr}, {self._cursor_y_prefix}{ystr})"
if not self._coord_formatters:
return out
xpos = (
self.dims['x'],
_cursor_value_to_variable(x, self.dtypes['x'], self.units['x']),
)
ypos = (
(
self.dims['y'],
_cursor_value_to_variable(y, self.dtypes['y'], self.units['y']),
)
if 'y' in self.dims
else None
)
extra = [formatter(xpos, ypos) for formatter in self._coord_formatters]
extra = [e for e in extra if e is not None]
if extra:
out += ": {" + ", ".join(extra) + "}"
return out
@property
def empty(self) -> bool:
"""
Check if the canvas is empty.
"""
return not self.dims
@property
def title(self) -> str:
"""
Get or set the title of the plot.
"""
return self.ax.get_title()
@title.setter
def title(self, text: str):
self.ax.set_title(text)
@property
def xlabel(self) -> str:
"""
Get or set the label of the x-axis.
"""
return self.ax.get_xlabel()
@xlabel.setter
def xlabel(self, lab: str):
self.ax.set_xlabel(lab)
@property
def ylabel(self) -> str:
"""
Get or set the label of the y-axis.
"""
return self.ax.get_ylabel()
@ylabel.setter
def ylabel(self, lab: str):
self.ax.set_ylabel(lab)
@property
def cblabel(self) -> str:
"""
Get or set the label of the colorbar.
"""
return self.cax.get_ylabel()
@cblabel.setter
def cblabel(self, lab: str):
self.cax.set_ylabel(lab)
@property
def xscale(self) -> Literal['linear', 'log']:
"""
Get or set the scale of the x-axis ('linear' or 'log').
"""
return self.ax.get_xscale()
@xscale.setter
def xscale(self, scale: Literal['linear', 'log']):
self.ax.set_xscale(scale)
@property
def yscale(self) -> Literal['linear', 'log']:
"""
Get or set the scale of the y-axis ('linear' or 'log').
"""
return self.ax.get_yscale()
@yscale.setter
def yscale(self, scale: Literal['linear', 'log']):
self.ax.set_yscale(scale)
@property
def xmin(self) -> float:
"""
Get or set the lower (left) bound of the x-axis.
"""
return self.ax.get_xlim()[0]
@xmin.setter
def xmin(self, value: float):
self.ax.set_xlim(
_maybe_trim_polar_limits(axis_type=self.ax.name, limits=(value, self.xmax))
)
@property
def xmax(self) -> float:
"""
Get or set the upper (right) bound of the x-axis.
"""
return self.ax.get_xlim()[1]
@xmax.setter
def xmax(self, value: float):
self.ax.set_xlim(
_maybe_trim_polar_limits(axis_type=self.ax.name, limits=(self.xmin, value))
)
@property
def xrange(self) -> tuple[float, float]:
"""
Get or set the range/limits of the x-axis.
"""
return self.ax.get_xlim()
@xrange.setter
def xrange(self, value: tuple[float, float]):
self.ax.set_xlim(_maybe_trim_polar_limits(axis_type=self.ax.name, limits=value))
@property
def ymin(self) -> float:
"""
Get or set the lower (bottom) bound of the y-axis.
"""
return self.ax.get_ylim()[0]
@ymin.setter
def ymin(self, value: float):
self.ax.set_ylim(value, self.ymax)
@property
def ymax(self) -> float:
"""
Get or set the upper (top) bound of the y-axis.
"""
return self.ax.get_ylim()[1]
@ymax.setter
def ymax(self, value: float):
self.ax.set_ylim(self.ymin, value)
@property
def yrange(self) -> tuple[float, float]:
"""
Get or set the range/limits of the y-axis.
"""
return self.ax.get_ylim()
@yrange.setter
def yrange(self, value: tuple[float, float]):
self.ax.set_ylim(value)
@property
def grid(self) -> bool:
"""
Get or set the visibility of the grid.
"""
return self.ax.axes.get_xgridlines()[0].get_visible()
@grid.setter
def grid(self, visible: bool):
self.ax.grid(visible)
def reset_mode(self):
"""
Reset the Matplotlib toolbar mode to nothing, to disable all Zoom/Pan tools.
"""
if self.fig.canvas.toolbar.mode == 'zoom rect':
self.zoom()
elif self.fig.canvas.toolbar.mode == 'pan/zoom':
self.pan()
def zoom(self):
"""
Activate the underlying Matplotlib zoom tool.
"""
self.fig.canvas.toolbar.zoom()
def pan(self):
"""
Activate the underlying Matplotlib pan tool.
"""
self.fig.canvas.toolbar.pan()
def panzoom(self, value: Literal['pan', 'zoom', None]):
"""
Activate or deactivate the pan or zoom tool, depending on the input value.
"""
if value == 'zoom':
self.zoom()
elif value == 'pan':
self.pan()
elif value is None:
self.reset_mode()
def download_figure(self):
"""
Save the figure to a PNG file via a pop-up dialog.
"""
self.fig.canvas.toolbar.save_figure()
def logx(self):
"""
Toggle the scale between ``linear`` and ``log`` along the horizontal axis.
"""
self.xscale = 'log' if self.xscale == 'linear' else 'linear'
def logy(self):
"""
Toggle the scale between ``linear`` and ``log`` along the vertical axis.
"""
self.yscale = 'log' if self.yscale == 'linear' else 'linear'