From f45bd5009e632aa29db95dab9087d49ff00ea476 Mon Sep 17 00:00:00 2001 From: HansBug Date: Thu, 16 Sep 2021 15:46:20 +0800 Subject: [PATCH] dev(hansbug): add __str__ and __repr__ support for Tensor --- treetensor/common/trees.py | 84 ++++++++++++++++++++++++++++++++-- treetensor/torch/funcs.py | 93 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 170 insertions(+), 7 deletions(-) diff --git a/treetensor/common/trees.py b/treetensor/common/trees.py index 2f68444d5..9b528d0c8 100644 --- a/treetensor/common/trees.py +++ b/treetensor/common/trees.py @@ -1,18 +1,96 @@ +import builtins +import io +import os from abc import ABCMeta +from functools import partial +from typing import Optional, Tuple, Callable -from treevalue import general_tree_value +from treevalue import general_tree_value, TreeValue +from treevalue.tree.tree.tree import get_data_property __all__ = [ - 'BaseTreeStruct', "TreeObject", + 'BaseTreeStruct', "TreeObject", 'print_tree', ] +def _tree_title(node: TreeValue): + _tree = get_data_property(node) + return "<{cls} {id}>".format( + cls=node.__class__.__name__, + id=hex(id(_tree.actual())), + ) + + +def print_tree(tree: TreeValue, repr_: Callable = str, ascii_: bool = False, file=None): + print_to_file = partial(builtins.print, file=file) + node_ids = {} + if ascii_: + _HORI, _VECT, _CROS, _SROS = '|', '-', '+', '+' + else: + _HORI, _VECT, _CROS, _SROS = '\u2502', '\u2500', '\u251c', '\u2514' + + def _print_layer(node, path: Tuple[str, ...], prefixes: Tuple[str, ...], + current_key: Optional[str] = None, is_last_key: bool = True): + # noinspection PyShadowingBuiltins + def print(*args, pid: Optional[int] = -1, **kwargs, ): + if pid is not None: + print_to_file(prefixes[pid], end='') + print_to_file(*args, **kwargs) + + _need_iter = True + if isinstance(node, TreeValue): + _node_id = id(get_data_property(node).actual()) + _content = f'<{node.__class__.__name__} {hex(_node_id)}>' + if _node_id in node_ids.keys(): + _str_old_path = '.'.join(('', *node_ids[_node_id])) + _content = f'{_content}{os.linesep}(The same address as {_str_old_path})' + _need_iter = False + else: + node_ids[_node_id] = path + _need_iter = True + else: + _content = repr_(node) + _need_iter = False + + if current_key: + _key_arrow = f'{current_key} --> ' + _appended_prefix = (_HORI if _need_iter and len(node) > 0 else ' ') + ' ' * (len(_key_arrow) - 1) + for index, line in enumerate(_content.splitlines()): + if index == 0: + print(f'{_CROS if not is_last_key else _SROS}{_VECT * 2} {_key_arrow}', pid=-2, end='') + else: + print(_appended_prefix, end='') + print(line, pid=None) + else: + print(_content) + + if _need_iter: + _length = len(node) + for index, (key, value) in enumerate(sorted(node)): + _is_last_line = index + 1 >= _length + _new_prefixes = (*prefixes, prefixes[-1] + f'{_HORI if not _is_last_line else " "} ') + _new_path = (*path, key) + _print_layer(value, _new_path, _new_prefixes, key, _is_last_line) + + if isinstance(tree, TreeValue): + _print_layer(tree, (), ('', '',)) + else: + print(repr_(tree), file=file) + + class BaseTreeStruct(general_tree_value(), metaclass=ABCMeta): """ Overview: Base structure of all the trees in ``treetensor``. """ - pass + + def __repr__(self): + with io.StringIO() as sfile: + print_tree(self, repr_=repr, ascii_=False, file=sfile) + return sfile.getvalue() + + def __str__(self): + return self.__repr__() class TreeObject(BaseTreeStruct): diff --git a/treetensor/torch/funcs.py b/treetensor/torch/funcs.py index ec1fbd8e8..a7a5f8b61 100644 --- a/treetensor/torch/funcs.py +++ b/treetensor/torch/funcs.py @@ -23,8 +23,8 @@ __all__ = [ 'empty', 'empty_like', 'all', 'any', 'min', 'max', 'sum', - 'eq', 'equal', - 'tensor', + 'eq', 'ne', 'lt', 'le', 'gt', 'ge', + 'equal', 'tensor', ] func_treelize = post_process(post_process(args_mapping( @@ -446,15 +446,100 @@ def sum(input, *args, **kwargs): @doc_from(torch.eq) @func_treelize() def eq(input, other, *args, **kwargs): + """ + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.eq( + >>> torch.tensor([[1, 2], [3, 4]]), + >>> torch.tensor([[1, 1], [4, 4]]), + >>> ) + torch.tensor([[ True, False], + [False, True]]) + + >>> ttorch.eq( + >>> ttorch.tensor({ + >>> 'a': [[1, 2], [3, 4]], + >>> 'b': [1.0, 1.5, 2.0], + >>> }), + >>> ttorch.tensor({ + >>> 'a': [[1, 1], [4, 4]], + >>> 'b': [1.3, 1.2, 2.0], + >>> }), + >>> ) + """ return torch.eq(input, other, *args, **kwargs) +# noinspection PyShadowingBuiltins +@doc_from(torch.ne) +@func_treelize() +def ne(input, other, *args, **kwargs): + return torch.ne(input, other, *args, **kwargs) + + +# noinspection PyShadowingBuiltins +@doc_from(torch.lt) +@func_treelize() +def lt(input, other, *args, **kwargs): + return torch.lt(input, other, *args, **kwargs) + + +# noinspection PyShadowingBuiltins +@doc_from(torch.le) +@func_treelize() +def le(input, other, *args, **kwargs): + return torch.le(input, other, *args, **kwargs) + + +# noinspection PyShadowingBuiltins +@doc_from(torch.gt) +@func_treelize() +def gt(input, other, *args, **kwargs): + return torch.gt(input, other, *args, **kwargs) + + +# noinspection PyShadowingBuiltins +@doc_from(torch.ge) +@func_treelize() +def ge(input, other, *args, **kwargs): + return torch.ge(input, other, *args, **kwargs) + + # noinspection PyShadowingBuiltins,PyArgumentList @doc_from(torch.equal) @ireduce(builtins.all) @func_treelize() -def equal(input, other, *args, **kwargs): - return torch.equal(input, other, *args, **kwargs) +def equal(input, other): + """ + In ``treetensor``, you can get the equality of the two tree tensors. + + Examples:: + + >>> import torch + >>> import treetensor.torch as ttorch + >>> ttorch.equal( + >>> torch.tensor([1, 2, 3]), + >>> torch.tensor([1, 2, 3]), + >>> ) # the same as torch.equal + True + + >>> ttorch.equal( + >>> ttorch.tensor({ + >>> 'a': torch.tensor([1, 2, 3]), + >>> 'b': torch.tensor([[4, 5], [6, 7]]), + >>> }), + >>> ttorch.tensor({ + >>> 'a': torch.tensor([1, 2, 3]), + >>> 'b': torch.tensor([[4, 5], [6, 7]]), + >>> }), + >>> ) + True + + """ + return torch.equal(input, other) @doc_from(torch.tensor) -- GitLab