提交 4044bf85 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): upgrade module and proxy system

上级 4f09f8d3
import inspect
from functools import wraps
from typing import Type
......@@ -16,13 +17,15 @@ __all__ = [
]
def module_func_loader(base, cls: Type[TreeValue], module_name: str):
def module_func_loader(base, cls: Type[TreeValue], cls_mapper=None):
func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))(
replaceable_partial(original_func_treelize, return_type=cls)
)
doc_from_base = replaceable_partial(original_doc_from_base, base=base)
auto_tree_cls = replaceable_partial(auto_tree, cls=cls)
outer_frame = inspect.currentframe().f_back
outer_module = outer_frame.f_globals.get('__name__', None)
auto_tree_cls = replaceable_partial(auto_tree, cls=cls_mapper or cls)
def _load_func(name):
func = getattr(base, name)
......@@ -37,7 +40,7 @@ def module_func_loader(base, cls: Type[TreeValue], module_name: str):
return func(*args, **kwargs)
_new_func.__qualname__ = _new_func.__name__
_new_func.__module__ = module_name
_new_func.__module__ = outer_module
return _new_func
return _load_func
import inspect
from functools import wraps
from types import MethodType
......@@ -14,8 +15,10 @@ __all__ = [
]
def get_tree_proxy(base):
def get_tree_proxy(base, cls_mapper=None):
doc_from_base = replaceable_partial(original_doc_from_base, base=base)
outer_frame = inspect.currentframe().f_back
outer_module = outer_frame.f_globals.get('__name__', None)
class _TreeClassProxy:
def __init__(self, cls):
......@@ -29,17 +32,18 @@ def get_tree_proxy(base):
and callable(getattr(base, name)):
_origin_func = getattr(base, name)
return_self_deco = return_self if name.endswith('_') else (lambda x: x)
auto_tree_cls = replaceable_partial(auto_tree, cls=cls_mapper or self.__cls)
@doc_from_base()
@return_self_deco
@post_process(lambda r: replaceable_partial(auto_tree, cls=self.__cls)(r))
@post_process(auto_tree_cls)
@method_treelize(return_type=TreeValue, rise=True)
@wraps(_origin_func, assigned=('__name__',), updated=())
def _new_func(*args, **kwargs):
return _origin_func(*args, **kwargs)
_new_func.__qualname__ = f'{self.__cls.__name__}.{name}'
_new_func.__module__ = self.__cls.__module__
_new_func.__module__ = outer_module
self.__torch_funcs[name] = _new_func
return _new_func
else:
......
......@@ -179,10 +179,32 @@ def clsmeta(func, allow_dict: bool = False) -> Type[type]:
return _MetaClass
def _auto_tree_func(t, cls):
from .object import Object
t = typetrans(t, return_type=Object)
for key, value in cls:
if isinstance(key, type):
predict = lambda x: isinstance(x, key)
elif callable(key):
predict = lambda x: key(x)
else:
raise TypeError(f'Unknown type of prediction - {repr(key)}.')
if t.map(predict).all():
return typetrans(t, return_type=value)
return t
# noinspection PyArgumentList
def auto_tree(v, cls):
if isinstance(cls, type) and issubclass(cls, TreeValue):
cls = partial(typetrans, return_type=cls)
elif isinstance(cls, (list, tuple)):
cls = partial(_auto_tree_func, cls=cls)
elif callable(cls):
pass
else:
raise TypeError(f'Unknown type of cls - {repr(cls)}.')
if isinstance(v, TreeValue):
return cls(v)
......
......@@ -19,7 +19,8 @@ func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))(
replaceable_partial(original_func_treelize, return_type=ndarray)
)
get_func_from_numpy = module_func_loader(np, ndarray, __name__)
get_func_from_numpy = module_func_loader(np, ndarray,
[(np.ndarray, ndarray)])
@doc_from(np.all)
......
......@@ -14,5 +14,6 @@ func_treelize = post_process(post_process(args_mapping(
replaceable_partial(original_func_treelize, return_type=Tensor)
)
doc_from_base = replaceable_partial(original_doc_from_base, base=torch)
auto_tensor = replaceable_partial(auto_tree, cls=Tensor)
get_func_from_torch = module_func_loader(torch, Tensor, '.'.join(__name__.split('.')[:-1]))
auto_tensor = replaceable_partial(auto_tree, cls=[(torch.is_tensor, Tensor)])
get_func_from_torch = module_func_loader(torch, Tensor,
[(torch.is_tensor, Tensor)])
......@@ -55,6 +55,11 @@ class _TensorMeta(_BaseTensorMeta):
@current_names()
@class_autoremove
class Tensor(Torch, metaclass=_TensorMeta):
__auto_tensor = lambda x: replaceable_partial(
auto_tree,
cls=[(pytorch.is_tensor, Tensor)]
)(x)
# noinspection PyUnusedLocal
def __init__(self, data, *args, **kwargs):
"""
......@@ -399,7 +404,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
return self
# noinspection PyShadowingBuiltins
@post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r))
@post_process(__auto_tensor)
@method_treelize(return_type=TreeValue, rise=True)
def __max_nr(self, *args, **kwargs):
return pytorch.max(self, *args, **kwargs)
......@@ -419,7 +424,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
return self
# noinspection PyShadowingBuiltins
@post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r))
@post_process(__auto_tensor)
@method_treelize(return_type=TreeValue, rise=True)
def __min_nr(self, *args, **kwargs):
return pytorch.min(self, *args, **kwargs)
......@@ -439,7 +444,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
return self
# noinspection PyShadowingBuiltins
@post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r))
@post_process(__auto_tensor)
@method_treelize(return_type=TreeValue, rise=True)
def __sum_nr(self, *args, **kwargs):
return pytorch.sum(self, *args, **kwargs)
......@@ -882,7 +887,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
return self.log10_(*args, **kwargs)
@doc_from_base()
@post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r))
@post_process(__auto_tensor)
@method_treelize(return_type=TreeValue, rise=True)
def split(self, split_size, *args, **kwargs):
"""
......@@ -891,7 +896,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
return self.split(split_size, *args, **kwargs)
@doc_from_base()
@post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r))
@post_process(__auto_tensor)
@method_treelize(return_type=TreeValue, rise=True)
def chunk(self, chunks, *args, **kwargs):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册