trees.py 4.2 KB
Newer Older
1 2 3 4 5
import builtins
import io
import os
from functools import partial
from typing import Optional, Tuple, Callable
6

7
from treevalue import func_treelize as original_func_treelize
8
from treevalue import general_tree_value, TreeValue
9
from treevalue.tree.common import BaseTree
10
from treevalue.tree.tree.tree import get_data_property
11 12 13
from treevalue.utils import post_process

from ..utils import replaceable_partial, args_mapping
14

15
__all__ = [
16 17
    'BaseTreeStruct', "Object",
    'print_tree', 'clsmeta',
18 19
]

20

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
def _tree_title(node: TreeValue):
    _tree = get_data_property(node)
    return "<{cls} {id}>".format(
        cls=node.__class__.__name__,
        id=hex(id(_tree.actual())),
    )


def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, file=None):
    print_to_file = partial(builtins.print, file=file)
    node_ids = {}
    if ascii_:
        _HORI, _VECT, _CROS, _SROS = '|', '-', '+', '+'
    else:
        _HORI, _VECT, _CROS, _SROS = '\u2502', '\u2500', '\u251c', '\u2514'

    def _print_layer(node, path: Tuple[str, ...], prefixes: Tuple[str, ...],
                     current_key: Optional[str] = None, is_last_key: bool = True):
        # noinspection PyShadowingBuiltins
        def print(*args, pid: Optional[int] = -1, **kwargs, ):
            if pid is not None:
                print_to_file(prefixes[pid], end='')
            print_to_file(*args, **kwargs)

        _need_iter = True
        if isinstance(node, TreeValue):
            _node_id = id(get_data_property(node).actual())
            _content = f'<{node.__class__.__name__} {hex(_node_id)}>'
            if _node_id in node_ids.keys():
                _str_old_path = '.'.join(('<root>', *node_ids[_node_id]))
                _content = f'{_content}{os.linesep}(The same address as {_str_old_path})'
                _need_iter = False
            else:
                node_ids[_node_id] = path
                _need_iter = True
        else:
            _content = repr_(node)
            _need_iter = False

        if current_key:
            _key_arrow = f'{current_key} --> '
            _appended_prefix = (_HORI if _need_iter and len(node) > 0 else ' ') + ' ' * (len(_key_arrow) - 1)
            for index, line in enumerate(_content.splitlines()):
                if index == 0:
                    print(f'{_CROS if not is_last_key else _SROS}{_VECT * 2} {_key_arrow}', pid=-2, end='')
                else:
                    print(_appended_prefix, end='')
                print(line, pid=None)
        else:
            print(_content)

        if _need_iter:
            _length = len(node)
            for index, (key, value) in enumerate(sorted(node)):
                _is_last_line = index + 1 >= _length
                _new_prefixes = (*prefixes, prefixes[-1] + f'{_HORI if not _is_last_line else " "}   ')
                _new_path = (*path, key)
                _print_layer(value, _new_path, _new_prefixes, key, _is_last_line)

    if isinstance(tree, TreeValue):
        _print_layer(tree, (), ('', '',))
    else:
        print(repr_(tree), file=file)


86
class BaseTreeStruct(general_tree_value()):
87 88 89 90
    """
    Overview:
        Base structure of all the trees in ``treetensor``.
    """
91 92 93 94 95 96 97 98

    def __repr__(self):
        with io.StringIO() as sfile:
            print_tree(self, repr_=repr, ascii_=False, file=sfile)
            return sfile.getvalue()

    def __str__(self):
        return self.__repr__()
99 100


101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
def clsmeta(cls: type, allow_dict: bool = False, allow_data: bool = True):
    class _TempTreeValue(TreeValue):
        pass

    _types = (
        TreeValue,
        *((dict,) if allow_dict else ()),
        *((BaseTree,) if allow_data else ()),
    )
    func_treelize = post_process(post_process(args_mapping(
        lambda i, x: TreeValue(x) if isinstance(x, _types) else x)))(
        replaceable_partial(original_func_treelize, return_type=_TempTreeValue)
    )

    _torch_size = func_treelize()(cls)

    class _MetaClass(type):
        def __call__(cls, *args, **kwargs):
            _result = _torch_size(*args, **kwargs)
            if isinstance(_result, _TempTreeValue):
                return type.__call__(cls, _result)
            else:
                return _result

    return _MetaClass


class Object(BaseTreeStruct):
129
    pass