Source code for plopp.graphics.colormapper

# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

from copy import copy
from functools import reduce
from typing import Any, Literal

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import scipp as sc
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import Colormap, LinearSegmentedColormap, LogNorm, Normalize

from ..backends.matplotlib.utils import fig_to_bytes
from ..core.limits import find_limits, fix_empty_range
from ..core.utils import maybe_variable_to_number, merge_masks
from ..utils import parse_mutually_exclusive


def _shift_color(color: float, delta: float) -> float:
    """
    Shift a color (number from 0 to 1) by delta. If the result is out of bounds,
    the color is shifted in the opposite direction.
    """
    shifted = color + delta
    if shifted > 1.0 or shifted < 0.0:
        return color - delta
    return shifted


def _get_cmap(colormap: str | Colormap, nan_color: str | None = None) -> Colormap:
    """
    Get a colormap object from a colormap name.
    We also set the 'over', 'under' and 'bad' colors. The 'bad' color is set to
    ``nan_color`` if it is not None. The 'over' and 'under' colors are set to be
    slightly lighter or darker than the first and last colors in the colormap.

    Parameters
    ----------
    colormap:
        Name of the colormap. If the name is just a single html color, this will
        create a colormap with that single color. If ``cmap`` is already a Colormap,
        it will be used as is.
    nan_color:
        The color to use for NAN values.
    """

    if isinstance(colormap, Colormap):
        return colormap

    try:
        if hasattr(mpl, 'colormaps'):
            cmap = copy(mpl.colormaps[colormap])
        else:
            cmap = mpl.cm.get_cmap(colormap)
    except (KeyError, ValueError):
        # Case where we have just a single color
        cmap = LinearSegmentedColormap.from_list('tmp', [colormap, colormap])
        cmap.set_over(colormap)
        cmap.set_under(colormap)
        cmap.set_bad(colormap)
        return cmap

    # Add under and over values to the cmap
    delta = 0.15
    cmap = cmap.copy()
    over = cmap.get_over()
    under = cmap.get_under()
    # Note that we only shift the first 3 RGB values, leaving alpha unchanged.
    cmap.set_over(
        [_shift_color(c, delta * (-1 + 2 * (np.mean(over) > 0.5))) for c in over[:3]]
    )
    cmap.set_under(
        [_shift_color(c, delta * (-1 + 2 * (np.mean(under) > 0.5))) for c in under[:3]]
    )

    if nan_color is not None:
        cmap.set_bad(color=nan_color)
    return cmap


def _get_normalizer(norm: str) -> Normalize:
    """
    Get an appropriate normalizer depending on the scaling.
    """
    return Normalize(vmin=0, vmax=1) if norm == 'linear' else LogNorm(vmin=1, vmax=2)


