From 4044bf852eebbceba7566915e4487999a2fd7de0 Mon Sep 17 00:00:00 2001 From: HansBug Date: Thu, 30 Sep 2021 15:39:30 +0800 Subject: [PATCH] dev(hansbug): upgrade module and proxy system --- treetensor/common/module.py | 9 ++++++--- treetensor/common/proxy.py | 10 +++++++--- treetensor/common/trees.py | 22 ++++++++++++++++++++++ treetensor/numpy/funcs.py | 3 ++- treetensor/torch/funcs/base.py | 5 +++-- treetensor/torch/tensor.py | 15 ++++++++++----- 6 files changed, 50 insertions(+), 14 deletions(-) diff --git a/treetensor/common/module.py b/treetensor/common/module.py index f84025987..f19e6673a 100644 --- a/treetensor/common/module.py +++ b/treetensor/common/module.py @@ -1,3 +1,4 @@ +import inspect from functools import wraps from typing import Type @@ -16,13 +17,15 @@ __all__ = [ ] -def module_func_loader(base, cls: Type[TreeValue], module_name: str): +def module_func_loader(base, cls: Type[TreeValue], cls_mapper=None): func_treelize = post_process(post_process(args_mapping( lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))( replaceable_partial(original_func_treelize, return_type=cls) ) doc_from_base = replaceable_partial(original_doc_from_base, base=base) - auto_tree_cls = replaceable_partial(auto_tree, cls=cls) + outer_frame = inspect.currentframe().f_back + outer_module = outer_frame.f_globals.get('__name__', None) + auto_tree_cls = replaceable_partial(auto_tree, cls=cls_mapper or cls) def _load_func(name): func = getattr(base, name) @@ -37,7 +40,7 @@ def module_func_loader(base, cls: Type[TreeValue], module_name: str): return func(*args, **kwargs) _new_func.__qualname__ = _new_func.__name__ - _new_func.__module__ = module_name + _new_func.__module__ = outer_module return _new_func return _load_func diff --git a/treetensor/common/proxy.py b/treetensor/common/proxy.py index 216688e85..afdd4aa34 100644 --- a/treetensor/common/proxy.py +++ b/treetensor/common/proxy.py @@ -1,3 +1,4 @@ +import inspect from functools import wraps from types import MethodType @@ -14,8 +15,10 @@ __all__ = [ ] -def get_tree_proxy(base): +def get_tree_proxy(base, cls_mapper=None): doc_from_base = replaceable_partial(original_doc_from_base, base=base) + outer_frame = inspect.currentframe().f_back + outer_module = outer_frame.f_globals.get('__name__', None) class _TreeClassProxy: def __init__(self, cls): @@ -29,17 +32,18 @@ def get_tree_proxy(base): and callable(getattr(base, name)): _origin_func = getattr(base, name) return_self_deco = return_self if name.endswith('_') else (lambda x: x) + auto_tree_cls = replaceable_partial(auto_tree, cls=cls_mapper or self.__cls) @doc_from_base() @return_self_deco - @post_process(lambda r: replaceable_partial(auto_tree, cls=self.__cls)(r)) + @post_process(auto_tree_cls) @method_treelize(return_type=TreeValue, rise=True) @wraps(_origin_func, assigned=('__name__',), updated=()) def _new_func(*args, **kwargs): return _origin_func(*args, **kwargs) _new_func.__qualname__ = f'{self.__cls.__name__}.{name}' - _new_func.__module__ = self.__cls.__module__ + _new_func.__module__ = outer_module self.__torch_funcs[name] = _new_func return _new_func else: diff --git a/treetensor/common/trees.py b/treetensor/common/trees.py index b7def1344..d88ba4ece 100644 --- a/treetensor/common/trees.py +++ b/treetensor/common/trees.py @@ -179,10 +179,32 @@ def clsmeta(func, allow_dict: bool = False) -> Type[type]: return _MetaClass +def _auto_tree_func(t, cls): + from .object import Object + t = typetrans(t, return_type=Object) + for key, value in cls: + if isinstance(key, type): + predict = lambda x: isinstance(x, key) + elif callable(key): + predict = lambda x: key(x) + else: + raise TypeError(f'Unknown type of prediction - {repr(key)}.') + + if t.map(predict).all(): + return typetrans(t, return_type=value) + return t + + # noinspection PyArgumentList def auto_tree(v, cls): if isinstance(cls, type) and issubclass(cls, TreeValue): cls = partial(typetrans, return_type=cls) + elif isinstance(cls, (list, tuple)): + cls = partial(_auto_tree_func, cls=cls) + elif callable(cls): + pass + else: + raise TypeError(f'Unknown type of cls - {repr(cls)}.') if isinstance(v, TreeValue): return cls(v) diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py index 068b03711..d7845dc37 100644 --- a/treetensor/numpy/funcs.py +++ b/treetensor/numpy/funcs.py @@ -19,7 +19,8 @@ func_treelize = post_process(post_process(args_mapping( lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))( replaceable_partial(original_func_treelize, return_type=ndarray) ) -get_func_from_numpy = module_func_loader(np, ndarray, __name__) +get_func_from_numpy = module_func_loader(np, ndarray, + [(np.ndarray, ndarray)]) @doc_from(np.all) diff --git a/treetensor/torch/funcs/base.py b/treetensor/torch/funcs/base.py index 1eab33875..f1db285f5 100644 --- a/treetensor/torch/funcs/base.py +++ b/treetensor/torch/funcs/base.py @@ -14,5 +14,6 @@ func_treelize = post_process(post_process(args_mapping( replaceable_partial(original_func_treelize, return_type=Tensor) ) doc_from_base = replaceable_partial(original_doc_from_base, base=torch) -auto_tensor = replaceable_partial(auto_tree, cls=Tensor) -get_func_from_torch = module_func_loader(torch, Tensor, '.'.join(__name__.split('.')[:-1])) +auto_tensor = replaceable_partial(auto_tree, cls=[(torch.is_tensor, Tensor)]) +get_func_from_torch = module_func_loader(torch, Tensor, + [(torch.is_tensor, Tensor)]) diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index b6b970681..b0d2360f6 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -55,6 +55,11 @@ class _TensorMeta(_BaseTensorMeta): @current_names() @class_autoremove class Tensor(Torch, metaclass=_TensorMeta): + __auto_tensor = lambda x: replaceable_partial( + auto_tree, + cls=[(pytorch.is_tensor, Tensor)] + )(x) + # noinspection PyUnusedLocal def __init__(self, data, *args, **kwargs): """ @@ -399,7 +404,7 @@ class Tensor(Torch, metaclass=_TensorMeta): return self # noinspection PyShadowingBuiltins - @post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r)) + @post_process(__auto_tensor) @method_treelize(return_type=TreeValue, rise=True) def __max_nr(self, *args, **kwargs): return pytorch.max(self, *args, **kwargs) @@ -419,7 +424,7 @@ class Tensor(Torch, metaclass=_TensorMeta): return self # noinspection PyShadowingBuiltins - @post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r)) + @post_process(__auto_tensor) @method_treelize(return_type=TreeValue, rise=True) def __min_nr(self, *args, **kwargs): return pytorch.min(self, *args, **kwargs) @@ -439,7 +444,7 @@ class Tensor(Torch, metaclass=_TensorMeta): return self # noinspection PyShadowingBuiltins - @post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r)) + @post_process(__auto_tensor) @method_treelize(return_type=TreeValue, rise=True) def __sum_nr(self, *args, **kwargs): return pytorch.sum(self, *args, **kwargs) @@ -882,7 +887,7 @@ class Tensor(Torch, metaclass=_TensorMeta): return self.log10_(*args, **kwargs) @doc_from_base() - @post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r)) + @post_process(__auto_tensor) @method_treelize(return_type=TreeValue, rise=True) def split(self, split_size, *args, **kwargs): """ @@ -891,7 +896,7 @@ class Tensor(Torch, metaclass=_TensorMeta): return self.split(split_size, *args, **kwargs) @doc_from_base() - @post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r)) + @post_process(__auto_tensor) @method_treelize(return_type=TreeValue, rise=True) def chunk(self, chunks, *args, **kwargs): """ -- GitLab