trees.py 7.4 KB
Newer Older
1 2 3 4 5
import builtins
import io
import os
from functools import partial
from typing import Optional, Tuple, Callable
6
from typing import Type
7

HansBug's avatar
HansBug 已提交
8
from hbutils.reflection import post_process
9
from treevalue import func_treelize as original_func_treelize
10
from treevalue import general_tree_value, TreeValue, typetrans
11
from treevalue.tree.common import TreeStorage
12 13

from ..utils import replaceable_partial, args_mapping
14

15
__all__ = [
16
    'BaseTreeStruct',
17
    'print_tree', 'clsmeta', 'auto_tree',
18 19
]

20

21 22
def print_tree(tree: TreeValue, repr_: Callable = str,
               ascii_: bool = False, show_node_id: bool = True, file=None):
23 24 25 26 27 28
    """
    Overview:
        Print a tree structure to the given file.

    Arguments:
        - tree (:obj:`TreeValue`): Given tree object.
29 30
        - repr\\_ (:obj:`Callable`): Representation function, default is ``str``.
        - ascii\\_ (:obj:`bool`): Use ascii to print the tree, default is ``False``.
31 32 33
        - show_node_id (:obj:`bool`): Show node id of the tree, default is ``True``.
        - file: Output file of this print procedure, default is ``None`` which means to stdout. 
    """
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
    print_to_file = partial(builtins.print, file=file)
    node_ids = {}
    if ascii_:
        _HORI, _VECT, _CROS, _SROS = '|', '-', '+', '+'
    else:
        _HORI, _VECT, _CROS, _SROS = '\u2502', '\u2500', '\u251c', '\u2514'

    def _print_layer(node, path: Tuple[str, ...], prefixes: Tuple[str, ...],
                     current_key: Optional[str] = None, is_last_key: bool = True):
        # noinspection PyShadowingBuiltins
        def print(*args, pid: Optional[int] = -1, **kwargs, ):
            if pid is not None:
                print_to_file(prefixes[pid], end='')
            print_to_file(*args, **kwargs)

        _need_iter = True
        if isinstance(node, TreeValue):
51
            _node_id = id(node._detach())
52 53 54 55
            if show_node_id:
                _content = f'<{node.__class__.__name__} {hex(_node_id)}>'
            else:
                _content = f'<{node.__class__.__name__}>'
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
            if _node_id in node_ids.keys():
                _str_old_path = '.'.join(('<root>', *node_ids[_node_id]))
                _content = f'{_content}{os.linesep}(The same address as {_str_old_path})'
                _need_iter = False
            else:
                node_ids[_node_id] = path
                _need_iter = True
        else:
            _content = repr_(node)
            _need_iter = False

        if current_key:
            _key_arrow = f'{current_key} --> '
            _appended_prefix = (_HORI if _need_iter and len(node) > 0 else ' ') + ' ' * (len(_key_arrow) - 1)
            for index, line in enumerate(_content.splitlines()):
                if index == 0:
                    print(f'{_CROS if not is_last_key else _SROS}{_VECT * 2} {_key_arrow}', pid=-2, end='')
                else:
                    print(_appended_prefix, end='')
                print(line, pid=None)
        else:
            print(_content)

        if _need_iter:
            _length = len(node)
            for index, (key, value) in enumerate(sorted(node)):
                _is_last_line = index + 1 >= _length
                _new_prefixes = (*prefixes, prefixes[-1] + f'{_HORI if not _is_last_line else " "}   ')
                _new_path = (*path, key)
                _print_layer(value, _new_path, _new_prefixes, key, _is_last_line)

    if isinstance(tree, TreeValue):
        _print_layer(tree, (), ('', '',))
    else:
        print(repr_(tree), file=file)


93
class BaseTreeStruct(general_tree_value()):
94 95 96 97
    """
    Overview:
        Base structure of all the trees in ``treetensor``.
    """
