提交 9a61cbec 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): add auto attr into Tensor class

上级 a1957a0a
from .tensor import Tensor
__all__ = [
'Tensor'
]
from treevalue import tree_class
from .tensor import Tensor
from ..base import Torch
@tree_class(return_type=Tensor)
class TensorMethod(Torch):
pass
......@@ -3,16 +3,12 @@ import torch
from treevalue import method_treelize, TreeValue
from treevalue.utils import post_process
from .base import Torch, auto_torch, rmreduce, post_reduce, auto_reduce
from .size import Size
from ..common import Object, ireduce, clsmeta, return_self
from ..numpy import ndarray
from ..utils import current_names, class_autoremove, replaceable_partial
from ..utils import doc_from_base as original_doc_from_base
__all__ = [
'Tensor'
]
from ..base import Torch, auto_torch, rmreduce, post_reduce, auto_reduce
from ..size import Size
from ...common import Object, ireduce, clsmeta, return_self
from ...numpy import ndarray
from ...utils import current_names, class_autoremove, replaceable_partial
from ...utils import doc_from_base as original_doc_from_base
doc_from_base = replaceable_partial(original_doc_from_base, base=torch.Tensor)
......@@ -64,6 +60,22 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
super(Torch, self).__init__(data)
@method_treelize(return_type=Object)
def __get_attr(self, key):
return getattr(self, key)
def _attr_extern(self, key):
tree = self.__get_attr(key)
if tree.map(lambda x: torch.is_tensor(x)).all():
type_ = Tensor
elif tree.map(lambda x: callable(x)).all():
from .attr import TensorMethod
type_ = TensorMethod
else:
type_ = Object
return tree.type(type_)
@doc_from_base()
@method_treelize(return_type=ndarray)
def numpy(self: torch.Tensor) -> np.ndarray:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册