tree.py 7.8 KB
Newer Older
1
import re
2
from functools import wraps
HansBug's avatar
HansBug 已提交
3
from queue import Queue
4
from typing import Optional, Mapping, Any, Callable
HansBug's avatar
HansBug 已提交
5

6
from graphviz import Digraph
HansBug's avatar
HansBug 已提交
7 8
from treelib import Tree as LibTree

9
from .func import dynamic_call
10 11
from .random import random_hex_with_timestamp

HansBug's avatar
HansBug 已提交
12 13 14 15
_ROOT_ID = '_root'
_NODE_ID_TEMP = '_node_{id}'


16
def build_tree(root_node, repr_gen=None, iter_gen=None) -> LibTree:
17 18 19 20 21
    """
    Overview:
        Build a treelib object by an object.

    Arguments:
22 23 24 25
        - root_node (:obj:`Any`): Root object.
        - repr_gen (:obj:`Optional[Callable]`): Represent function, default is primitive `repr`.
        - iter_gen (:obj:`Optional[Callable]`): Iterate function, \
            default is `lambda x: x.items() if hasattr(x, 'items') else None`.
26 27 28 29 30 31 32

    Returns:
        - tree (:obj:`treelib.Tree`): Built tree.

    Example:
         >>> t = build_tree(
         >>>     {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}, 'z': [1, 2], 'v': {'1': '2'}},
33
         >>>     repr_gen=lambda x: '<node>' if isinstance(x, dict) else repr(x),
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
         >>> )
         >>> print(t)

         The output should be

         >>> <node>
         >>> ├── 'a' --> 1
         >>> ├── 'b' --> 2
         >>> ├── 'v' --> <node>
         >>> │   └── '1' --> '2'
         >>> ├── 'x' --> <node>
         >>> │   ├── 'c' --> 3
         >>> │   └── 'd' --> 4
         >>> └── 'z' --> [1, 2]
    """
49 50
    repr_gen = repr_gen or repr
    iter_gen = iter_gen or (lambda x: x.items() if hasattr(x, 'items') else None)
HansBug's avatar
HansBug 已提交
51 52

    _tree = LibTree()
53
    _tree.create_node(repr_gen(root_node), _ROOT_ID)
HansBug's avatar
HansBug 已提交
54
    _index, _queue = 0, Queue()
55
    _queue.put((_ROOT_ID, root_node))
HansBug's avatar
HansBug 已提交
56 57 58 59

    while not _queue.empty():
        _parent_id, _parent_tree = _queue.get()

60
        for key, value in iter_gen(_parent_tree):
HansBug's avatar
HansBug 已提交
61 62 63
            _index += 1
            _current_id = _NODE_ID_TEMP.format(id=_index)
            _tree.create_node(
64
                "{key} --> {value}".format(key=repr(key), value=repr_gen(value)),
HansBug's avatar
HansBug 已提交
65 66 67
                _current_id,
                _parent_id
            )
68
            if iter_gen(value):
HansBug's avatar
HansBug 已提交
69 70 71
                _queue.put((_current_id, value))

    return _tree
72 73 74 75 76 77 78 79 80 81 82 83


_NAME_PATTERN = re.compile('^[a-zA-Z_][a-zA-Z0-9_]*$')


def _title_flatten(title):
    title = re.sub(r'[^a-zA-Z0-9_]+', '_', str(title))
    title = re.sub(r'_+', '_', title)
    title = title.strip('_').lower()
    return title


84 85
def _no_none_value(dict_) -> dict:
    return type(dict_)({key: value for key, value in dict_.items() if value is not None})
86

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181

def _none_value_filter(func):
    @wraps(func)
    def _new_func(*args, **kwargs):
        return _no_none_value(func(*args, **kwargs))

    return _new_func


_SUFFIXED = '__suffixed__'


def suffixed_node_id(func):
    if getattr(func, _SUFFIXED, None):
        return func

    func = dynamic_call(func)

    @wraps(func)
    @dynamic_call
    def _new_func(current, parent, current_path, parent_path, is_node):
        if is_node:
            return func(current, current_path)
        else:
            return '%s__%s' % (func(parent, parent_path), current_path[-1])

    setattr(_new_func, _SUFFIXED, True)
    return _new_func


@suffixed_node_id
@dynamic_call
def _default_node_id(current):
    return 'node_%x' % (id(current),)


def _root_process(root, index):
    if isinstance(root, tuple):
        if len(root) < 1:
            return None
        elif len(root) == 1:
            return _root_process(root[0], index)
        else:
            return root[0], str(root[1])
    else:
        return root, '<root_%d>' % (index,)


