未验证 提交 2cc7d7a7 编写于 作者: HansBug's avatar HansBug 😆 提交者: GitHub

Merge pull request #75 from opendilab/dev/repr

dev(hansbug): optimize graphviz visualization
...@@ -66,17 +66,17 @@ jobs: ...@@ -66,17 +66,17 @@ jobs:
shell: bash shell: bash
run: | run: |
sudo apt-get update sudo apt-get update
sudo apt-get install -y tree cloc wget curl make sudo apt-get install -y tree cloc wget curl make graphviz
- name: Set up system dependences on Windows - name: Set up system dependences on Windows
if: ${{ env.OS_NAME == 'Windows' }} if: ${{ env.OS_NAME == 'Windows' }}
shell: bash shell: bash
run: | run: |
choco install tree cloc wget curl make zip choco install tree cloc wget curl make zip graphviz
- name: Set up system dependences on MacOS - name: Set up system dependences on MacOS
if: ${{ env.OS_NAME == 'MacOS' }} if: ${{ env.OS_NAME == 'MacOS' }}
shell: bash shell: bash
run: | run: |
brew install tree cloc wget curl make zip brew install tree cloc wget curl make zip graphviz
- name: Set up python ${{ matrix.python-version }} - name: Set up python ${{ matrix.python-version }}
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:
......
.PHONY: clean
JUPYTER ?= $(shell which jupyter)
NBCONVERT ?= ${JUPYTER} nbconvert
SOURCE ?= .
IPYNBS := $(shell find ${SOURCE} -name *.ipynb | grep -v .ipynb_checkpoints)
clean:
for nb in ${IPYNBS}; do \
if [ -f $$nb ]; then \
$(NBCONVERT) --clear-output --inplace $$nb; \
fi; \
done;
\ No newline at end of file
# Examples of TreeValue
## Examples
Here are some TreeValue examples that can be viewed on Colab.
* Visualization: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opendilab/treevalue/blob/main/examples/visualization.ipynb)
## Makefile
```shell
make clean # clean all output in notebook files
```
**Please make sure the output of all notebook files are cleared with `make clean` before committing**.
{
"cells": [
{
"cell_type": "markdown",
"id": "edb6ed95",
"metadata": {},
"source": [
"# Visualization of TreeValue"
]
},
{
"cell_type": "markdown",
"id": "9fb09e8c",
"metadata": {},
"source": [
"Create a simple `FastTreeValue` object."
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"!pip install treevalue"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc7e3dd2",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from treevalue import FastTreeValue\n",
"\n",
"t = FastTreeValue({\n",
" 'a': 1,\n",
" 'b': np.random.randint(0, 20, (3, 4)),\n",
" 'x': {\n",
" 'c': 3,\n",
" 'd': torch.randn(2, 3),\n",
" },\n",
"})"
]
},
{
"cell_type": "markdown",
"id": "8b654200",
"metadata": {},
"source": [
"See it"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30771752",
"metadata": {},
"outputs": [],
"source": [
"t"
]
},
{
"cell_type": "markdown",
"id": "b2c1bb61",
"metadata": {},
"source": [
"Do some operations on `t`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc97351a",
"metadata": {},
"outputs": [],
"source": [
"1 / (t + 0.5)"
]
},
{
"cell_type": "markdown",
"id": "b25f3db4",
"metadata": {},
"source": [
"Create another `FastTreeValue` which is sharing storage with `t`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af6345dc",
"metadata": {},
"outputs": [],
"source": [
"t2 = FastTreeValue({'ppp': t.x, 'x': {'t': t, 'y': t.x}})"
]
},
{
"cell_type": "markdown",
"id": "e039d8d9",
"metadata": {},
"source": [
"Just put `t` and `t2` into one diagram."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "57446c40",
"metadata": {},
"outputs": [],
"source": [
"from treevalue import graphics\n",
"\n",
"graphics((t, 't'), (t2, 't2'))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
...@@ -646,7 +646,7 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]): ...@@ -646,7 +646,7 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]):
title="This is a demo of 2 trees with dup value.", title="This is a demo of 2 trees with dup value.",
cfg={'bgcolor': '#ffffffff'}, cfg={'bgcolor': '#ffffffff'},
) )
assert len(graph_1.source) <= 4960 assert len(graph_1.source) <= 5000 # a value for svg in windows
graph_2 = treevalue_class.graphics( graph_2 = treevalue_class.graphics(
(t, 't'), (t1, 't1'), (t, 't'), (t1, 't1'),
...@@ -655,7 +655,7 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]): ...@@ -655,7 +655,7 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]):
title="This is a demo of 2 trees with dup value.", title="This is a demo of 2 trees with dup value.",
cfg={'bgcolor': '#ffffffff'}, cfg={'bgcolor': '#ffffffff'},
) )
assert len(graph_2.source) <= 5480 assert len(graph_2.source) <= 5600
graph_3 = treevalue_class.graphics( graph_3 = treevalue_class.graphics(
(t, 't'), (t1, 't1'), (t, 't'), (t1, 't1'),
...@@ -673,7 +673,7 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]): ...@@ -673,7 +673,7 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]):
title="This is a demo of 2 trees with dup value.", title="This is a demo of 2 trees with dup value.",
cfg={'bgcolor': '#ffffffff'}, cfg={'bgcolor': '#ffffffff'},
) )
assert len(graph_4.source) <= 3780 assert len(graph_4.source) <= 4000
graph_6 = treevalue_class.graphics( graph_6 = treevalue_class.graphics(
(t, 't'), (t1, 't1'), (t, 't'), (t1, 't1'),
......
...@@ -3,6 +3,7 @@ import re ...@@ -3,6 +3,7 @@ import re
from typing import Type from typing import Type
import pytest import pytest
from hbutils.testing import OS
from test.tree.tree.test_constraint import GreaterThanConstraint from test.tree.tree.test_constraint import GreaterThanConstraint
from treevalue import raw, TreeValue, delayed, ValidationError from treevalue import raw, TreeValue, delayed, ValidationError
...@@ -749,4 +750,31 @@ def get_treevalue_test(treevalue_class: Type[TreeValue]): ...@@ -749,4 +750,31 @@ def get_treevalue_test(treevalue_class: Type[TreeValue]):
assert newt1 == t1 assert newt1 == t1
assert newt1.constraint == t1.constraint assert newt1.constraint == t1.constraint
def test_repr_svg(self):
t1 = get_demo_constraint_tree()
assert hasattr(t1, '_repr_svg_')
_repr_svg_ = t1._repr_svg_()
assert isinstance(_repr_svg_, str)
assert 4500 <= len(_repr_svg_) <= 4900
def test_repr_png(self):
t1 = get_demo_constraint_tree()
assert hasattr(t1, '_repr_png_')
_repr_png_ = t1._repr_png_()
assert isinstance(_repr_png_, bytes)
if OS.windows:
assert 12000 <= len(_repr_png_) <= 16050
else:
assert 16050 <= len(_repr_png_) <= 20500
def test_repr_jpeg(self):
t1 = get_demo_constraint_tree()
assert hasattr(t1, '_repr_jpeg_')
_repr_jpeg_ = t1._repr_jpeg_()
assert isinstance(_repr_jpeg_, bytes)
assert 10500 <= len(_repr_jpeg_) <= 14500
return _TestClass return _TestClass
...@@ -37,7 +37,7 @@ class TestTreeTreeGraph: ...@@ -37,7 +37,7 @@ class TestTreeTreeGraph:
title="This is a demo of 2 trees with dup value.", title="This is a demo of 2 trees with dup value.",
cfg={'bgcolor': '#ffffffff'}, cfg={'bgcolor': '#ffffffff'},
) )
assert len(graph_1.source) <= 4960 assert len(graph_1.source) <= 5000
graph_2 = graphics( graph_2 = graphics(
(t, 't'), (t1, 't1'), (t, 't'), (t1, 't1'),
...@@ -46,7 +46,7 @@ class TestTreeTreeGraph: ...@@ -46,7 +46,7 @@ class TestTreeTreeGraph:
title="This is a demo of 2 trees with dup value.", title="This is a demo of 2 trees with dup value.",
cfg={'bgcolor': '#ffffffff'}, cfg={'bgcolor': '#ffffffff'},
) )
assert len(graph_2.source) <= 5480 assert len(graph_2.source) <= 5600
graph_3 = graphics( graph_3 = graphics(
(t, 't'), (t1, 't1'), (t, 't'), (t1, 't1'),
...@@ -64,7 +64,7 @@ class TestTreeTreeGraph: ...@@ -64,7 +64,7 @@ class TestTreeTreeGraph:
title="This is a demo of 2 trees with dup value.", title="This is a demo of 2 trees with dup value.",
cfg={'bgcolor': '#ffffffff'}, cfg={'bgcolor': '#ffffffff'},
) )
assert len(graph_4.source) <= 3780 assert len(graph_4.source) <= 4000
graph_6 = graphics( graph_6 = graphics(
(t, 't'), (t1, 't1'), (t, 't'), (t1, 't1'),
......
from shutil import which
from unittest.mock import patch
import pytest import pytest
from treevalue.utils import build_graph from treevalue.utils import build_graph
@pytest.fixture()
def no_dot():
def _no_dot_which(x):
if x == 'dot':
return None
else:
return which(x)
with patch('shutil.which', _no_dot_which):
yield
@pytest.mark.unittest @pytest.mark.unittest
class TestUtilsTree: class TestUtilsTree:
def test_build_graph(self): def test_build_graph(self):
...@@ -14,17 +29,20 @@ class TestUtilsTree: ...@@ -14,17 +29,20 @@ class TestUtilsTree:
g2 = build_graph(t, graph_title="Demo 2 of build_graph.") g2 = build_graph(t, graph_title="Demo 2 of build_graph.")
assert "Demo 2 of build_graph." in g2.source assert "Demo 2 of build_graph." in g2.source
assert "<root_0>" in g2.source assert "root_0" in g2.source
assert "<root_0>.x" in g2.source
assert len(g2.source) <= 580 assert len(g2.source) <= 580
g3 = build_graph((t,), graph_title="Demo 3 of build_graph.") g3 = build_graph((t,), graph_title="Demo 3 of build_graph.")
assert "Demo 3 of build_graph." in g3.source assert "Demo 3 of build_graph." in g3.source
assert "<root_0>" in g3.source assert "root_0" in g3.source
assert "<root_0>.x" in g3.source
assert len(g3.source) <= 580 assert len(g3.source) <= 580
g4 = build_graph((), graph_title="Demo 4 of build_graph.") g4 = build_graph((), graph_title="Demo 4 of build_graph.")
assert "Demo 4 of build_graph." in g4.source assert "Demo 4 of build_graph." in g4.source
assert "node" not in g4.source assert "node" not in g4.source
assert len(g4.source) <= 110 assert len(g4.source) <= 110
def test_build_graph_without_dot(self, no_dot):
t = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}
with pytest.raises(EnvironmentError):
_ = build_graph((t, 't'), graph_title="Demo of build_graph.")
from functools import lru_cache from functools import lru_cache
from typing import Type, Callable, Union, Optional, Tuple from typing import Type, Callable, Union, Optional, Tuple
from graphviz import Digraph, nohtml from graphviz import Digraph
from hbutils.reflection import post_process, dynamic_call, freduce from hbutils.reflection import post_process, dynamic_call, freduce
from .tree import TreeValue from .tree import TreeValue
...@@ -134,6 +134,13 @@ _GENERIC_N = 36 ...@@ -134,6 +134,13 @@ _GENERIC_N = 36
_GENERIC_S = _GENERIC_N // 3 _GENERIC_S = _GENERIC_N // 3
def _custom_repr(x):
if isinstance(x, (str,)):
return repr(x)
else:
return str(x)
def graphics(*trees, title: Optional[str] = None, cfg: Optional[dict] = None, def graphics(*trees, title: Optional[str] = None, cfg: Optional[dict] = None,
dup_value: Union[bool, Callable, type, Tuple[Type, ...]] = False, dup_value: Union[bool, Callable, type, Tuple[Type, ...]] = False,
repr_gen: Optional[Callable] = None, repr_gen: Optional[Callable] = None,
...@@ -182,9 +189,9 @@ def graphics(*trees, title: Optional[str] = None, cfg: Optional[dict] = None, ...@@ -182,9 +189,9 @@ def graphics(*trees, title: Optional[str] = None, cfg: Optional[dict] = None,
return build_graph( return build_graph(
*trees, *trees,
node_id_gen=_node_tag, node_id_gen=_node_tag,
graph_title=title or "<untitled>", graph_title=title,
graph_cfg=cfg or {}, graph_cfg=cfg or {},
repr_gen=repr_gen or (lambda x: nohtml(repr(x))), repr_gen=repr_gen or (lambda x: _custom_repr(x)),
iter_gen=lambda n: iter(n.items()) if isinstance(n, TreeValue) else None, iter_gen=lambda n: iter(n.items()) if isinstance(n, TreeValue) else None,
node_cfg_gen=_dict_call_merge(lambda n, p, np, pp, is_node, is_root, root: { node_cfg_gen=_dict_call_merge(lambda n, p, np, pp, is_node, is_root, root: {
'fillcolor': _shape_color(root[2]), 'fillcolor': _shape_color(root[2]),
...@@ -193,7 +200,7 @@ def graphics(*trees, title: Optional[str] = None, cfg: Optional[dict] = None, ...@@ -193,7 +200,7 @@ def graphics(*trees, title: Optional[str] = None, cfg: Optional[dict] = None,
'style': 'filled', 'style': 'filled',
'shape': 'diamond' if is_root else ('ellipse' if is_node else 'box'), 'shape': 'diamond' if is_root else ('ellipse' if is_node else 'box'),
'penwidth': 3 if is_root else 1.5, 'penwidth': 3 if is_root else 1.5,
'fontname': "Times-Roman bold" if is_node else "Times-Roman", 'fontname': "monospace bold" if is_node else "monospace",
}, (node_cfg_gen or (lambda: {}))), }, (node_cfg_gen or (lambda: {}))),
edge_cfg_gen=_dict_call_merge(lambda n, p, np, pp, is_node, root: { edge_cfg_gen=_dict_call_merge(lambda n, p, np, pp, is_node, root: {
'arrowhead': 'vee' if is_node else 'dot', 'arrowhead': 'vee' if is_node else 'dot',
......
...@@ -49,6 +49,8 @@ cdef class TreeValue: ...@@ -49,6 +49,8 @@ cdef class TreeValue:
cpdef void validate(self) except* cpdef void validate(self) except*
cdef object _get_tree_graph(self)
cdef str _prefix_fix(object text, object prefix) cdef str _prefix_fix(object text, object prefix)
cdef str _title_repr(TreeStorage st, object type_) cdef str _title_repr(TreeStorage st, object type_)
cdef object _build_tree(TreeStorage st, object type_, str prefix, dict id_pool, tuple path) cdef object _build_tree(TreeStorage st, object type_, str prefix, dict id_pool, tuple path)
......
...@@ -972,6 +972,23 @@ cdef class TreeValue: ...@@ -972,6 +972,23 @@ cdef class TreeValue:
else: else:
return self._type(self._st, constraint=to_constraint([constraint, self.constraint])) return self._type(self._st, constraint=to_constraint([constraint, self.constraint]))
@cython.final
cdef inline object _get_tree_graph(self):
from .graph import graphics
return graphics((self, '<root>'))
@cython.binding(True)
def _repr_svg_(self):
return self._get_tree_graph().pipe(format='svg', encoding='utf-8')
@cython.binding(True)
def _repr_png_(self):
return self._get_tree_graph().pipe(format='png')
@cython.binding(True)
def _repr_jpeg_(self):
return self._get_tree_graph().pipe(format='jpeg')
cdef str _prefix_fix(object text, object prefix): cdef str _prefix_fix(object text, object prefix):
cdef list lines = [] cdef list lines = []
cdef int i cdef int i
......
import html
import re import re
import shutil
from functools import wraps from functools import wraps
from queue import Queue from queue import Queue
from typing import Optional, Mapping, Any, Callable from typing import Optional, Mapping, Any, Callable
...@@ -6,8 +8,6 @@ from typing import Optional, Mapping, Any, Callable ...@@ -6,8 +8,6 @@ from typing import Optional, Mapping, Any, Callable
from graphviz import Digraph from graphviz import Digraph
from hbutils.reflection import dynamic_call, post_process from hbutils.reflection import dynamic_call, post_process
from .random import random_hex_with_timestamp
def _title_flatten(title): def _title_flatten(title):
title = re.sub(r'[^a-zA-Z0-9_]+', '_', str(title)) title = re.sub(r'[^a-zA-Z0-9_]+', '_', str(title))
...@@ -69,6 +69,35 @@ def _root_process(root, index): ...@@ -69,6 +69,35 @@ def _root_process(root, index):
return root, '<root_%d>' % (index,), index return root, '<root_%d>' % (index,), index
_LEFT_ALIGN_LINESEP = r'\l'
def _custom_html_escape(x):
s = html.escape(x) \
.replace(' ', '&ensp;') \
.replace('\r\n', _LEFT_ALIGN_LINESEP) \
.replace('\n', _LEFT_ALIGN_LINESEP) \
.replace('\r', _LEFT_ALIGN_LINESEP)
if len(x.splitlines()) > 1 and not s.endswith(_LEFT_ALIGN_LINESEP):
s += _LEFT_ALIGN_LINESEP
return s
_DOT_NOT_INSTALL_TEXT = """
Command 'dot' not found in this environment.
Please install graphviz first, with:
- Ubuntu: apt install -y graphviz
- Windows: choco install graphviz
- MacOS: brew install graphviz
More detailed information can be found in documentation of graphviz python package: https://graphviz.readthedocs.io/en/stable/#installation
""".lstrip()
def _check_dot_installed():
if not shutil.which('dot'):
raise EnvironmentError(_DOT_NOT_INSTALL_TEXT)
def build_graph(*roots, node_id_gen: Optional[Callable] = None, def build_graph(*roots, node_id_gen: Optional[Callable] = None,
graph_title: Optional[str] = None, graph_name: Optional[str] = None, graph_title: Optional[str] = None, graph_name: Optional[str] = None,
graph_cfg: Optional[Mapping[str, Any]] = None, graph_cfg: Optional[Mapping[str, Any]] = None,
...@@ -100,11 +129,13 @@ def build_graph(*roots, node_id_gen: Optional[Callable] = None, ...@@ -100,11 +129,13 @@ def build_graph(*roots, node_id_gen: Optional[Callable] = None,
Returns: Returns:
- dot (:obj:`Digraph`): Graphviz directed graph object. - dot (:obj:`Digraph`): Graphviz directed graph object.
""" """
_check_dot_installed()
roots = [_root_process(root, index) for index, root in enumerate(roots)] roots = [_root_process(root, index) for index, root in enumerate(roots)]
roots = [item for item in roots if item is not None] roots = [item for item in roots if item is not None]
node_id_gen = dynamic_call(suffixed_node_id(node_id_gen or _default_node_id)) node_id_gen = dynamic_call(suffixed_node_id(node_id_gen or _default_node_id))
graph_title = graph_title or ('untitled_' + random_hex_with_timestamp()) graph_title = graph_title or ''
graph_name = graph_name or _title_flatten(graph_title) graph_name = graph_name or _title_flatten(graph_title)
graph_cfg = _no_none_value(graph_cfg or {}) graph_cfg = _no_none_value(graph_cfg or {})
...@@ -125,7 +156,7 @@ def build_graph(*roots, node_id_gen: Optional[Callable] = None, ...@@ -125,7 +156,7 @@ def build_graph(*roots, node_id_gen: Optional[Callable] = None,
root_node_id = node_id_gen(root, None, [], [], True) root_node_id = node_id_gen(root, None, [], [], True)
if root_node_id not in _queued_node_ids: if root_node_id not in _queued_node_ids:
graph.node( graph.node(
name=root_node_id, label=root_title, name=root_node_id, label=_custom_html_escape(root_title),
**node_cfg_gen(root, None, [], [], True, True, root_info) **node_cfg_gen(root, None, [], [], True, True, root_info)
) )
_queue.put((root_node_id, root, (root, root_title, root_index), [])) _queue.put((root_node_id, root, (root, root_title, root_index), []))
...@@ -145,14 +176,14 @@ def build_graph(*roots, node_id_gen: Optional[Callable] = None, ...@@ -145,14 +176,14 @@ def build_graph(*roots, node_id_gen: Optional[Callable] = None,
_current_label = repr_gen(_current_node, _current_path) _current_label = repr_gen(_current_node, _current_path)
if _current_id not in _queued_node_ids: if _current_id not in _queued_node_ids:
graph.node(_current_id, label=_current_label, graph.node(_current_id, label=_custom_html_escape(_current_label),
**node_cfg_gen(_current_node, _parent_node, _current_path, _parent_path, **node_cfg_gen(_current_node, _parent_node, _current_path, _parent_path,
_is_node, False, _root_info)) _is_node, False, _root_info))
if iter_gen(_current_node, _current_path): if iter_gen(_current_node, _current_path):
_queue.put((_current_id, _current_node, _root_info, _current_path)) _queue.put((_current_id, _current_node, _root_info, _current_path))
_queued_node_ids.add(_current_id) _queued_node_ids.add(_current_id)
if (_parent_id, _current_id, key) not in _queued_edges: if (_parent_id, _current_id, key) not in _queued_edges:
graph.edge(_parent_id, _current_id, label=key, graph.edge(_parent_id, _current_id, label=_custom_html_escape(key),
**edge_cfg_gen(_current_node, _parent_node, _current_path, _parent_path, **edge_cfg_gen(_current_node, _parent_node, _current_path, _parent_path,
_is_node, _root_info)) _is_node, _root_info))
_queued_edges.add((_parent_id, _current_id, key)) _queued_edges.add((_parent_id, _current_id, key))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册