Source code for plopp.backends.matplotlib.scatter

# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

import uuid
from typing import Literal

import numpy as np
import scipp as sc
from matplotlib.lines import Line2D

from ...core.utils import merge_masks
from ...graphics.bbox import BoundingBox, axis_bounds
from ...graphics.colormapper import ColorMapper
from ..common import check_ndim
from .canvas import Canvas
from .utils import parse_dicts_in_kwargs


[docs] class Scatter: """ Artist to represent a two-dimensional scatter plot. Parameters ---------- canvas: The canvas that will display the scatter plot. data: The initial data to create the line from. x: The name of the coordinate that is to be used for the X positions. y: The name of the coordinate that is to be used for the Y positions. uid: The unique identifier of the artist. If None, a random UUID is generated. size: The size of the markers. color: The color of the markers (this is ignored if a colorbar is used). artist_number: Number of the artist (can be used to set the color of the artist). colormapper: The colormapper to use for the scatter plot. mask_color: The color of the masked points. cbar: Whether to use a colorbar. """
[docs] def __init__( self, canvas: Canvas, data: sc.DataArray, x: str = 'x', y: str = 'y', uid: str | None = None, size: str | float | None = None, artist_number: int = 0, colormapper: ColorMapper | None = None, mask_color: str = 'black', cbar: bool = False, **kwargs, ): check_ndim(data, ndim=1, origin='Scatter') self.uid = uid if uid is not None else uuid.uuid4().hex self._canvas = canvas self._ax = self._canvas.ax self._data = data self._x = x self._y = y self._size = size self._colormapper = colormapper # Because all keyword arguments from the figure are forwarded to both the canvas # and the line, we need to remove the arguments that belong to the canvas. kwargs.pop('ax', None) if 's' in kwargs: raise ValueError("Use 'size' instead of 's' for scatter plot.") scatter_kwargs = parse_dicts_in_kwargs(kwargs, name=data.name) self.label = data.name if not cbar else None self._unit = self._data.unit self._id = uuid.uuid4().hex markers = list(Line2D.markers.keys()) default_plot_style = { 'marker': markers[(artist_number + 2) % len(markers)], } if not cbar: default_plot_style['color'] = f'C{artist_number}' merged_kwargs = {**default_plot_style, **scatter_kwargs} marker_size = ( self._data.coords[self._size].values if isinstance(self._size, str) else self._size ) self._scatter = self._ax.scatter( self._data.coords[self._x].values, self._data.coords[self._y].values, s=marker_size, label=self.label, **merged_kwargs, ) if self._colormapper is not None: self._colormapper.add_artist(self.uid, self) self._scatter.set_array(None) xmask = self._data.coords[self._x].values.copy() ymask = self._data.coords[self._y].values.copy() visible_mask = False if self._data.masks: not_one_mask = ~merge_masks(self._data.masks).values xmask[not_one_mask] = np.nan ymask[not_one_mask] = np.nan visible_mask = True self._mask = self._ax.scatter( xmask, ymask, s=marker_size, marker=merged_kwargs['marker'], edgecolors=mask_color, facecolor="None", linewidth=3.0, zorder=self._scatter.get_zorder() + 1, visible=visible_mask, )
def notify_artist(self, message: str) -> None: """ Receive notification from the colormapper that its state has changed. We thus need to update the colors of the points. Parameters ---------- message: The message from the colormapper. """ self._update_colors() def _update_colors(self): """ Update the colors of the scatter points. """ self._scatter.set_facecolors(self._colormapper.rgba(self.data)) def update(self, new_values: sc.DataArray): """ Update the x and y positions of the data points from new data. Parameters ---------- new_values: New data to update the line values, masks, errorbars from. """ check_ndim(new_values, ndim=1, origin='Scatter') self._data = new_values self._scatter.set_offsets( np.stack( [self._data.coords[self._x].values, self._data.coords[self._y].values], axis=1, ) ) if isinstance(self._size, str): self._scatter.set_sizes(self._data.coords[self._size].values) if self._colormapper is not None: self._update_colors() def remove(self): """ Remove the scatter and mask artists from the canvas. """ self._scatter.remove() self._mask.remove() if self._colormapper is not None: self._colormapper.remove_artist(self.uid) @property def data(self): """ """ return self._data def bbox(self, xscale: Literal['linear', 'log'], yscale: Literal['linear', 'log']): """ The bounding box of the scatter points. """ scatter_x = self._data.coords[self._x] scatter_y = self._data.coords[self._y] return BoundingBox( **{**axis_bounds(('xmin', 'xmax'), scatter_x, xscale, pad=True)}, **{**axis_bounds(('ymin', 'ymax'), scatter_y, yscale, pad=True)}, )