trees.py 4.0 KB
Newer Older
1 2 3 4 5
import builtins
import io
import os
from functools import partial
from typing import Optional, Tuple, Callable
6

7
from treevalue import func_treelize as original_func_treelize
8
from treevalue import general_tree_value, TreeValue
9
from treevalue.tree.common import BaseTree
10
from treevalue.tree.tree.tree import get_data_property
11 12 13
from treevalue.utils import post_process

from ..utils import replaceable_partial, args_mapping
14

15
__all__ = [
16
    'BaseTreeStruct',
17
    'print_tree', 'clsmeta',
18 19
]

20

21 22
def print_tree(tree: TreeValue, repr_: Callable = str,
               ascii_: bool = False, show_node_id: bool = True, file=None):
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
    print_to_file = partial(builtins.print, file=file)
    node_ids = {}
    if ascii_:
        _HORI, _VECT, _CROS, _SROS = '|', '-', '+', '+'
    else:
        _HORI, _VECT, _CROS, _SROS = '\u2502', '\u2500', '\u251c', '\u2514'

    def _print_layer(node, path: Tuple[str, ...], prefixes: Tuple[str, ...],
                     current_key: Optional[str] = None, is_last_key: bool = True):
        # noinspection PyShadowingBuiltins
        def print(*args, pid: Optional[int] = -1, **kwargs, ):
            if pid is not None:
                print_to_file(prefixes[pid], end='')
            print_to_file(*args, **kwargs)

        _need_iter = True
        if isinstance(node, TreeValue):
            _node_id = id(get_data_property(node).actual())
41 42 43 44
            if show_node_id:
                _content = f'<{node.__class__.__name__} {hex(_node_id)}>'
            else:
                _content = f'<{node.__class__.__name__}>'
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
            if _node_id in node_ids.keys():
                _str_old_path = '.'.join(('<root>', *node_ids[_node_id]))
                _content = f'{_content}{os.linesep}(The same address as {_str_old_path})'
                _need_iter = False
            else:
                node_ids[_node_id] = path
                _need_iter = True
        else:
            _content = repr_(node)
            _need_iter = False

        if current_key:
            _key_arrow = f'{current_key} --> '
            _appended_prefix = (_HORI if _need_iter and len(node) > 0 else ' ') + ' ' * (len(_key_arrow) - 1)
            for index, line in enumerate(_content.splitlines()):
                if index == 0:
                    print(f'{_CROS if not is_last_key else _SROS}{_VECT * 2} {_key_arrow}', pid=-2, end='')
                else:
                    print(_appended_prefix, end='')
                print(line, pid=None)
        else:
            print(_content)

        if _need_iter:
            _length = len(node)
            for index, (key, value) in enumerate(sorted(node)):
                _is_last_line = index + 1 >= _length
                _new_prefixes = (*prefixes, prefixes[-1] + f'{_HORI if not _is_last_line else " "}   ')
                _new_path = (*path, key)
                _print_layer(value, _new_path, _new_prefixes, key, _is_last_line)

    if isinstance(tree, TreeValue):
        _print_layer(tree, (), ('', '',))
    else:
        print(repr_(tree), file=file)


82
class BaseTreeStruct(general_tree_value()):
83 84 85 86
    """
    Overview:
        Base structure of all the trees in ``treetensor``.
    """
87 88 89 90 91 92 93 94

    def __repr__(self):
        with io.StringIO() as sfile:
            print_tree(self, repr_=repr, ascii_=False, file=sfile)
            return sfile.getvalue()

    def __str__(self):
        return self.__repr__()
95 96


97
def clsmeta(func, allow_dict: bool = False):
98 99 100 101
    class _TempTreeValue(TreeValue):
        pass

    _types = (
102
        TreeValue, BaseTree,
103 104 105 106 107 108 109
        *((dict,) if allow_dict else ()),
    )
    func_treelize = post_process(post_process(args_mapping(
        lambda i, x: TreeValue(x) if isinstance(x, _types) else x)))(
        replaceable_partial(original_func_treelize, return_type=_TempTreeValue)
    )

110
    _wrapped_func = func_treelize()(func)
111 112 113

    class _MetaClass(type):
        def __call__(cls, *args, **kwargs):
114
            _result = _wrapped_func(*args, **kwargs)
115 116 117 118 119 120
            if isinstance(_result, _TempTreeValue):
                return type.__call__(cls, _result)
            else:
                return _result

    return _MetaClass