diff --git a/treetensor/common/__init__.py b/treetensor/common/__init__.py index 6163b00278e6a7c385dc9415e67b7b8cc4f21350..04448872d945440151158e1a1c1058c4ab71a26c 100644 --- a/treetensor/common/__init__.py +++ b/treetensor/common/__init__.py @@ -1,3 +1,5 @@ +from .module import * from .object import * +from .proxy import * from .trees import * from .wrappers import * diff --git a/treetensor/common/module.py b/treetensor/common/module.py new file mode 100644 index 0000000000000000000000000000000000000000..f840259879c5a891076b218a9ab798d6e465309c --- /dev/null +++ b/treetensor/common/module.py @@ -0,0 +1,43 @@ +from functools import wraps +from typing import Type + +from treevalue import TreeValue +from treevalue import func_treelize as original_func_treelize +from treevalue.tree.common import BaseTree +from treevalue.utils import post_process + +from .trees import auto_tree +from .wrappers import return_self +from ..utils import doc_from_base as original_doc_from_base +from ..utils import replaceable_partial, args_mapping + +__all__ = [ + 'module_func_loader', +] + + +def module_func_loader(base, cls: Type[TreeValue], module_name: str): + func_treelize = post_process(post_process(args_mapping( + lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))( + replaceable_partial(original_func_treelize, return_type=cls) + ) + doc_from_base = replaceable_partial(original_doc_from_base, base=base) + auto_tree_cls = replaceable_partial(auto_tree, cls=cls) + + def _load_func(name): + func = getattr(base, name) + return_self_dec = return_self if func.__name__.endswith("_") else (lambda x: x) + + @doc_from_base() + @return_self_dec + @post_process(auto_tree_cls) + @func_treelize(return_type=TreeValue, rise=True) + @wraps(func, assigned=('__name__',), updated=()) + def _new_func(*args, **kwargs): + return func(*args, **kwargs) + + _new_func.__qualname__ = _new_func.__name__ + _new_func.__module__ = module_name + return _new_func + + return _load_func diff --git a/treetensor/common/proxy.py b/treetensor/common/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..216688e85b2c85456821e74fb4b60562360ec755 --- /dev/null +++ b/treetensor/common/proxy.py @@ -0,0 +1,56 @@ +from functools import wraps +from types import MethodType + +from treevalue import method_treelize, TreeValue +from treevalue.utils import post_process + +from .trees import auto_tree +from .wrappers import return_self +from ..utils import doc_from_base as original_doc_from_base +from ..utils import replaceable_partial + +__all__ = [ + 'get_tree_proxy', +] + + +def get_tree_proxy(base): + doc_from_base = replaceable_partial(original_doc_from_base, base=base) + + class _TreeClassProxy: + def __init__(self, cls): + self.__torch_funcs = {} + self.__cls = cls + + def __getattr__(self, name): + if name in self.__torch_funcs.keys(): + return self.__torch_funcs[name] + elif hasattr(base, name) and not name.startswith('_') \ + and callable(getattr(base, name)): + _origin_func = getattr(base, name) + return_self_deco = return_self if name.endswith('_') else (lambda x: x) + + @doc_from_base() + @return_self_deco + @post_process(lambda r: replaceable_partial(auto_tree, cls=self.__cls)(r)) + @method_treelize(return_type=TreeValue, rise=True) + @wraps(_origin_func, assigned=('__name__',), updated=()) + def _new_func(*args, **kwargs): + return _origin_func(*args, **kwargs) + + _new_func.__qualname__ = f'{self.__cls.__name__}.{name}' + _new_func.__module__ = self.__cls.__module__ + self.__torch_funcs[name] = _new_func + return _new_func + else: + raise AttributeError(f'Function {repr(name)} not found in {repr(base)}') + + class _TreeInstanceProxy: + def __init__(self, proxy, s): + self.__proxy = proxy + self.__self = s + + def __getattr__(self, name): + return MethodType(getattr(self.__proxy, name), self.__self) + + return _TreeClassProxy, _TreeInstanceProxy diff --git a/treetensor/common/trees.py b/treetensor/common/trees.py index 4a02ec3a79a5e6d7bac7188d497b831fd6e47009..b7def134439acfa49c6d9b8faab96edab310ad18 100644 --- a/treetensor/common/trees.py +++ b/treetensor/common/trees.py @@ -6,7 +6,7 @@ from typing import Optional, Tuple, Callable from typing import Type from treevalue import func_treelize as original_func_treelize -from treevalue import general_tree_value, TreeValue +from treevalue import general_tree_value, TreeValue, typetrans from treevalue.tree.common import BaseTree from treevalue.tree.tree.tree import get_data_property from treevalue.utils import post_process @@ -15,7 +15,7 @@ from ..utils import replaceable_partial, args_mapping __all__ = [ 'BaseTreeStruct', - 'print_tree', 'clsmeta', + 'print_tree', 'clsmeta', 'auto_tree', ] @@ -177,3 +177,18 @@ def clsmeta(func, allow_dict: bool = False) -> Type[type]: return _result return _MetaClass + + +# noinspection PyArgumentList +def auto_tree(v, cls): + if isinstance(cls, type) and issubclass(cls, TreeValue): + cls = partial(typetrans, return_type=cls) + + if isinstance(v, TreeValue): + return cls(v) + elif isinstance(v, (tuple, list, set)): + return type(v)((auto_tree(item, cls) for item in v)) + elif isinstance(v, dict): + return type(v)({key: auto_tree(value, cls) for key, value in v.items()}) + else: + return v diff --git a/treetensor/numpy/__init__.py b/treetensor/numpy/__init__.py index 46aa5d19273abba6041678241362e61c1f9cba4d..ddd0dbe373e1557a3a6c7be1c340f624b364f3a7 100644 --- a/treetensor/numpy/__init__.py +++ b/treetensor/numpy/__init__.py @@ -1,9 +1,56 @@ +import builtins +from types import ModuleType, FunctionType, BuiltinFunctionType +from typing import Iterable + +import numpy as np + from .array import * from .array import __all__ as _array_all from .funcs import * from .funcs import __all__ as _funcs_all +from .funcs import get_func_from_numpy +from ..config.meta import __VERSION__ __all__ = [ *_funcs_all, *_array_all, ] + +_basic_types = ( + builtins.bool, builtins.bytearray, builtins.bytes, builtins.complex, builtins.dict, + builtins.float, builtins.frozenset, builtins.int, builtins.list, builtins.range, builtins.set, + builtins.slice, builtins.str, builtins.tuple, +) +_np_all = set(np.__all__) + + +class _Module(ModuleType): + def __init__(self, module): + ModuleType.__init__(self, module.__name__) + + for name in filter(lambda x: x.startswith('__') and x.endswith('__'), dir(module)): + setattr(self, name, getattr(module, name)) + self.__origin__ = module + self.__numpy_version__ = np.__version__ + self.__version__ = __VERSION__ + + def __getattr__(self, name): + if (name in self.__all__) or \ + (hasattr(self.__origin__, name) and isinstance(getattr(self.__origin__, name), ModuleType)): + return getattr(self.__origin__, name) + else: + item = getattr(np, name) + if isinstance(item, (FunctionType, BuiltinFunctionType)) and not name.startswith('_'): + return get_func_from_numpy(name) + elif isinstance(item, _basic_types) and name in _np_all: + return item + else: + raise AttributeError(f'Attribute {repr(name)} not found in {repr(__name__)}.') + + def __dir__(self) -> Iterable[str]: + return self.__all__ + + +import sys + +sys.modules[__name__] = _Module(sys.modules[__name__]) diff --git a/treetensor/numpy/array.py b/treetensor/numpy/array.py index 7cbcc9a8a1ba078aadb0d4b6af5ebee7b70fa87c..142c9587fc1a175db2670f1d702a179806db2bfe 100644 --- a/treetensor/numpy/array.py +++ b/treetensor/numpy/array.py @@ -1,52 +1,95 @@ -import numpy as np +import numpy from treevalue import method_treelize from .base import TreeNumpy -from ..common import Object, ireduce +from ..common import Object, ireduce, clsmeta, get_tree_proxy from ..utils import current_names __all__ = [ 'ndarray' ] +_ArrayProxy, _InstanceArrayProxy = get_tree_proxy(numpy.ndarray) + + +class _BaseArrayMeta(clsmeta(numpy.asarray, allow_dict=True)): + pass + + +# noinspection PyMethodParameters +class _ArrayMeta(_BaseArrayMeta): + def __init__(cls, *args, **kwargs): + _BaseArrayMeta.__init__(cls, *args, **kwargs) + cls.__proxy = None + + @property + def np(cls): + if not cls.__proxy: + cls.__proxy = _ArrayProxy(cls) + return cls.__proxy + + def __getattr__(cls, name): + try: + return cls.np.__getattr__(name) + except AttributeError: + raise AttributeError(f"type object {repr(cls.__name__)} has no attribute {repr(name)}") + # noinspection PyPep8Naming @current_names() -class ndarray(TreeNumpy): +class ndarray(TreeNumpy, metaclass=_ArrayMeta): """ Overview: Real numpy tree. """ @method_treelize(return_type=Object) - def tolist(self: np.ndarray): + def __get_attr(self, key): + return getattr(self, key) + + def _attr_extern(self, name): + try: + return getattr(self.np, name) + except AttributeError: + tree = self.__get_attr(name) + if tree.map(lambda x: isinstance(x, numpy.ndarray)).all(): + return tree.type(ndarray) + else: + return tree + + @property + def np(self): + return _InstanceArrayProxy(self.__class__.np, self) + + @method_treelize(return_type=Object) + def tolist(self: numpy.ndarray): return self.tolist() @property @ireduce(sum) @method_treelize(return_type=Object) - def size(self: np.ndarray) -> int: + def size(self: numpy.ndarray) -> int: return self.size @property @ireduce(sum) @method_treelize(return_type=Object) - def nbytes(self: np.ndarray) -> int: + def nbytes(self: numpy.ndarray) -> int: return self.nbytes @ireduce(sum) @method_treelize(return_type=Object) - def sum(self: np.ndarray, *args, **kwargs): + def sum(self: numpy.ndarray, *args, **kwargs): return self.sum(*args, **kwargs) @ireduce(all) @method_treelize(return_type=Object) - def all(self: np.ndarray, *args, **kwargs): + def all(self: numpy.ndarray, *args, **kwargs): return self.all(*args, **kwargs) @ireduce(any) @method_treelize(return_type=Object) - def any(self: np.ndarray, *args, **kwargs): + def any(self: numpy.ndarray, *args, **kwargs): return self.any(*args, **kwargs) @method_treelize() diff --git a/treetensor/numpy/funcs.py b/treetensor/numpy/funcs.py index c4f5a522b86dbe80b7b029d555d29225ef64bdc0..068b0371102d00546616b6b721174760da70f803 100644 --- a/treetensor/numpy/funcs.py +++ b/treetensor/numpy/funcs.py @@ -3,10 +3,11 @@ import builtins import numpy as np from treevalue import TreeValue from treevalue import func_treelize as original_func_treelize +from treevalue.tree.common import BaseTree from treevalue.utils import post_process from .array import ndarray -from ..common import ireduce, Object +from ..common import ireduce, Object, module_func_loader from ..utils import replaceable_partial, doc_from, args_mapping __all__ = [ @@ -15,9 +16,10 @@ __all__ = [ ] func_treelize = post_process(post_process(args_mapping( - lambda i, x: TreeValue(x) if isinstance(x, (dict, TreeValue)) else x)))( + lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))( replaceable_partial(original_func_treelize, return_type=ndarray) ) +get_func_from_numpy = module_func_loader(np, ndarray, __name__) @doc_from(np.all) diff --git a/treetensor/torch/base/torch.py b/treetensor/torch/base/torch.py index c195b233d0db8313d8121fbf365f7ff9a8a45185..b2bbf2ca7b62bb68df61810eac6739748b4986fb 100644 --- a/treetensor/torch/base/torch.py +++ b/treetensor/torch/base/torch.py @@ -1,23 +1,9 @@ -from typing import Type - -from treevalue import TreeValue, typetrans - from ...common import BaseTreeStruct -__all__ = ['Torch', 'auto_torch'] +__all__ = [ + 'Torch' +] class Torch(BaseTreeStruct): pass - - -# noinspection PyArgumentList -def auto_torch(v, cls: Type[Torch]): - if isinstance(v, TreeValue): - return typetrans(v, cls) - elif isinstance(v, (tuple, list, set)): - return type(v)((auto_torch(item, cls) for item in v)) - elif isinstance(v, dict): - return type(v)({key: auto_torch(value, cls) for key, value in v.items()}) - else: - return v diff --git a/treetensor/torch/funcs/base.py b/treetensor/torch/funcs/base.py index dae8e9d13b8933a7f2f4da377c7b90ea444730e5..1eab33875aa005c939e786972ba4f8e80ad0018b 100644 --- a/treetensor/torch/funcs/base.py +++ b/treetensor/torch/funcs/base.py @@ -1,14 +1,11 @@ -from functools import wraps - import torch from treevalue import TreeValue from treevalue import func_treelize as original_func_treelize from treevalue.tree.common import BaseTree from treevalue.utils import post_process -from ..base import auto_torch from ..tensor import Tensor -from ...common import return_self +from ...common import auto_tree, module_func_loader from ...utils import doc_from_base as original_doc_from_base from ...utils import replaceable_partial, args_mapping @@ -17,23 +14,5 @@ func_treelize = post_process(post_process(args_mapping( replaceable_partial(original_func_treelize, return_type=Tensor) ) doc_from_base = replaceable_partial(original_doc_from_base, base=torch) -auto_tensor = replaceable_partial(auto_torch, cls=Tensor) - -_funcs_module = '.'.join(__name__.split('.')[:-1]) - - -def get_func_from_torch(name): - func = getattr(torch, name) - return_self_dec = return_self if func.__name__.endswith("_") else (lambda x: x) - - @doc_from_base() - @return_self_dec - @post_process(auto_tensor) - @func_treelize(return_type=TreeValue, rise=True) - @wraps(func, assigned=('__name__',), updated=()) - def _new_func(*args, **kwargs): - return func(*args, **kwargs) - - _new_func.__qualname__ = _new_func.__name__ - _new_func.__module__ = _funcs_module - return _new_func +auto_tensor = replaceable_partial(auto_tree, cls=Tensor) +get_func_from_torch = module_func_loader(torch, Tensor, '.'.join(__name__.split('.')[:-1])) diff --git a/treetensor/torch/size.py b/treetensor/torch/size.py index 2e86cfb9b54da89daa6a972beea4f0ef563b996a..31295744fa78c0a47ecbf010cab0ebdb30c62679 100644 --- a/treetensor/torch/size.py +++ b/treetensor/torch/size.py @@ -54,7 +54,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)): Examples:: >>> import torch - >>> import treetensor.torch as ttorch + >>> import treetensor.numpy as ttorch >>> ttorch.Size([1, 2, 3]) torch.Size([1, 2, 3]) @@ -81,7 +81,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)): Example:: >>> import torch - >>> import treetensor.torch as ttorch + >>> import treetensor.numpy as ttorch >>> ttorch.Size({ ... 'a': [1, 2], ... 'b': {'x': [3, 2, 4]}, @@ -99,7 +99,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)): Example:: >>> import torch - >>> import treetensor.torch as ttorch + >>> import treetensor.numpy as ttorch >>> ttorch.Size({ ... 'a': [1, 2], ... 'b': {'x': [3, 2, 4]}, @@ -132,7 +132,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)): Example:: >>> import torch - >>> import treetensor.torch as ttorch + >>> import treetensor.numpy as ttorch >>> ttorch.Size({ ... 'a': [1, 2], ... 'b': {'x': [3, 2, 4]}, diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index 99754b36b7737d05b462a9f668f92871e5b48929..b6b970681da98b9ef5e9c38fd958422d7571305f 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -1,14 +1,11 @@ -from functools import wraps -from types import MethodType - import numpy as np import torch as pytorch from treevalue import method_treelize, TreeValue from treevalue.utils import post_process -from .base import Torch, auto_torch, rmreduce, post_reduce, auto_reduce +from .base import Torch, rmreduce, post_reduce, auto_reduce from .size import Size -from ..common import Object, ireduce, clsmeta, return_self +from ..common import Object, ireduce, clsmeta, return_self, auto_tree, get_tree_proxy from ..numpy import ndarray from ..utils import current_names, class_autoremove, replaceable_partial from ..utils import doc_from_base as original_doc_from_base @@ -18,6 +15,7 @@ __all__ = [ ] doc_from_base = replaceable_partial(original_doc_from_base, base=pytorch.Tensor) +_TorchProxy, _InstanceTorchProxy = get_tree_proxy(pytorch.Tensor) def _to_tensor(*args, **kwargs): @@ -30,44 +28,6 @@ def _to_tensor(*args, **kwargs): return pytorch.tensor(*args, **kwargs) -class _TorchProxy: - def __init__(self, cls): - self.__torch_funcs = {} - self.__cls = cls - - def __getattr__(self, name): - if name in self.__torch_funcs.keys(): - return self.__torch_funcs[name] - elif hasattr(pytorch.Tensor, name) and not name.startswith('_') \ - and callable(getattr(pytorch.Tensor, name)): - _origin_func = getattr(pytorch.Tensor, name) - return_self_deco = return_self if name.endswith('_') else (lambda x: x) - - @doc_from_base() - @return_self_deco - @post_process(lambda r: replaceable_partial(auto_torch, cls=self.__cls)(r)) - @method_treelize(return_type=TreeValue, rise=True) - @wraps(_origin_func, assigned=('__name__',), updated=()) - def _new_func(*args, **kwargs): - return _origin_func(*args, **kwargs) - - _new_func.__qualname__ = f'{self.__cls.__name__}.{name}' - _new_func.__module__ = self.__cls.__module__ - self.__torch_funcs[name] = _new_func - return _new_func - else: - raise AttributeError(f'Function {repr(name)} not found in {repr(pytorch)}') - - -class _InstanceTorchProxy: - def __init__(self, proxy, s): - self.__proxy = proxy - self.__self = s - - def __getattr__(self, name): - return MethodType(getattr(self.__proxy, name), self.__self) - - class _BaseTensorMeta(clsmeta(_to_tensor, allow_dict=True)): pass @@ -76,13 +36,13 @@ class _BaseTensorMeta(clsmeta(_to_tensor, allow_dict=True)): class _TensorMeta(_BaseTensorMeta): def __init__(cls, *args, **kwargs): _BaseTensorMeta.__init__(cls, *args, **kwargs) - cls.__torch_proxy = None + cls.__proxy = None @property def torch(cls): - if not cls.__torch_proxy: - cls.__torch_proxy = _TorchProxy(cls) - return cls.__torch_proxy + if not cls.__proxy: + cls.__proxy = _TorchProxy(cls) + return cls.__proxy def __getattr__(cls, name): try: @@ -439,7 +399,7 @@ class Tensor(Torch, metaclass=_TensorMeta): return self # noinspection PyShadowingBuiltins - @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) + @post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r)) @method_treelize(return_type=TreeValue, rise=True) def __max_nr(self, *args, **kwargs): return pytorch.max(self, *args, **kwargs) @@ -459,7 +419,7 @@ class Tensor(Torch, metaclass=_TensorMeta): return self # noinspection PyShadowingBuiltins - @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) + @post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r)) @method_treelize(return_type=TreeValue, rise=True) def __min_nr(self, *args, **kwargs): return pytorch.min(self, *args, **kwargs) @@ -479,7 +439,7 @@ class Tensor(Torch, metaclass=_TensorMeta): return self # noinspection PyShadowingBuiltins - @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) + @post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r)) @method_treelize(return_type=TreeValue, rise=True) def __sum_nr(self, *args, **kwargs): return pytorch.sum(self, *args, **kwargs) @@ -922,7 +882,7 @@ class Tensor(Torch, metaclass=_TensorMeta): return self.log10_(*args, **kwargs) @doc_from_base() - @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) + @post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r)) @method_treelize(return_type=TreeValue, rise=True) def split(self, split_size, *args, **kwargs): """ @@ -931,7 +891,7 @@ class Tensor(Torch, metaclass=_TensorMeta): return self.split(split_size, *args, **kwargs) @doc_from_base() - @post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r)) + @post_process(lambda r: replaceable_partial(auto_tree, cls=Tensor)(r)) @method_treelize(return_type=TreeValue, rise=True) def chunk(self, chunks, *args, **kwargs): """