def build_graph(*roots, node_id_gen: Optional[Callable] = None,
                graph_title: Optional[str] = None, graph_name: Optional[str] = None,
                graph_cfg: Optional[Mapping[str, Any]] = None,
                repr_gen: Optional[Callable] = None, iter_gen: Optional[Callable] = None,
                node_cfg_gen: Optional[Callable] = None, edge_cfg_gen: Optional[Callable] = None) -> Digraph:
    """
    Overview:
        Build a graphviz graph based on given tree structure.

    Arguments:
        - roots: Root nodes of the graph.
        - node_id_gen (:obj:`Optional[Callable]`): Node id generation function, \
            default is `None` which means based on object id.
        - graph_title (:obj:`Optional[str]`): Title of the graph, \
            default is `None` which means generate automatically based on timestamp.
        - graph_name (:obj:`Optional[str]`): Name of the graph, \
            default is `None` which means auto generated based on graph title.
        - graph_cfg (:obj:`Optional[Mapping[str, Any]]`): Configuration of graph, \
            default is `None` which means no configuration.
        - repr_gen (:obj:`Optional[Callable]`): Representation format generator, \
            default is `None` which means using `repr` function.
        - iter_gen (:obj:`Optional[Callable]`): Iterator generator, \
            default is `None` which means load from `items` method.
        - node_cfg_gen (:obj:`Optional[Callable]`): Node configuration generator, \
            default is `None` which means no configuration.
        - edge_cfg_gen (:obj:`Optional[Callable]`): Edge configuration generator, \
            default is `None` which means no configuration.

    Returns:
        - dot (:obj:`Digraph`): Graphviz directed graph object.
    """
    roots = [_root_process(root, index) for index, root in enumerate(roots)]
    roots = [item for item in roots if item is not None]

    node_id_gen = dynamic_call(suffixed_node_id(node_id_gen or _default_node_id))
    graph_title = graph_title or ('untitled_' + random_hex_with_timestamp())
    graph_name = graph_name or _title_flatten(graph_title)
    graph_cfg = _no_none_value(graph_cfg or {})

    repr_gen = dynamic_call(repr_gen or repr)
    iter_gen = dynamic_call(iter_gen or (lambda x: x.items() if hasattr(x, 'items') else None))
    node_cfg_gen = _none_value_filter(dynamic_call(node_cfg_gen or (lambda: {})))
    edge_cfg_gen = _none_value_filter(dynamic_call(edge_cfg_gen or (lambda: {})))

    graph = Digraph(name=graph_name, comment=graph_title)
    graph.graph_attr.update(graph_cfg or {})
    graph.graph_attr.update({'label': graph_title})
HansBug's avatar
HansBug 已提交
182 183

    _queue = Queue()
184 185 186 187 188 189 190 191 192 193 194
    _queued_node_ids = set()
    _queued_edges = set()
    for root, root_title in roots:
        root_node_id = node_id_gen(root, None, [], [], True)
        if root_node_id not in _queued_node_ids:
            graph.node(
                name=root_node_id, label=root_title,
                **node_cfg_gen(root, [])
            )
            _queue.put((root_node_id, root, root_title, []))
            _queued_node_ids.add(root_node_id)
HansBug's avatar
HansBug 已提交
195 196

    while not _queue.empty():
197
        _parent_id, _parent_node, _root_title, _parent_path = _queue.get()
HansBug's avatar
HansBug 已提交
198

199
        for key, _current_node in iter_gen(_parent_node, _parent_path):
HansBug's avatar
HansBug 已提交
200
            _current_path = [*_parent_path, key]
201 202 203 204
            _current_id = node_id_gen(_current_node, _parent_node, _current_path, _parent_path,
                                      not not iter_gen(_current_node, _current_path))
            if iter_gen(_current_node, _current_path):
                _current_label = '.'.join([_root_title, *_current_path])
HansBug's avatar
HansBug 已提交
205
            else:
206 207 208 209 210 211 212 213 214 215 216
                _current_label = repr_gen(_current_node, _current_path)

            if _current_id not in _queued_node_ids:
                graph.node(_current_id, label=_current_label, **node_cfg_gen(_current_node, _current_path))
                if iter_gen(_current_node, _current_path):
                    _queue.put((_current_id, _current_node, _root_title, _current_path))
                _queued_node_ids.add(_current_id)
            if (_parent_id, _current_id) not in _queued_edges:
                graph.edge(_parent_id, _current_id, label=key,
                           **edge_cfg_gen(_current_node, _parent_node, _current_path, _parent_path))
                _queued_edges.add((_parent_id, _current_id))
HansBug's avatar
HansBug 已提交
217 218

    return graph