From cdc105f56e94d561e81204b3984e04a4c1abe92b Mon Sep 17 00:00:00 2001 From: HansBug Date: Thu, 14 Oct 2021 10:06:38 +0800 Subject: [PATCH] dev(hansbug): apply for binary treevalue --- treetensor/common/trees.py | 8 ++++---- treetensor/numpy/funcs.py | 4 ++-- treetensor/torch/funcs/base.py | 2 +- treetensor/torch/funcs/construct.py | 4 ++-- treetensor/torch/size.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/treetensor/common/trees.py b/treetensor/common/trees.py index e23aca288..6fa56ad12 100644 --- a/treetensor/common/trees.py +++ b/treetensor/common/trees.py @@ -7,7 +7,7 @@ from typing import Type from treevalue import func_treelize as original_func_treelize from treevalue import general_tree_value, TreeValue, typetrans -from treevalue.tree.common import BaseTree +from treevalue.tree.common import TreeStorage from treevalue.tree.tree.tree import get_data_property from treevalue.utils import post_process @@ -49,7 +49,7 @@ def print_tree(tree: TreeValue, repr_: Callable = str, _need_iter = True if isinstance(node, TreeValue): - _node_id = id(get_data_property(node).actual()) + _node_id = id(get_data_property(node)) if show_node_id: _content = f'<{node.__class__.__name__} {hex(_node_id)}>' else: @@ -152,7 +152,7 @@ def clsmeta(func, allow_dict: bool = False) -> Type[type]: def _mapping_func(_, x): if isinstance(x, TreeValue): return x - elif isinstance(x, BaseTree): + elif isinstance(x, TreeStorage): return TreeValue(x) elif allow_dict and isinstance(x, dict): return TreeValue(x) @@ -167,7 +167,7 @@ def clsmeta(func, allow_dict: bool = False) -> Type[type]: class _MetaClass(type): def __call__(cls, data, *args, **kwargs): - if isinstance(data, BaseTree): + if isinstance(data, TreeStorage): return type.__call__(cls, data) elif isinstance(data, cls) and not args and not kwargs: return data diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py index d7845dc37..23326466e 100644 --- a/treetensor/numpy/funcs.py +++ b/treetensor/numpy/funcs.py @@ -3,7 +3,7 @@ import builtins import numpy as np from treevalue import TreeValue from treevalue import func_treelize as original_func_treelize -from treevalue.tree.common import BaseTree +from treevalue.tree.common import TreeStorage from treevalue.utils import post_process from .array import ndarray @@ -16,7 +16,7 @@ __all__ = [ ] func_treelize = post_process(post_process(args_mapping( - lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))( + lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeStorage, TreeValue)) else x)))( replaceable_partial(original_func_treelize, return_type=ndarray) ) get_func_from_numpy = module_func_loader(np, ndarray, diff --git a/treetensor/torch/funcs/base.py b/treetensor/torch/funcs/base.py index f75ac9ff4..c1098e195 100644 --- a/treetensor/torch/funcs/base.py +++ b/treetensor/torch/funcs/base.py @@ -1,7 +1,7 @@ import torch from treevalue import TreeValue from treevalue import func_treelize as original_func_treelize -from treevalue.tree.common import BaseTree +from treevalue.tree.common import TreeStorage from ..tensor import Tensor from ...common import auto_tree, module_func_loader diff --git a/treetensor/torch/funcs/construct.py b/treetensor/torch/funcs/construct.py index 9ecaa525b..64e9b1e92 100644 --- a/treetensor/torch/funcs/construct.py +++ b/treetensor/torch/funcs/construct.py @@ -1,6 +1,6 @@ import torch from treevalue import TreeValue -from treevalue.tree.common import BaseTree +from treevalue.tree.common import TreeStorage from .base import doc_from_base, func_treelize from ...utils import args_mapping @@ -15,7 +15,7 @@ __all__ = [ 'empty', 'empty_like', ] -args_treelize = args_mapping(lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x) +args_treelize = args_mapping(lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeStorage, TreeValue)) else x) @doc_from_base() diff --git a/treetensor/torch/size.py b/treetensor/torch/size.py index 31295744f..9709dee1b 100644 --- a/treetensor/torch/size.py +++ b/treetensor/torch/size.py @@ -3,7 +3,7 @@ from functools import wraps import torch from treevalue import TreeValue from treevalue import func_treelize as original_func_treelize -from treevalue.tree.common import BaseTree +from treevalue.tree.common import TreeStorage from treevalue.utils import post_process from .base import Torch @@ -12,7 +12,7 @@ from ..utils import doc_from_base as original_doc_from_base from ..utils import replaceable_partial, current_names, args_mapping func_treelize = post_process(post_process(args_mapping( - lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))( + lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeStorage, TreeValue)) else x)))( replaceable_partial(original_func_treelize) ) doc_from_base = replaceable_partial(original_doc_from_base, base=torch.Size) -- GitLab