未验证 提交 2cc7d7a7 编写于 作者: HansBug's avatar HansBug 😆 提交者: GitHub

Merge pull request #75 from opendilab/dev/repr

dev(hansbug): optimize graphviz visualization
......@@ -66,17 +66,17 @@ jobs:
shell: bash
run: |
sudo apt-get update
sudo apt-get install -y tree cloc wget curl make
sudo apt-get install -y tree cloc wget curl make graphviz
- name: Set up system dependences on Windows
if: ${{ env.OS_NAME == 'Windows' }}
shell: bash
run: |
choco install tree cloc wget curl make zip
choco install tree cloc wget curl make zip graphviz
- name: Set up system dependences on MacOS
if: ${{ env.OS_NAME == 'MacOS' }}
shell: bash
run: |
brew install tree cloc wget curl make zip
brew install tree cloc wget curl make zip graphviz
- name: Set up python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
......
.PHONY: clean
JUPYTER ?= $(shell which jupyter)
NBCONVERT ?= ${JUPYTER} nbconvert
SOURCE ?= .
IPYNBS := $(shell find ${SOURCE} -name *.ipynb | grep -v .ipynb_checkpoints)
clean:
for nb in ${IPYNBS}; do \
if [ -f $$nb ]; then \
$(NBCONVERT) --clear-output --inplace $$nb; \
fi; \
done;
\ No newline at end of file
# Examples of TreeValue
## Examples
Here are some TreeValue examples that can be viewed on Colab.
* Visualization: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opendilab/treevalue/blob/main/examples/visualization.ipynb)
## Makefile
```shell
make clean # clean all output in notebook files
```
**Please make sure the output of all notebook files are cleared with `make clean` before committing**.
{
"cells": [
{
"cell_type": "markdown",
"id": "edb6ed95",
"metadata": {},
"source": [
"# Visualization of TreeValue"
]
},
{
"cell_type": "markdown",
"id": "9fb09e8c",
"metadata": {},
"source": [
"Create a simple `FastTreeValue` object."
]
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"!pip install treevalue"
],
"metadata": {
"collapsed": false
}
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc7e3dd2",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"from treevalue import FastTreeValue\n",
"\n",
"t = FastTreeValue({\n",
" 'a': 1,\n",
" 'b': np.random.randint(0, 20, (3, 4)),\n",
" 'x': {\n",
" 'c': 3,\n",
" 'd': torch.randn(2, 3),\n",
" },\n",
"})"
]
},
{
"cell_type": "markdown",
"id": "8b654200",
"metadata": {},
"source": [
"See it"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30771752",
"metadata": {},
"outputs": [],
"source": [
"t"
]
},
{
"cell_type": "markdown",
"id": "b2c1bb61",
"metadata": {},
"source": [
"Do some operations on `t`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc97351a",
"metadata": {},
"outputs": [],
"source": [
"1 / (t + 0.5)"
]
},
{
"cell_type": "markdown",
"id": "b25f3db4",
"metadata": {},
"source": [
"Create another `FastTreeValue` which is sharing storage with `t`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "af6345dc",
"metadata": {},
"outputs": [],
"source": [
"t2 = FastTreeValue({'ppp': t.x, 'x': {'t': t, 'y': t.x}})"
]
},
{
"cell_type": "markdown",
"id": "e039d8d9",
"metadata": {},
"source": [
"Just put `t` and `t2` into one diagram."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "57446c40",
"metadata": {},
"outputs": [],
"source": [
"from treevalue import graphics\n",
"\n",
"graphics((t, 't'), (t2, 't2'))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
......@@ -646,7 +646,7 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]):
title="This is a demo of 2 trees with dup value.",
cfg={'bgcolor': '#ffffffff'},
)
assert len(graph_1.source) <= 4960
assert len(graph_1.source) <= 5000 # a value for svg in windows
graph_2 = treevalue_class.graphics(
(t, 't'), (t1, 't1'),
......@@ -655,7 +655,7 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]):
title="This is a demo of 2 trees with dup value.",
cfg={'bgcolor': '#ffffffff'},
)
assert len(graph_2.source) <= 5480
assert len(graph_2.source) <= 5600
graph_3 = treevalue_class.graphics(
(t, 't'), (t1, 't1'),
......@@ -673,7 +673,7 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]):
title="This is a demo of 2 trees with dup value.",
cfg={'bgcolor': '#ffffffff'},
)
assert len(graph_4.source) <= 3780
assert len(graph_4.source) <= 4000
graph_6 = treevalue_class.graphics(
(t, 't'), (t1, 't1'),
......
......@@ -3,6 +3,7 @@ import re
from typing import Type
import pytest
from hbutils.testing import OS
from test.tree.tree.test_constraint import GreaterThanConstraint
from treevalue import raw, TreeValue, delayed, ValidationError
......@@ -749,4 +750,31 @@ def get_treevalue_test(treevalue_class: Type[TreeValue]):
assert newt1 == t1
assert newt1.constraint == t1.constraint
def test_repr_svg(self):
t1 = get_demo_constraint_tree()
assert hasattr(t1, '_repr_svg_')
_repr_svg_ = t1._repr_svg_()
assert isinstance(_repr_svg_, str)
assert 4500 <= len(_repr_svg_) <= 4900
def test_repr_png(self):
t1 = get_demo_constraint_tree()
assert hasattr(t1, '_repr_png_')
_repr_png_ = t1._repr_png_()
assert isinstance(_repr_png_, bytes)
if OS.windows:
assert 12000 <= len(_repr_png_) <= 16050
else:
assert 16050 <= len(_repr_png_) <= 20500
def test_repr_jpeg(self):
t1 = get_demo_constraint_tree()
assert hasattr(t1, '_repr_jpeg_')
_repr_jpeg_ = t1._repr_jpeg_()
assert isinstance(_repr_jpeg_, bytes)
assert 10500 <= len(_repr_jpeg_) <= 14500
return _TestClass
......@@ -37,7 +37,7 @@ class TestTreeTreeGraph:
title="This is a demo of 2 trees with dup value.",
cfg={'bgcolor': '#ffffffff'},
)
assert len(graph_1.source) <= 4960
assert len(graph_1.source) <= 5000
graph_2 = graphics(
(t, 't'), (t1, 't1'),
......@@ -46,7 +46,7 @@ class TestTreeTreeGraph:
title="This is a demo of 2 trees with dup value.",
cfg={'bgcolor': '#ffffffff'},
)
assert len(graph_2.source) <= 5480
assert len(graph_2.source) <= 5600
graph_3 = graphics(
(t, 't'), (t1, 't1'),
......@@ -64,7 +64,7 @@ class TestTreeTreeGraph:
title="This is a demo of 2 trees with dup value.",
cfg={'bgcolor': '#ffffffff'},
)
assert len(graph_4.source) <= 3780
assert len(graph_4.source) <= 4000
graph_6 = graphics(
(t, 't'), (t1, 't1'),
......
from shutil import which
from unittest.mock import patch
import pytest
from treevalue.utils import build_graph
@pytest.fixture()
def no_dot():
def _no_dot_which(x):
if x == 'dot':
return None
else:
return which(x)
with patch('shutil.which', _no_dot_which):
yield
@pytest.mark.unittest
class TestUtilsTree:
def test_build_graph(self):
......@@ -14,17 +29,20 @@ class TestUtilsTree:
g2 = build_graph(t, graph_title="Demo 2 of build_graph.")
assert "Demo 2 of build_graph." in g2.source
assert "<root_0>" in g2.source
assert "<root_0>.x" in g2.source
assert "root_0" in g2.source
assert len(g2.source) <= 580
g3 = build_graph((t,), graph_title="Demo 3 of build_graph.")
assert "Demo 3 of build_graph." in g3.source
assert "<root_0>" in g3.source
assert "<root_0>.x" in g3.source
assert "root_0" in g3.source
assert len(g3.source) <= 580
g4 = build_graph((), graph_title="Demo 4 of build_graph.")
assert "Demo 4 of build_graph." in g4.source
assert "node" not in g4.source
assert len(g4.source) <= 110
def test_build_graph_without_dot(self, no_dot):
t = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}
with pytest.raises(EnvironmentError):
_ = build_graph((t, 't'), graph_title="Demo of build_graph.")
from functools import lru_cache
from typing import Type, Callable, Union, Optional, Tuple
from graphviz import Digraph, nohtml
from graphviz import Digraph
from hbutils.reflection import post_process, dynamic_call, freduce
from .tree import TreeValue
......@@ -134,6 +134,13 @@ _GENERIC_N = 36
_GENERIC_S = _GENERIC_N // 3
def _custom_repr(x):
if isinstance(x, (str,)):
return repr(x)
else:
return str(x)
def graphics(*trees, title: Optional[str] = None, cfg: Optional[dict] = None,
dup_value: Union[bool, Callable, type, Tuple[Type, ...]] = False,
repr_gen: Optional[Callable] = None,
......@@ -182,9 +189,9 @@ def graphics(*trees, title: Optional[str] = None, cfg: Optional[dict] = None,
return build_graph(
*trees,
node_id_gen=_node_tag,
graph_title=title or "<untitled>",
graph_title=title,
graph_cfg=cfg or {},
repr_gen=repr_gen or (lambda x: nohtml(repr(x))),
repr_gen=repr_gen or (lambda x: _custom_repr(x)),
iter_gen=lambda n: iter(n.items()) if isinstance(n, TreeValue) else None,
node_cfg_gen=_dict_call_merge(lambda n, p, np, pp, is_node, is_root, root: {
'fillcolor': _shape_color(root[2]),
......@@ -193,7 +200,7 @@ def graphics(*trees, title: Optional[str] = None, cfg: Optional[dict] = None,
'style': 'filled',
'shape': 'diamond' if is_root else ('ellipse' if is_node else 'box'),
'penwidth': 3 if is_root else 1.5,
'fontname': "Times-Roman bold" if is_node else "Times-Roman",
'fontname': "monospace bold" if is_node else "monospace",
}, (node_cfg_gen or (lambda: {}))),
edge_cfg_gen=_dict_call_merge(lambda n, p, np, pp, is_node, root: {
'arrowhead': 'vee' if is_node else 'dot',
......
......@@ -49,6 +49,8 @@ cdef class TreeValue:
cpdef void validate(self) except*
cdef object _get_tree_graph(self)
cdef str _prefix_fix(object text, object prefix)
cdef str _title_repr(TreeStorage st, object type_)
cdef object _build_tree(TreeStorage st, object type_, str prefix, dict id_pool, tuple path)
......
......@@ -972,6 +972,23 @@ cdef class TreeValue:
else:
return self._type(self._st, constraint=to_constraint([constraint, self.constraint]))
@cython.final
cdef inline object _get_tree_graph(self):
from .graph import graphics
return graphics((self, '<root>'))
@cython.binding(True)
def _repr_svg_(self):
return self._get_tree_graph().pipe(format='svg', encoding='utf-8')
@cython.binding(True)
def _repr_png_(self):
return self._get_tree_graph().pipe(format='png')
@cython.binding(True)
def _repr_jpeg_(self):
return self._get_tree_graph().pipe(format='jpeg')
cdef str _prefix_fix(object text, object prefix):
cdef list lines = []
cdef int i
......
import html
import re
import shutil
from functools import wraps
from queue import Queue
from typing import Optional, Mapping, Any, Callable
......@@ -6,8 +8,6 @@ from typing import Optional, Mapping, Any, Callable
from graphviz import Digraph
from hbutils.reflection import dynamic_call, post_process
from .random import random_hex_with_timestamp
def _title_flatten(title):
title = re.sub(r'[^a-zA-Z0-9_]+', '_', str(title))
......@@ -69,6 +69,35 @@ def _root_process(root, index):
return root, '<root_%d>' % (index,), index
_LEFT_ALIGN_LINESEP = r'\l'
def _custom_html_escape(x):
s = html.escape(x) \
.replace(' ', '&ensp;') \
.replace('\r\n', _LEFT_ALIGN_LINESEP) \
.replace('\n', _LEFT_ALIGN_LINESEP) \
.replace('\r', _LEFT_ALIGN_LINESEP)
if len(x.splitlines()) > 1 and not s.endswith(_LEFT_ALIGN_LINESEP):
s += _LEFT_ALIGN_LINESEP
return s
_DOT_NOT_INSTALL_TEXT = """
Command 'dot' not found in this environment.
Please install graphviz first, with:
- Ubuntu: apt install -y graphviz
- Windows: choco install graphviz
- MacOS: brew install graphviz
More detailed information can be found in documentation of graphviz python package: https://graphviz.readthedocs.io/en/stable/#installation
""".lstrip()
def _check_dot_installed():
if not shutil.which('dot'):
raise EnvironmentError(_DOT_NOT_INSTALL_TEXT)
def build_graph(*roots, node_id_gen: Optional[Callable] = None,
graph_title: Optional[str] = None, graph_name: Optional[str] = None,
graph_cfg: Optional[Mapping[str, Any]] = None,
......@@ -100,11 +129,13 @@ def build_graph(*roots, node_id_gen: Optional[Callable] = None,
Returns:
- dot (:obj:`Digraph`): Graphviz directed graph object.
"""
_check_dot_installed()
roots = [_root_process(root, index) for index, root in enumerate(roots)]
roots = [item for item in roots if item is not None]
node_id_gen = dynamic_call(suffixed_node_id(node_id_gen or _default_node_id))
graph_title = graph_title or ('untitled_' + random_hex_with_timestamp())
graph_title = graph_title or ''
graph_name = graph_name or _title_flatten(graph_title)
graph_cfg = _no_none_value(graph_cfg or {})
......@@ -125,7 +156,7 @@ def build_graph(*roots, node_id_gen: Optional[Callable] = None,
root_node_id = node_id_gen(root, None, [], [], True)
if root_node_id not in _queued_node_ids:
graph.node(
name=root_node_id, label=root_title,
name=root_node_id, label=_custom_html_escape(root_title),
**node_cfg_gen(root, None, [], [], True, True, root_info)
)
_queue.put((root_node_id, root, (root, root_title, root_index), []))
......@@ -145,14 +176,14 @@ def build_graph(*roots, node_id_gen: Optional[Callable] = None,
_current_label = repr_gen(_current_node, _current_path)
if _current_id not in _queued_node_ids:
graph.node(_current_id, label=_current_label,
graph.node(_current_id, label=_custom_html_escape(_current_label),
**node_cfg_gen(_current_node, _parent_node, _current_path, _parent_path,
_is_node, False, _root_info))
if iter_gen(_current_node, _current_path):
_queue.put((_current_id, _current_node, _root_info, _current_path))
_queued_node_ids.add(_current_id)
if (_parent_id, _current_id, key) not in _queued_edges:
graph.edge(_parent_id, _current_id, label=key,
graph.edge(_parent_id, _current_id, label=_custom_html_escape(key),
**edge_cfg_gen(_current_node, _parent_node, _current_path, _parent_path,
_is_node, _root_info))
_queued_edges.add((_parent_id, _current_id, key))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册