trees.py 3.2 KB
Newer Older
1 2 3
import builtins
import io
import os
4
from abc import ABCMeta
5 6
from functools import partial
from typing import Optional, Tuple, Callable
7

8 9
from treevalue import general_tree_value, TreeValue
from treevalue.tree.tree.tree import get_data_property
10

11
__all__ = [
12
    'BaseTreeStruct', "TreeObject", 'print_tree',
13 14
]

15

16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
def _tree_title(node: TreeValue):
    _tree = get_data_property(node)
    return "<{cls} {id}>".format(
        cls=node.__class__.__name__,
        id=hex(id(_tree.actual())),
    )


def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, file=None):
    print_to_file = partial(builtins.print, file=file)
    node_ids = {}
    if ascii_:
        _HORI, _VECT, _CROS, _SROS = '|', '-', '+', '+'
    else:
        _HORI, _VECT, _CROS, _SROS = '\u2502', '\u2500', '\u251c', '\u2514'

    def _print_layer(node, path: Tuple[str, ...], prefixes: Tuple[str, ...],
                     current_key: Optional[str] = None, is_last_key: bool = True):
        # noinspection PyShadowingBuiltins
        def print(*args, pid: Optional[int] = -1, **kwargs, ):
            if pid is not None:
                print_to_file(prefixes[pid], end='')
            print_to_file(*args, **kwargs)

        _need_iter = True
        if isinstance(node, TreeValue):
            _node_id = id(get_data_property(node).actual())
            _content = f'<{node.__class__.__name__} {hex(_node_id)}>'
            if _node_id in node_ids.keys():
                _str_old_path = '.'.join(('<root>', *node_ids[_node_id]))
                _content = f'{_content}{os.linesep}(The same address as {_str_old_path})'
                _need_iter = False
            else:
                node_ids[_node_id] = path
                _need_iter = True
        else:
            _content = repr_(node)
            _need_iter = False

        if current_key:
            _key_arrow = f'{current_key} --> '
            _appended_prefix = (_HORI if _need_iter and len(node) > 0 else ' ') + ' ' * (len(_key_arrow) - 1)
            for index, line in enumerate(_content.splitlines()):
                if index == 0:
                    print(f'{_CROS if not is_last_key else _SROS}{_VECT * 2} {_key_arrow}', pid=-2, end='')
                else:
                    print(_appended_prefix, end='')
                print(line, pid=None)
        else:
            print(_content)

        if _need_iter:
            _length = len(node)
            for index, (key, value) in enumerate(sorted(node)):
                _is_last_line = index + 1 >= _length
                _new_prefixes = (*prefixes, prefixes[-1] + f'{_HORI if not _is_last_line else " "}   ')
                _new_path = (*path, key)
                _print_layer(value, _new_path, _new_prefixes, key, _is_last_line)

    if isinstance(tree, TreeValue):
        _print_layer(tree, (), ('', '',))
    else:
        print(repr_(tree), file=file)


81 82 83 84 85
class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta):
    """
    Overview:
        Base structure of all the trees in ``treetensor``.
    """
86 87 88 89 90 91 92 93

    def __repr__(self):
        with io.StringIO() as sfile:
            print_tree(self, repr_=repr, ascii_=False, file=sfile)
            return sfile.getvalue()

    def __str__(self):
        return self.__repr__()
94 95


96 97
class TreeObject(BaseTreeStruct):
    pass