Source code for plopp.widgets.drawing

# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

from collections.abc import Callable
from functools import partial
from typing import Any

import scipp as sc

from ..core import Node, node
from ..core.typing import FigureLike
from ..graphics import BaseFig
from .tools import ToggleTool


def is_figure(x):
    answer = isinstance(x, BaseFig)
    return answer


[docs] class DrawingTool(ToggleTool): """ Interface between Plopp and Mpltoolbox. Parameters ---------- figure: The figure where the tool will draw things (points, lines, shapes...). input_node: The node that provides the raw data which is shown in ``figure``. tool: The Mpltoolbox tool to use (Points, Lines, Rectangles, Ellipses...). func: The function to be used to make a node whose parents will be the ``input_node`` and a node yielding the current state of the tool (current position, size). destination: Where the output from the ``func`` node will be then sent on. This can either be a figure, or another graph node. get_artist_info: A function that returns another function which will convert the properties of the artist that produced the event to something (usually a dict) that is usable by the ``destination``. value: Activate the tool upon creation if ``True``. continuous_update: If ``True``, the tool will update the nodes as a drawing object changes. If ``False``, destination will be updated only when the user releases the mouse button. In other words, it can be set ``True`` for tools that need fast feedback, or ``False`` for tools that use computationally expensive functions. **kwargs: Additional arguments are forwarded to the ``ToggleTool`` constructor. """
[docs] def __init__( self, figure: FigureLike, input_node: Node, tool: Any, func: Callable, destination: FigureLike | Node, get_artist_info: Callable, value: bool = False, continuous_update: bool = True, **kwargs, ): super().__init__(callback=self.start_stop, value=value, **kwargs) self._figure = figure self._input_node = input_node self._draw_nodes = {} self._output_nodes = {} self._func = func self._tool = tool(ax=self._figure.ax, autostart=False) self._destination = destination self._destination_is_fig = is_figure(self._destination) self._get_artist_info = get_artist_info self._tool.on_create(self.make_node) self._tool.on_remove(self.remove_node) if continuous_update: self._tool.on_change(self.update_node) else: self._tool.on_vertex_release(self.update_node) self._tool.on_drag_release(self.update_node)
def make_node(self, artist): draw_node = Node(self._get_artist_info(artist=artist, figure=self._figure)) draw_node.name = f'Draw node {len(self._draw_nodes)}' nodeid = draw_node.id self._draw_nodes[nodeid] = draw_node artist.nodeid = nodeid output_node = node(self._func)(self._input_node, draw_node) output_node.name = f'Output node {len(self._output_nodes)}' self._output_nodes[nodeid] = output_node if self._destination_is_fig: output_node.add_view(self._destination.view) self._destination.update({output_node.id: output_node()}) self._destination.artists[output_node.id].color = ( artist.color if hasattr(artist, 'color') else artist.edgecolor ) elif isinstance(self._destination, Node): self._destination.add_parents(output_node) self._destination.notify_children(artist) def update_node(self, artist): n = self._draw_nodes[artist.nodeid] n.func = self._get_artist_info(artist=artist, figure=self._figure) n.notify_children(artist) def remove_node(self, artist): nodeid = artist.nodeid draw_node = self._draw_nodes.pop(nodeid) output_node = self._output_nodes[nodeid] if self._destination_is_fig: self._destination.artists[output_node.id].remove() del self._destination.artists[output_node.id] self._destination.canvas.draw() output_node.remove() draw_node.remove() def start_stop(self): """ Toggle start or stop of the tool. """ if self.value: self._tool.start() else: self._tool.stop()
def _get_points_info(artist, figure): """ Convert the raw (x, y) position of a point to a dict containing the dimensions of each axis, and scalar values with units. """ return lambda: { 'x': { 'dim': figure.canvas.dims['x'], 'value': sc.scalar(artist.x, unit=figure.canvas.units['x']), }, 'y': { 'dim': figure.canvas.dims['y'], 'value': sc.scalar(artist.y, unit=figure.canvas.units['y']), }, } def _make_points(**kwargs): """ Intermediate function needed for giving to `partial` to avoid making mpltoolbox a hard dependency. """ from mpltoolbox import Points return Points(**kwargs) PointsTool = partial( DrawingTool, tool=partial(_make_points, mec='w'), get_artist_info=_get_points_info, icon='crosshairs', ) """ Tool to add point markers onto a figure. Parameters ---------- figure: The figure where the tool will draw things (points, lines, shapes...). input_node: The node that provides the raw data which is shown in ``figure``. func: The function to be used to make a node whose parents will be the ``input_node`` and a node yielding the current state of the tool (current position, size). destination: Where the output from the ``func`` node will be then sent on. This can either be a figure, or another graph node. value: Activate the tool upon creation if ``True``. **kwargs: Additional arguments are forwarded to the ``ToggleTool`` constructor. """