提交 cdc105f5 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): apply for binary treevalue

上级 ee203a17
......@@ -7,7 +7,7 @@ from typing import Type
from treevalue import func_treelize as original_func_treelize
from treevalue import general_tree_value, TreeValue, typetrans
from treevalue.tree.common import BaseTree
from treevalue.tree.common import TreeStorage
from treevalue.tree.tree.tree import get_data_property
from treevalue.utils import post_process
......@@ -49,7 +49,7 @@ def print_tree(tree: TreeValue, repr_: Callable = str,
_need_iter = True
if isinstance(node, TreeValue):
_node_id = id(get_data_property(node).actual())
_node_id = id(get_data_property(node))
if show_node_id:
_content = f'<{node.__class__.__name__} {hex(_node_id)}>'
else:
......@@ -152,7 +152,7 @@ def clsmeta(func, allow_dict: bool = False) -> Type[type]:
def _mapping_func(_, x):
if isinstance(x, TreeValue):
return x
elif isinstance(x, BaseTree):
elif isinstance(x, TreeStorage):
return TreeValue(x)
elif allow_dict and isinstance(x, dict):
return TreeValue(x)
......@@ -167,7 +167,7 @@ def clsmeta(func, allow_dict: bool = False) -> Type[type]:
class _MetaClass(type):
def __call__(cls, data, *args, **kwargs):
if isinstance(data, BaseTree):
if isinstance(data, TreeStorage):
return type.__call__(cls, data)
elif isinstance(data, cls) and not args and not kwargs:
return data
......
......@@ -3,7 +3,7 @@ import builtins
import numpy as np
from treevalue import TreeValue
from treevalue import func_treelize as original_func_treelize
from treevalue.tree.common import BaseTree
from treevalue.tree.common import TreeStorage
from treevalue.utils import post_process
from .array import ndarray
......@@ -16,7 +16,7 @@ __all__ = [
]
func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))(
lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeStorage, TreeValue)) else x)))(
replaceable_partial(original_func_treelize, return_type=ndarray)
)
get_func_from_numpy = module_func_loader(np, ndarray,
......
import torch
from treevalue import TreeValue
from treevalue import func_treelize as original_func_treelize
from treevalue.tree.common import BaseTree
from treevalue.tree.common import TreeStorage
from ..tensor import Tensor
from ...common import auto_tree, module_func_loader
......
import torch
from treevalue import TreeValue
from treevalue.tree.common import BaseTree
from treevalue.tree.common import TreeStorage
from .base import doc_from_base, func_treelize
from ...utils import args_mapping
......@@ -15,7 +15,7 @@ __all__ = [
'empty', 'empty_like',
]
args_treelize = args_mapping(lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)
args_treelize = args_mapping(lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeStorage, TreeValue)) else x)
@doc_from_base()
......
......@@ -3,7 +3,7 @@ from functools import wraps
import torch
from treevalue import TreeValue
from treevalue import func_treelize as original_func_treelize
from treevalue.tree.common import BaseTree
from treevalue.tree.common import TreeStorage
from treevalue.utils import post_process
from .base import Torch
......@@ -12,7 +12,7 @@ from ..utils import doc_from_base as original_doc_from_base
from ..utils import replaceable_partial, current_names, args_mapping
func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))(
lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeStorage, TreeValue)) else x)))(
replaceable_partial(original_func_treelize)
)
doc_from_base = replaceable_partial(original_doc_from_base, base=torch.Size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册