提交 a44c0b66 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): optimize TensorMethod and Tensor

上级 a5a6704c
from treevalue import tree_class
from treevalue import method_treelize, TreeValue
from treevalue.utils import post_process
from .tensor import Tensor
from ..base import Torch
from ..base import Torch, auto_torch
from ..tensor import Tensor
from ...utils import replaceable_partial
auto_tensor = replaceable_partial(auto_torch, cls=Tensor)
@tree_class(return_type=Tensor)
class TensorMethod(Torch):
pass
@post_process(auto_tensor)
@method_treelize(return_type=TreeValue, rise=True)
def __call__(self, *args, **kwargs):
return self(*args, **kwargs)
from functools import wraps
import numpy as np
import torch
from treevalue import method_treelize, TreeValue
......@@ -64,17 +66,24 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
def __get_attr(self, key):
return getattr(self, key)
def _attr_extern(self, key):
tree = self.__get_attr(key)
if tree.map(lambda x: torch.is_tensor(x)).all():
type_ = Tensor
elif tree.map(lambda x: callable(x)).all():
def _attr_extern(self, name):
tree = self.__get_attr(name)
if hasattr(torch.Tensor, name) and not name.startswith('_') \
and callable(getattr(torch.Tensor, name)):
from .attr import TensorMethod
type_ = TensorMethod
else:
type_ = Object
tree = tree.type(TensorMethod)
return tree.type(type_)
@wraps(getattr(torch, name), assigned=('__name__', '__qualname__'), updated=())
def _new_func(*args, **kwargs):
result = tree(*args, **kwargs)
return self if name.endswith('_') else result
_new_func.__self__ = self
return _new_func
elif tree.map(lambda x: torch.is_tensor(x)).all():
return tree.type(Tensor)
else:
return tree
@doc_from_base()
@method_treelize(return_type=ndarray)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册