提交 ba8e6f92 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): upgrade auto function

上级 a44c0b66
......@@ -19,17 +19,21 @@ func_treelize = post_process(post_process(args_mapping(
doc_from_base = replaceable_partial(original_doc_from_base, base=torch)
auto_tensor = replaceable_partial(auto_torch, cls=Tensor)
_funcs_module = '.'.join(__name__.split('.')[:-1])
def get_func_from_torch(name):
func = getattr(torch, name)
return_self_dec = return_self if func.__name__.endswith("_") else (lambda x: x)
@func_treelize()
@wraps(func)
@doc_from_base()
@return_self_dec
@post_process(auto_tensor)
@func_treelize(return_type=TreeValue, rise=True)
@wraps(func, assigned=('__name__',), updated=())
def _new_func(*args, **kwargs):
return func(*args, **kwargs)
if func.__name__.endswith("_"):
_new_func = return_self(_new_func)
_new_func = doc_from_base()(_new_func)
_new_func.__qualname__ = _new_func.__name__
_new_func.__module__ = _funcs_module
return _new_func
from functools import wraps
from types import MethodType
import numpy as np
import torch
......@@ -25,10 +26,33 @@ def _to_tensor(*args, **kwargs):
return torch.tensor(*args, **kwargs)
# noinspection PyMethodParameters
class _TensorMeta(clsmeta(_to_tensor, allow_dict=True)):
def __getattr__(cls, name):
if hasattr(torch.Tensor, name) and not name.startswith('_') \
and callable(getattr(torch.Tensor, name)):
_origin_func = getattr(torch.Tensor, name)
return_self_deco = return_self if name.endswith('_') else (lambda x: x)
@doc_from_base()
@return_self_deco
@post_process(lambda r: replaceable_partial(auto_torch, cls=cls)(r))
@method_treelize(return_type=TreeValue, rise=True)
@wraps(_origin_func, assigned=('__name__',), updated=())
def _new_func(*args, **kwargs):
return _origin_func(*args, **kwargs)
_new_func.__qualname__ = f'{cls.__name__}.{name}'
_new_func.__module__ = cls.__module__
return _new_func
else:
raise AttributeError(f"type object {repr(cls.__name__)} has no attribute {repr(name)}")
# noinspection PyTypeChecker
@current_names()
@class_autoremove
class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
class Tensor(Torch, metaclass=_TensorMeta):
# noinspection PyUnusedLocal
def __init__(self, data, *args, **kwargs):
"""
......@@ -67,23 +91,15 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return getattr(self, key)
def _attr_extern(self, name):
tree = self.__get_attr(name)
if hasattr(torch.Tensor, name) and not name.startswith('_') \
and callable(getattr(torch.Tensor, name)):
from .attr import TensorMethod
tree = tree.type(TensorMethod)
@wraps(getattr(torch, name), assigned=('__name__', '__qualname__'), updated=())
def _new_func(*args, **kwargs):
result = tree(*args, **kwargs)
return self if name.endswith('_') else result
_new_func.__self__ = self
return _new_func
elif tree.map(lambda x: torch.is_tensor(x)).all():
return tree.type(Tensor)
return MethodType(getattr(self.__class__, name), self)
else:
return tree
tree = self.__get_attr(name)
if tree.map(lambda x: torch.is_tensor(x)).all():
return tree.type(Tensor)
else:
return tree
@doc_from_base()
@method_treelize(return_type=ndarray)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册