tree.py 6.2 KB
Newer Older
1
import re
2
from functools import wraps
HansBug's avatar
HansBug 已提交
3
from queue import Queue
4
from typing import Optional, Mapping, Any, Callable
HansBug's avatar
HansBug 已提交
5

6
from graphviz import Digraph
7
from hbutils.reflection import dynamic_call, post_process
HansBug's avatar
HansBug 已提交
8

9 10 11 12 13 14 15 16

def _title_flatten(title):
    title = re.sub(r'[^a-zA-Z0-9_]+', '_', str(title))
    title = re.sub(r'_+', '_', title)
    title = title.strip('_').lower()
    return title


17 18
def _no_none_value(dict_) -> dict:
    return type(dict_)({key: value for key, value in dict_.items() if value is not None})
19

20

21
def _cfg_func_wrap(func):
22
    @wraps(func)
23
    @post_process(lambda d: type(d)({str(key): str(value) for key, value in d.items()}))
HansBug's avatar
HansBug 已提交
24
    @post_process(_no_none_value)
25
    def _new_func(*args, **kwargs):
HansBug's avatar
HansBug 已提交
26
        return func(*args, **kwargs)
27 28 29 30

    return _new_func


31
SUFFIXED_TAG = '__suffixed__'
32 33 34


def suffixed_node_id(func):
35
    if getattr(func, SUFFIXED_TAG, None):
36 37 38 39 40 41 42 43 44 45 46 47
        return func

    func = dynamic_call(func)

    @wraps(func)
    @dynamic_call
    def _new_func(current, parent, current_path, parent_path, is_node):
        if is_node:
            return func(current, current_path)
        else:
            return '%s__%s' % (func(parent, parent_path), current_path[-1])

48
    setattr(_new_func, SUFFIXED_TAG, True)
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    return _new_func


@suffixed_node_id
@dynamic_call
def _default_node_id(current):
    return 'node_%x' % (id(current),)


def _root_process(root, index):
    if isinstance(root, tuple):
        if len(root) < 1:
            return None
        elif len(root) == 1:
            return _root_process(root[0], index)
        else:
HansBug's avatar
HansBug 已提交
65
            return root[0], str(root[1]), index
66
    else:
HansBug's avatar
HansBug 已提交
67
        return root, '<root_%d>' % (index,), index
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104


def build_graph(*roots, node_id_gen: Optional[Callable] = None,
                graph_title: Optional[str] = None, graph_name: Optional[str] = None,
                graph_cfg: Optional[Mapping[str, Any]] = None,
                repr_gen: Optional[Callable] = None, iter_gen: Optional[Callable] = None,
                node_cfg_gen: Optional[Callable] = None, edge_cfg_gen: Optional[Callable] = None) -> Digraph:
    """
    Overview:
        Build a graphviz graph based on given tree structure.

    Arguments:
        - roots: Root nodes of the graph.
        - node_id_gen (:obj:`Optional[Callable]`): Node id generation function, \
            default is `None` which means based on object id.
        - graph_title (:obj:`Optional[str]`): Title of the graph, \
            default is `None` which means generate automatically based on timestamp.
        - graph_name (:obj:`Optional[str]`): Name of the graph, \
            default is `None` which means auto generated based on graph title.
        - graph_cfg (:obj:`Optional[Mapping[str, Any]]`): Configuration of graph, \
            default is `None` which means no configuration.
        - repr_gen (:obj:`Optional[Callable]`): Representation format generator, \
            default is `None` which means using `repr` function.
        - iter_gen (:obj:`Optional[Callable]`): Iterator generator, \
            default is `None` which means load from `items` method.
        - node_cfg_gen (:obj:`Optional[Callable]`): Node configuration generator, \
            default is `None` which means no configuration.
        - edge_cfg_gen (:obj:`Optional[Callable]`): Edge configuration generator, \
            default is `None` which means no configuration.

    Returns:
        - dot (:obj:`Digraph`): Graphviz directed graph object.
    """
    roots = [_root_process(root, index) for index, root in enumerate(roots)]
    roots = [item for item in roots if item is not None]

    node_id_gen = dynamic_call(suffixed_node_id(node_id_gen or _default_node_id))
105
    graph_title = graph_title or ''
106 107 108 109 110
    graph_name = graph_name or _title_flatten(graph_title)
    graph_cfg = _no_none_value(graph_cfg or {})

    repr_gen = dynamic_call(repr_gen or repr)
    iter_gen = dynamic_call(iter_gen or (lambda x: x.items() if hasattr(x, 'items') else None))
111 112
    node_cfg_gen = _cfg_func_wrap(dynamic_call(node_cfg_gen or (lambda: {})))
    edge_cfg_gen = _cfg_func_wrap(dynamic_call(edge_cfg_gen or (lambda: {})))
113 114 115 116

    graph = Digraph(name=graph_name, comment=graph_title)
    graph.graph_attr.update(graph_cfg or {})
    graph.graph_attr.update({'label': graph_title})
HansBug's avatar
HansBug 已提交
117 118

    _queue = Queue()
119 120
    _queued_node_ids = set()
    _queued_edges = set()
HansBug's avatar
HansBug 已提交
121 122
    for root_info in roots:
        root, root_title, root_index = root_info
123 124 125 126
        root_node_id = node_id_gen(root, None, [], [], True)
        if root_node_id not in _queued_node_ids:
            graph.node(
                name=root_node_id, label=root_title,
HansBug's avatar
HansBug 已提交
127
                **node_cfg_gen(root, None, [], [], True, True, root_info)
128
            )
HansBug's avatar
HansBug 已提交
129
            _queue.put((root_node_id, root, (root, root_title, root_index), []))
130
            _queued_node_ids.add(root_node_id)
HansBug's avatar
HansBug 已提交
131 132

    while not _queue.empty():
HansBug's avatar
HansBug 已提交
133 134
        _parent_id, _parent_node, _root_info, _parent_path = _queue.get()
        _root_node, _root_title, _root_index = _root_info
HansBug's avatar
HansBug 已提交
135

136
        for key, _current_node in iter_gen(_parent_node, _parent_path):
HansBug's avatar
HansBug 已提交
137
            _current_path = [*_parent_path, key]
HansBug's avatar
HansBug 已提交
138 139
            _is_node = not not iter_gen(_current_node, _current_path)
            _current_id = node_id_gen(_current_node, _parent_node, _current_path, _parent_path, _is_node)
140 141
            if iter_gen(_current_node, _current_path):
                _current_label = '.'.join([_root_title, *_current_path])
HansBug's avatar
HansBug 已提交
142
            else:
143 144 145
                _current_label = repr_gen(_current_node, _current_path)

            if _current_id not in _queued_node_ids:
HansBug's avatar
HansBug 已提交
146
                graph.node(_current_id, label=_current_label,
HansBug's avatar
HansBug 已提交
147 148
                           **node_cfg_gen(_current_node, _parent_node, _current_path, _parent_path,
                                          _is_node, False, _root_info))
149
                if iter_gen(_current_node, _current_path):
HansBug's avatar
HansBug 已提交
150
                    _queue.put((_current_id, _current_node, _root_info, _current_path))
151
                _queued_node_ids.add(_current_id)
152
            if (_parent_id, _current_id, key) not in _queued_edges:
153
                graph.edge(_parent_id, _current_id, label=key,
HansBug's avatar
HansBug 已提交
154 155
                           **edge_cfg_gen(_current_node, _parent_node, _current_path, _parent_path,
                                          _is_node, _root_info))
156
                _queued_edges.add((_parent_id, _current_id, key))
HansBug's avatar
HansBug 已提交
157 158

    return graph