Source code for plopp.core.graph

# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

from html import escape
from itertools import chain

from .node_class import Node
from .view import View


def _make_graphviz_digraph(*args, **kwargs):
    try:
        from graphviz import Digraph
    except ImportError:
        raise RuntimeError(
            "Failed to import `graphviz`. "
            "Use `pip install graphviz` (requires installed `graphviz` executable) or "
            "`conda install -c conda-forge python-graphviz`."
        ) from None
    return Digraph(*args, **kwargs)


def _walk_graph(start, nodes, edges, views, labels):
    label = (
        escape(str(start.func)) + '\nid = ' + start.id
        if start.name is None
        else escape(start.name)
    )
    nodes[start.id] = label
    for child in start.children:
        if start.id not in edges:
            edges[start.id] = {child.id}
        else:
            edges[start.id].add(child.id)
        _walk_graph(
            start=child,
            nodes=nodes,
            edges=edges,
            views=views,
            labels=labels,
        )
    for arg_name, parent in chain(
        ((f'arg_{i}', p) for i, p in enumerate(start.parents)), start.kwparents.items()
    ):
        key = parent.id
        if key not in labels:
            labels[key] = arg_name
        if key not in nodes:
            if key not in edges:
                edges[key] = {start.id}
            else:
                edges[key].add(start.id)
            _walk_graph(
                start=parent,
                nodes=nodes,
                edges=edges,
                views=views,
                labels=labels,
            )

    for view in start.views:
        need_walk = False
        if view.id not in views:
            need_walk = True
        views[view.id] = view.__class__.__name__
        if start.id not in edges:
            edges[start.id] = {view.id}
        else:
            edges[start.id].add(view.id)
        if need_walk:
            for node in view.graph_nodes.values():
                _walk_graph(
                    start=node,
                    nodes=nodes,
                    edges=edges,
                    views=views,
                    labels=labels,
                )


def _make_graph(dot, nodes, edges, labels, views):
    for key, lab in nodes.items():
        dot.node(key, label=lab)
    for key, lab in views.items():
        dot.node(key, label=lab, shape='ellipse', style='filled', color='lightgrey')
    for parent, children in edges.items():
        for child in children:
            dot.edge(
                parent,
                child,
                label=labels.get(parent, '') if child not in views else '',
            )
    return dot


[docs] def show_graph(entry: Node | View, **kwargs): """ Display the connected nodes and views as a graph. Parameters ---------- entry: An entry point in the graph (node or view). This can be any node/view in the graph. The graph will be searched from end to end to construct the diagram. **kwargs: Additional keyword arguments are forwarded to ``graphviz.Digraph``. Returns ------- : A visual representation of the graph generated with Graphviz. """ dot = _make_graphviz_digraph(strict=True, graph_attr=kwargs) dot.attr('node', shape='box', height='0.1') nodes = {} edges = {} views = {} labels = {} # If input is a View, get the underlying node if hasattr(entry, 'graph_nodes'): entry = next(iter(entry.graph_nodes.values())) _walk_graph( start=entry, nodes=nodes, edges=edges, views=views, labels=labels, ) return _make_graph(dot=dot, nodes=nodes, edges=edges, labels=labels, views=views)