diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 7ce3b5afc6141ffe7c190c9457400e31d98e6155..4ef3ab559fa4084b34bde203a212bf59591e71d9 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -66,17 +66,17 @@ jobs: shell: bash run: | sudo apt-get update - sudo apt-get install -y tree cloc wget curl make + sudo apt-get install -y tree cloc wget curl make graphviz - name: Set up system dependences on Windows if: ${{ env.OS_NAME == 'Windows' }} shell: bash run: | - choco install tree cloc wget curl make zip + choco install tree cloc wget curl make zip graphviz - name: Set up system dependences on MacOS if: ${{ env.OS_NAME == 'MacOS' }} shell: bash run: | - brew install tree cloc wget curl make zip + brew install tree cloc wget curl make zip graphviz - name: Set up python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: diff --git a/examples/Makefile b/examples/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..13e0fe6e44f9e3c58fe0faf4aec656b0997c5517 --- /dev/null +++ b/examples/Makefile @@ -0,0 +1,14 @@ +.PHONY: clean + +JUPYTER ?= $(shell which jupyter) +NBCONVERT ?= ${JUPYTER} nbconvert + +SOURCE ?= . +IPYNBS := $(shell find ${SOURCE} -name *.ipynb | grep -v .ipynb_checkpoints) + +clean: + for nb in ${IPYNBS}; do \ + if [ -f $$nb ]; then \ + $(NBCONVERT) --clear-output --inplace $$nb; \ + fi; \ + done; \ No newline at end of file diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7c1f567e825ea786a3d8ea1d7c834ad8ded7ee11 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,16 @@ +# Examples of TreeValue + +## Examples + +Here are some TreeValue examples that can be viewed on Colab. + +* Visualization: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opendilab/treevalue/blob/main/examples/visualization.ipynb) + +## Makefile + +```shell +make clean # clean all output in notebook files + +``` + +**Please make sure the output of all notebook files are cleared with `make clean` before committing**. diff --git a/examples/visualization.ipynb b/examples/visualization.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..6d6f9bf617888db910a0bdadd32f6ee7c51b00a6 --- /dev/null +++ b/examples/visualization.ipynb @@ -0,0 +1,147 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "edb6ed95", + "metadata": {}, + "source": [ + "# Visualization of TreeValue" + ] + }, + { + "cell_type": "markdown", + "id": "9fb09e8c", + "metadata": {}, + "source": [ + "Create a simple `FastTreeValue` object." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "!pip install treevalue" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc7e3dd2", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "from treevalue import FastTreeValue\n", + "\n", + "t = FastTreeValue({\n", + " 'a': 1,\n", + " 'b': np.random.randint(0, 20, (3, 4)),\n", + " 'x': {\n", + " 'c': 3,\n", + " 'd': torch.randn(2, 3),\n", + " },\n", + "})" + ] + }, + { + "cell_type": "markdown", + "id": "8b654200", + "metadata": {}, + "source": [ + "See it" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30771752", + "metadata": {}, + "outputs": [], + "source": [ + "t" + ] + }, + { + "cell_type": "markdown", + "id": "b2c1bb61", + "metadata": {}, + "source": [ + "Do some operations on `t`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc97351a", + "metadata": {}, + "outputs": [], + "source": [ + "1 / (t + 0.5)" + ] + }, + { + "cell_type": "markdown", + "id": "b25f3db4", + "metadata": {}, + "source": [ + "Create another `FastTreeValue` which is sharing storage with `t`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "af6345dc", + "metadata": {}, + "outputs": [], + "source": [ + "t2 = FastTreeValue({'ppp': t.x, 'x': {'t': t, 'y': t.x}})" + ] + }, + { + "cell_type": "markdown", + "id": "e039d8d9", + "metadata": {}, + "source": [ + "Just put `t` and `t2` into one diagram." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "57446c40", + "metadata": {}, + "outputs": [], + "source": [ + "from treevalue import graphics\n", + "\n", + "graphics((t, 't'), (t2, 't2'))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/test/tree/general/base.py b/test/tree/general/base.py index 84507554f39d1c0d3905e1534a8fad0b452b2e09..3eac30009e8a0b3f4822c1adb3a70facaf98b5f9 100644 --- a/test/tree/general/base.py +++ b/test/tree/general/base.py @@ -646,7 +646,7 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]): title="This is a demo of 2 trees with dup value.", cfg={'bgcolor': '#ffffffff'}, ) - assert len(graph_1.source) <= 4960 + assert len(graph_1.source) <= 5000 # a value for svg in windows graph_2 = treevalue_class.graphics( (t, 't'), (t1, 't1'), @@ -655,7 +655,7 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]): title="This is a demo of 2 trees with dup value.", cfg={'bgcolor': '#ffffffff'}, ) - assert len(graph_2.source) <= 5480 + assert len(graph_2.source) <= 5600 graph_3 = treevalue_class.graphics( (t, 't'), (t1, 't1'), @@ -673,7 +673,7 @@ def get_fasttreevalue_test(treevalue_class: Type[FastTreeValue]): title="This is a demo of 2 trees with dup value.", cfg={'bgcolor': '#ffffffff'}, ) - assert len(graph_4.source) <= 3780 + assert len(graph_4.source) <= 4000 graph_6 = treevalue_class.graphics( (t, 't'), (t1, 't1'), diff --git a/test/tree/tree/base.py b/test/tree/tree/base.py index 9aa0aeb1902bfc642bd82443753597fe4f64e40a..4dc0e0b6e38288e5393d2559cd16d147e0fd3a66 100644 --- a/test/tree/tree/base.py +++ b/test/tree/tree/base.py @@ -3,6 +3,7 @@ import re from typing import Type import pytest +from hbutils.testing import OS from test.tree.tree.test_constraint import GreaterThanConstraint from treevalue import raw, TreeValue, delayed, ValidationError @@ -749,4 +750,31 @@ def get_treevalue_test(treevalue_class: Type[TreeValue]): assert newt1 == t1 assert newt1.constraint == t1.constraint + def test_repr_svg(self): + t1 = get_demo_constraint_tree() + assert hasattr(t1, '_repr_svg_') + + _repr_svg_ = t1._repr_svg_() + assert isinstance(_repr_svg_, str) + assert 4500 <= len(_repr_svg_) <= 4900 + + def test_repr_png(self): + t1 = get_demo_constraint_tree() + assert hasattr(t1, '_repr_png_') + + _repr_png_ = t1._repr_png_() + assert isinstance(_repr_png_, bytes) + if OS.windows: + assert 12000 <= len(_repr_png_) <= 16050 + else: + assert 16050 <= len(_repr_png_) <= 20500 + + def test_repr_jpeg(self): + t1 = get_demo_constraint_tree() + assert hasattr(t1, '_repr_jpeg_') + + _repr_jpeg_ = t1._repr_jpeg_() + assert isinstance(_repr_jpeg_, bytes) + assert 10500 <= len(_repr_jpeg_) <= 14500 + return _TestClass diff --git a/test/tree/tree/test_graph.py b/test/tree/tree/test_graph.py index 6fe9c54f731b4d6b385f35a4164bec7c12827db2..833cdbf72f2d56864bc181489b1e1e56b8422ad4 100644 --- a/test/tree/tree/test_graph.py +++ b/test/tree/tree/test_graph.py @@ -37,7 +37,7 @@ class TestTreeTreeGraph: title="This is a demo of 2 trees with dup value.", cfg={'bgcolor': '#ffffffff'}, ) - assert len(graph_1.source) <= 4960 + assert len(graph_1.source) <= 5000 graph_2 = graphics( (t, 't'), (t1, 't1'), @@ -46,7 +46,7 @@ class TestTreeTreeGraph: title="This is a demo of 2 trees with dup value.", cfg={'bgcolor': '#ffffffff'}, ) - assert len(graph_2.source) <= 5480 + assert len(graph_2.source) <= 5600 graph_3 = graphics( (t, 't'), (t1, 't1'), @@ -64,7 +64,7 @@ class TestTreeTreeGraph: title="This is a demo of 2 trees with dup value.", cfg={'bgcolor': '#ffffffff'}, ) - assert len(graph_4.source) <= 3780 + assert len(graph_4.source) <= 4000 graph_6 = graphics( (t, 't'), (t1, 't1'), diff --git a/test/utils/test_tree.py b/test/utils/test_tree.py index 9c6ffc3df99483090e1034505d88f1386a1b35f8..431ef700918532611273483e80e6a1bb8b428700 100644 --- a/test/utils/test_tree.py +++ b/test/utils/test_tree.py @@ -1,8 +1,23 @@ +from shutil import which +from unittest.mock import patch + import pytest from treevalue.utils import build_graph +@pytest.fixture() +def no_dot(): + def _no_dot_which(x): + if x == 'dot': + return None + else: + return which(x) + + with patch('shutil.which', _no_dot_which): + yield + + @pytest.mark.unittest class TestUtilsTree: def test_build_graph(self): @@ -14,17 +29,20 @@ class TestUtilsTree: g2 = build_graph(t, graph_title="Demo 2 of build_graph.") assert "Demo 2 of build_graph." in g2.source - assert "" in g2.source - assert ".x" in g2.source + assert "root_0" in g2.source assert len(g2.source) <= 580 g3 = build_graph((t,), graph_title="Demo 3 of build_graph.") assert "Demo 3 of build_graph." in g3.source - assert "" in g3.source - assert ".x" in g3.source + assert "root_0" in g3.source assert len(g3.source) <= 580 g4 = build_graph((), graph_title="Demo 4 of build_graph.") assert "Demo 4 of build_graph." in g4.source assert "node" not in g4.source assert len(g4.source) <= 110 + + def test_build_graph_without_dot(self, no_dot): + t = {'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}} + with pytest.raises(EnvironmentError): + _ = build_graph((t, 't'), graph_title="Demo of build_graph.") diff --git a/treevalue/tree/tree/graph.py b/treevalue/tree/tree/graph.py index 39aafcbdb871dc3287c55d0f9c32b3985f72cc23..dbd4b77ca3513581d156038fef307018ea7c4f63 100644 --- a/treevalue/tree/tree/graph.py +++ b/treevalue/tree/tree/graph.py @@ -1,7 +1,7 @@ from functools import lru_cache from typing import Type, Callable, Union, Optional, Tuple -from graphviz import Digraph, nohtml +from graphviz import Digraph from hbutils.reflection import post_process, dynamic_call, freduce from .tree import TreeValue @@ -134,6 +134,13 @@ _GENERIC_N = 36 _GENERIC_S = _GENERIC_N // 3 +def _custom_repr(x): + if isinstance(x, (str,)): + return repr(x) + else: + return str(x) + + def graphics(*trees, title: Optional[str] = None, cfg: Optional[dict] = None, dup_value: Union[bool, Callable, type, Tuple[Type, ...]] = False, repr_gen: Optional[Callable] = None, @@ -182,9 +189,9 @@ def graphics(*trees, title: Optional[str] = None, cfg: Optional[dict] = None, return build_graph( *trees, node_id_gen=_node_tag, - graph_title=title or "", + graph_title=title, graph_cfg=cfg or {}, - repr_gen=repr_gen or (lambda x: nohtml(repr(x))), + repr_gen=repr_gen or (lambda x: _custom_repr(x)), iter_gen=lambda n: iter(n.items()) if isinstance(n, TreeValue) else None, node_cfg_gen=_dict_call_merge(lambda n, p, np, pp, is_node, is_root, root: { 'fillcolor': _shape_color(root[2]), @@ -193,7 +200,7 @@ def graphics(*trees, title: Optional[str] = None, cfg: Optional[dict] = None, 'style': 'filled', 'shape': 'diamond' if is_root else ('ellipse' if is_node else 'box'), 'penwidth': 3 if is_root else 1.5, - 'fontname': "Times-Roman bold" if is_node else "Times-Roman", + 'fontname': "monospace bold" if is_node else "monospace", }, (node_cfg_gen or (lambda: {}))), edge_cfg_gen=_dict_call_merge(lambda n, p, np, pp, is_node, root: { 'arrowhead': 'vee' if is_node else 'dot', diff --git a/treevalue/tree/tree/tree.pxd b/treevalue/tree/tree/tree.pxd index 00bbf34aa0adb7fce9be2773c0300312b2078744..446e154ed1223389ee2da1c234996beb621016a4 100644 --- a/treevalue/tree/tree/tree.pxd +++ b/treevalue/tree/tree/tree.pxd @@ -49,6 +49,8 @@ cdef class TreeValue: cpdef void validate(self) except* + cdef object _get_tree_graph(self) + cdef str _prefix_fix(object text, object prefix) cdef str _title_repr(TreeStorage st, object type_) cdef object _build_tree(TreeStorage st, object type_, str prefix, dict id_pool, tuple path) diff --git a/treevalue/tree/tree/tree.pyx b/treevalue/tree/tree/tree.pyx index 7b97ac0f8775cb820b47a718e8e59a4687e4ed0e..f584653ea0f5714c45780b64824b05bb2a38e1b3 100644 --- a/treevalue/tree/tree/tree.pyx +++ b/treevalue/tree/tree/tree.pyx @@ -972,6 +972,23 @@ cdef class TreeValue: else: return self._type(self._st, constraint=to_constraint([constraint, self.constraint])) + @cython.final + cdef inline object _get_tree_graph(self): + from .graph import graphics + return graphics((self, '')) + + @cython.binding(True) + def _repr_svg_(self): + return self._get_tree_graph().pipe(format='svg', encoding='utf-8') + + @cython.binding(True) + def _repr_png_(self): + return self._get_tree_graph().pipe(format='png') + + @cython.binding(True) + def _repr_jpeg_(self): + return self._get_tree_graph().pipe(format='jpeg') + cdef str _prefix_fix(object text, object prefix): cdef list lines = [] cdef int i diff --git a/treevalue/utils/tree.py b/treevalue/utils/tree.py index aad1a63dcb5a6fab9dab55797b8a512cb993abdd..bd5085d0d0bb5871e434b447024e6873dc45e762 100644 --- a/treevalue/utils/tree.py +++ b/treevalue/utils/tree.py @@ -1,4 +1,6 @@ +import html import re +import shutil from functools import wraps from queue import Queue from typing import Optional, Mapping, Any, Callable @@ -6,8 +8,6 @@ from typing import Optional, Mapping, Any, Callable from graphviz import Digraph from hbutils.reflection import dynamic_call, post_process -from .random import random_hex_with_timestamp - def _title_flatten(title): title = re.sub(r'[^a-zA-Z0-9_]+', '_', str(title)) @@ -69,6 +69,35 @@ def _root_process(root, index): return root, '' % (index,), index +_LEFT_ALIGN_LINESEP = r'\l' + + +def _custom_html_escape(x): + s = html.escape(x) \ + .replace(' ', ' ') \ + .replace('\r\n', _LEFT_ALIGN_LINESEP) \ + .replace('\n', _LEFT_ALIGN_LINESEP) \ + .replace('\r', _LEFT_ALIGN_LINESEP) + if len(x.splitlines()) > 1 and not s.endswith(_LEFT_ALIGN_LINESEP): + s += _LEFT_ALIGN_LINESEP + return s + + +_DOT_NOT_INSTALL_TEXT = """ +Command 'dot' not found in this environment. +Please install graphviz first, with: + - Ubuntu: apt install -y graphviz + - Windows: choco install graphviz + - MacOS: brew install graphviz +More detailed information can be found in documentation of graphviz python package: https://graphviz.readthedocs.io/en/stable/#installation +""".lstrip() + + +def _check_dot_installed(): + if not shutil.which('dot'): + raise EnvironmentError(_DOT_NOT_INSTALL_TEXT) + + def build_graph(*roots, node_id_gen: Optional[Callable] = None, graph_title: Optional[str] = None, graph_name: Optional[str] = None, graph_cfg: Optional[Mapping[str, Any]] = None, @@ -100,11 +129,13 @@ def build_graph(*roots, node_id_gen: Optional[Callable] = None, Returns: - dot (:obj:`Digraph`): Graphviz directed graph object. """ + _check_dot_installed() + roots = [_root_process(root, index) for index, root in enumerate(roots)] roots = [item for item in roots if item is not None] node_id_gen = dynamic_call(suffixed_node_id(node_id_gen or _default_node_id)) - graph_title = graph_title or ('untitled_' + random_hex_with_timestamp()) + graph_title = graph_title or '' graph_name = graph_name or _title_flatten(graph_title) graph_cfg = _no_none_value(graph_cfg or {}) @@ -125,7 +156,7 @@ def build_graph(*roots, node_id_gen: Optional[Callable] = None, root_node_id = node_id_gen(root, None, [], [], True) if root_node_id not in _queued_node_ids: graph.node( - name=root_node_id, label=root_title, + name=root_node_id, label=_custom_html_escape(root_title), **node_cfg_gen(root, None, [], [], True, True, root_info) ) _queue.put((root_node_id, root, (root, root_title, root_index), [])) @@ -145,14 +176,14 @@ def build_graph(*roots, node_id_gen: Optional[Callable] = None, _current_label = repr_gen(_current_node, _current_path) if _current_id not in _queued_node_ids: - graph.node(_current_id, label=_current_label, + graph.node(_current_id, label=_custom_html_escape(_current_label), **node_cfg_gen(_current_node, _parent_node, _current_path, _parent_path, _is_node, False, _root_info)) if iter_gen(_current_node, _current_path): _queue.put((_current_id, _current_node, _root_info, _current_path)) _queued_node_ids.add(_current_id) if (_parent_id, _current_id, key) not in _queued_edges: - graph.edge(_parent_id, _current_id, label=key, + graph.edge(_parent_id, _current_id, label=_custom_html_escape(key), **edge_cfg_gen(_current_node, _parent_node, _current_path, _parent_path, _is_node, _root_info)) _queued_edges.add((_parent_id, _current_id, key))