[docs] class ColorMapper: """ A class that handles conversion between data values and RGBA colors. It controls the normalization (linear or log), as well as the min and max limits for the color range. Parameters ---------- cax: The axes to use for the colorbar. If none are supplied, the ColorMapper will create its own axes. cbar: Create a colorbar if ``True``. If ``False``, no colorbar is made even if ``cax`` is defined. cmap: The name of the colormap for the data (see https://matplotlib.org/stable/tutorials/colors/colormaps.html). In addition to the Matplotlib docs, if the name is just a single html color, a colormap with that single color will be used. mask_cmap: The name of the colormap for masked data. cmin: The minimum value for the colorscale range. If a number (without a unit) is supplied, it is assumed that the unit is the same as the data unit. cmax: The maximum value for the colorscale range. If a number (without a unit) is supplied, it is assumed that the unit is the same as the data unit. logc: If ``True``, use a logarithmic colorscale. nan_color: The color used for representing NAN values. figsize: The size of the figure next to which the colorbar will be displayed. norm: The colorscale normalization. This is an old parameter name. Prefer using ``logc`` instead. vmin: The minimum value for the colorscale range. If a number (without a unit) is supplied, it is assumed that the unit is the same as the data unit. This is an old parameter name. Prefer using ``cmin`` instead. vmax: The maximum value for the colorscale range. If a number (without a unit) is supplied, it is assumed that the unit is the same as the data unit. This is an old parameter name. Prefer using ``cmax`` instead. """
[docs] def __init__( self, canvas: Any | None = None, cbar: bool = True, cmap: str | Colormap = 'viridis', mask_cmap: str | Colormap = 'gray', mask_color: str | None = None, cmin: sc.Variable | float | None = None, cmax: sc.Variable | float | None = None, logc: bool | None = None, clabel: str | None = None, nan_color: str | None = None, figsize: tuple[float, float] | None = None, norm: Literal['linear', 'log'] | None = None, vmin: sc.Variable | float | None = None, vmax: sc.Variable | float | None = None, ): cmin = parse_mutually_exclusive(vmin=vmin, cmin=cmin) cmax = parse_mutually_exclusive(vmax=vmax, cmax=cmax) logc = parse_mutually_exclusive(norm=norm, logc=logc) self._canvas = canvas self.cax = self._canvas.cax if hasattr(self._canvas, 'cax') else None self.cmap = _get_cmap(cmap, nan_color=nan_color) self.mask_cmap = _get_cmap(mask_cmap if mask_color is None else mask_color) if mask_color is None: # If no single color was used for masks, it means a colormap was used # instead. However, `nan_color` was not passed to get_cmap for the masks # (it only applied to non-masked data), so we still need to set the 'bad' # color here. We choose to set it to the 'over' color for a lack of a # better idea. self.mask_cmap.set_bad(self.mask_cmap.get_over()) # Inside the autoscale, we need to distinguish between a min value that was set # by the user and one that was found by looping over all the data. # So we basically always need to keep a backup of the user-set value. self._cmin = {"data": np.inf} self._cmax = {"data": -np.inf} if cmin is not None: self._cmin["user"] = cmin if cmax is not None: self._cmax["user"] = cmax self._clabel = clabel self._logc = False if logc is None else logc self.set_colors_on_update = True # Note that we need to set cmin/cmax for the LogNorm, if not an error is # raised when making the colorbar before any call to update is made. self.normalizer = _get_normalizer('log' if self._logc else 'linear') self.colorbar = None self._unit = None self.empty = True self.changed = False self.artists = {} self.widget = None if cbar: if self.cax is None: dpi = 100 height_inches = (figsize[1] / dpi) if figsize is not None else 6 fig = plt.Figure(figsize=(height_inches * 0.2, height_inches)) self.cax = fig.add_axes([0.05, 0.02, 0.2, 0.98]) self.colorbar = ColorbarBase(self.cax, cmap=self.cmap, norm=self.normalizer) self.cax.yaxis.set_label_coords(-0.9, 0.5) if self._clabel is not None: self.cax.set_ylabel(self._clabel)
def add_artist(self, key: str, artist: Any): self.artists[key] = artist def remove_artist(self, key: str): del self.artists[key] def to_widget(self): """ Convert the colorbar into a widget for use with other ``ipywidgets``. """ import ipywidgets as ipw self.widget = ipw.HTML() self._update_colorbar_widget() return self.widget def _update_colorbar_widget(self): """ Upon an updated colorscale range, we need to update the image inside the widget. """ if self.widget is not None: self.widget.value = fig_to_bytes(self.cax.get_figure(), form='svg').decode() def rgba(self, data: sc.DataArray) -> np.ndarray: """ Return rgba values given a data array. Parameters ---------- data: The data array to be converted to rgba colors, taking masks into account. """ colors = self.cmap(self.normalizer(data.values)) if data.masks: one_mask = sc.broadcast(merge_masks(data.masks), sizes=data.sizes).values colors[one_mask] = self.mask_cmap(self.normalizer(data.values[one_mask])) return colors def autoscale(self): """ Re-compute the global min and max range of values by iterating over all the artists and adjust the limits. """ if ("user" in self._cmin) and ("user" in self._cmax): if self._cmin["user"] >= self._cmax["user"]: raise ValueError('User-set limits: cmin must be smaller than cmax.') self._cmin["data"] = self._cmin["user"] self._cmax["data"] = self._cmax["user"] self.apply_limits() return limits = [ fix_empty_range( find_limits(artist._data, scale='log' if self._logc else 'linear') ) for artist in self.artists.values() ] if "user" not in self._cmin: self._cmin["data"] = reduce(min, [v[0] for v in limits]).value else: self._cmin["data"] = self._cmin["user"] if "user" not in self._cmax: self._cmax["data"] = reduce(max, [v[1] for v in limits]).value else: self._cmax["data"] = self._cmax["user"] if self._cmin["data"] >= self._cmax["data"]: if "user" in self._cmax: self._cmin["data"] = self._cmax["data"] - abs(self._cmax["data"]) * 0.1 else: self._cmax["data"] = self._cmin["data"] + abs(self._cmin["data"]) * 0.1 self.apply_limits() def apply_limits(self): # Synchronize the underlying normalizer limits to the current state. # Note that the order matters here, as for a normalizer cmin cannot be set above # the current cmax. if self._cmin["data"] >= self.normalizer.vmax: self.normalizer.vmax = self._cmax["data"] self.normalizer.vmin = self._cmin["data"] else: self.normalizer.vmin = self._cmin["data"] self.normalizer.vmax = self._cmax["data"] if self.colorbar is not None: self._update_colorbar_widget() self.notify_artists() def notify_artists(self): """ Notify the artists that the state of the colormapper has changed. """ for artist in self.artists.values(): artist.notify_artist('colormap changed') @property def vmin(self) -> float: """ Get or set the minimum value of the colorbar. This is an old property name. Prefer using ``cmin`` instead. """ return self.cmin @vmin.setter def vmin(self, value: sc.Variable | float): self.cmin = value @property def vmax(self) -> float: """ Get or set the maximum value of the colorbar. This is an old property name. Prefer using ``cmax`` instead. """ return self.cmax @vmax.setter def vmax(self, value: sc.Variable | float): self.cmax = value @property def cmin(self) -> float: """ Get or set the minimum value of the colorbar. """ return self._cmin.get("user", self._cmin["data"]) @cmin.setter def cmin(self, value: sc.Variable | float): self._cmin["user"] = maybe_variable_to_number(value, unit=self._unit) self.apply_limits() @property def cmax(self) -> float: """ Get or set the maximum value of the colorbar. """ return self._cmax.get("user", self._cmax["data"]) @cmax.setter def cmax(self, value: sc.Variable | float): self._cmax["user"] = maybe_variable_to_number(value, unit=self._unit) self.apply_limits() @property def unit(self) -> str | None: """ Get or set the unit of the colorbar. """ return self._unit @unit.setter def unit(self, unit: str | None): self._unit = unit if "user" in self._cmin: self._cmin["user"] = maybe_variable_to_number( self._cmin["user"], unit=self._unit ) if "user" in self._cmax: self._cmax["user"] = maybe_variable_to_number( self._cmax["user"], unit=self._unit ) @property def clabel(self) -> str | None: """ Get or set the label of the colorbar axis. """ if self.cax is not None: return self.cax.get_ylabel() @clabel.setter def clabel(self, lab: str): if self.cax is not None: self.cax.set_ylabel(lab) @property def ylabel(self) -> str | None: """ Get or set the label of the colorbar axis. This is an old property name. Prefer using ``clabel`` instead. """ return self.clabel @ylabel.setter def ylabel(self, lab: str): self.clabel = lab def toggle_norm(self): """ Toggle the norm flag, between `linear` and `log`. """ self._logc = not self._logc self.normalizer = _get_normalizer('log' if self._logc else 'linear') self._cmin["data"] = np.inf self._cmax["data"] = -np.inf if self.colorbar is not None: self.colorbar.mappable.norm = self.normalizer self.autoscale() if self._canvas is not None: self._canvas.draw() @property def norm(self) -> Literal['linear', 'log']: """ Get or set the colorscale normalization. """ return 'log' if self._logc else 'linear' @norm.setter def norm(self, value: Literal['linear', 'log']): if value not in ['linear', 'log']: raise ValueError('norm must be either "linear" or "log".') if value != self.norm: self.toggle_norm() def has_user_clabel(self) -> bool: """ Return ``True`` if the user has set a colorbar label. """ return self._clabel is not None