98 99

    def __repr__(self):
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
        """
        Get the tree-based representation format of this object.

        Examples::

            >>> from treetensor.common import Object
            >>> repr(Object(1))  # Object is subclass of BaseTreeStruct
            '1'

            >>> repr(Object({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}}))
            '<Object 0x7fe00b121220>\n├── a --> 1\n├── b --> 2\n└── x --> <Object 0x7fe00b121c10>\n    ├── c --> 3\n    └── d --> 4\n'

            >>> Object({'a': 1, 'b': 2, 'x': {'c': 3, 'd': 4}})
            <Object 0x7fe00b1271c0>
            ├── a --> 1
            ├── b --> 2
            └── x --> <Object 0x7fe00b127910>
                ├── c --> 3
                └── d --> 4
        """
120 121 122 123 124
        with io.StringIO() as sfile:
            print_tree(self, repr_=repr, ascii_=False, file=sfile)
            return sfile.getvalue()

    def __str__(self):
125 126 127
        """
        The same as :py:meth:`BaseTreeStruct.__repr__`.
        """
128
        return self.__repr__()
129 130


131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
def clsmeta(func, allow_dict: bool = False) -> Type[type]:
    """
    Overview:
        Create a metaclass based on generating function.

        Used in :py:class:`treetensor.common.Object`,
        :py:class:`treetensor.torch.Tensor` and :py:class:`treetensor.torch.Size`.
        Can do modify onto the constructor function of the classes.

    Arguments:
        - func: Generating function.
        - allow_dict (:obj:`bool`): Auto transform dictionary to :py:class:`treevalue.TreeValue` class, \
                                    default is ``False``.
    Returns:
        - metaclass (:obj:`Type[type]`): Metaclass for creating a new class.
    """

148 149 150
    class _TempTreeValue(TreeValue):
        pass

151 152 153
    def _mapping_func(_, x):
        if isinstance(x, TreeValue):
            return x
154
        elif isinstance(x, TreeStorage):
155 156 157 158 159 160 161
            return TreeValue(x)
        elif allow_dict and isinstance(x, dict):
            return TreeValue(x)
        else:
            return x

    func_treelize = post_process(post_process(args_mapping(_mapping_func)))(
162 163 164
        replaceable_partial(original_func_treelize, return_type=_TempTreeValue)
    )

165
    _wrapped_func = func_treelize()(func)
166 167

    class _MetaClass(type):
168
        def __call__(cls, data, *args, **kwargs):
169
            if isinstance(data, TreeStorage):
170
                return type.__call__(cls, data)
HansBug's avatar
HansBug 已提交
171 172
            elif isinstance(data, cls) and not args and not kwargs:
                return data
173 174

            _result = _wrapped_func(data, *args, **kwargs)
175 176 177 178 179 180
            if isinstance(_result, _TempTreeValue):
                return type.__call__(cls, _result)
            else:
                return _result

    return _MetaClass
181 182


183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
def _auto_tree_func(t, cls):
    from .object import Object
    t = typetrans(t, return_type=Object)
    for key, value in cls:
        if isinstance(key, type):
            predict = lambda x: isinstance(x, key)
        elif callable(key):
            predict = lambda x: key(x)
        else:
            raise TypeError(f'Unknown type of prediction - {repr(key)}.')

        if t.map(predict).all():
            return typetrans(t, return_type=value)
    return t


199 200 201 202
# noinspection PyArgumentList
def auto_tree(v, cls):
    if isinstance(cls, type) and issubclass(cls, TreeValue):
        cls = partial(typetrans, return_type=cls)
203 204 205 206 207 208
    elif isinstance(cls, (list, tuple)):
        cls = partial(_auto_tree_func, cls=cls)
    elif callable(cls):
        pass
    else:
        raise TypeError(f'Unknown type of cls - {repr(cls)}.')
209 210 211 212 213 214 215 216 217

    if isinstance(v, TreeValue):
        return cls(v)
    elif isinstance(v, (tuple, list, set)):
        return type(v)((auto_tree(item, cls) for item in v))
    elif isinstance(v, dict):
        return type(v)({key: auto_tree(value, cls) for key, value in v.items()})
    else:
        return v