From 9a61cbecd3b5e969f384bdea0e190db4c56bd3c7 Mon Sep 17 00:00:00 2001 From: HansBug Date: Tue, 28 Sep 2021 23:28:08 +0800 Subject: [PATCH] dev(hansbug): add auto attr into Tensor class --- treetensor/torch/tensor/__init__.py | 5 ++++ treetensor/torch/tensor/attr.py | 9 +++++++ treetensor/torch/{ => tensor}/tensor.py | 32 +++++++++++++++++-------- 3 files changed, 36 insertions(+), 10 deletions(-) create mode 100644 treetensor/torch/tensor/__init__.py create mode 100644 treetensor/torch/tensor/attr.py rename treetensor/torch/{ => tensor}/tensor.py (97%) diff --git a/treetensor/torch/tensor/__init__.py b/treetensor/torch/tensor/__init__.py new file mode 100644 index 000000000..8e268b47b --- /dev/null +++ b/treetensor/torch/tensor/__init__.py @@ -0,0 +1,5 @@ +from .tensor import Tensor + +__all__ = [ + 'Tensor' +] diff --git a/treetensor/torch/tensor/attr.py b/treetensor/torch/tensor/attr.py new file mode 100644 index 000000000..fef8db77d --- /dev/null +++ b/treetensor/torch/tensor/attr.py @@ -0,0 +1,9 @@ +from treevalue import tree_class + +from .tensor import Tensor +from ..base import Torch + + +@tree_class(return_type=Tensor) +class TensorMethod(Torch): + pass diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor/tensor.py similarity index 97% rename from treetensor/torch/tensor.py rename to treetensor/torch/tensor/tensor.py index 0f6d18a68..ba8482de5 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor/tensor.py @@ -3,16 +3,12 @@ import torch from treevalue import method_treelize, TreeValue from treevalue.utils import post_process -from .base import Torch, auto_torch, rmreduce, post_reduce, auto_reduce -from .size import Size -from ..common import Object, ireduce, clsmeta, return_self -from ..numpy import ndarray -from ..utils import current_names, class_autoremove, replaceable_partial -from ..utils import doc_from_base as original_doc_from_base - -__all__ = [ - 'Tensor' -] +from ..base import Torch, auto_torch, rmreduce, post_reduce, auto_reduce +from ..size import Size +from ...common import Object, ireduce, clsmeta, return_self +from ...numpy import ndarray +from ...utils import current_names, class_autoremove, replaceable_partial +from ...utils import doc_from_base as original_doc_from_base doc_from_base = replaceable_partial(original_doc_from_base, base=torch.Tensor) @@ -64,6 +60,22 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ super(Torch, self).__init__(data) + @method_treelize(return_type=Object) + def __get_attr(self, key): + return getattr(self, key) + + def _attr_extern(self, key): + tree = self.__get_attr(key) + if tree.map(lambda x: torch.is_tensor(x)).all(): + type_ = Tensor + elif tree.map(lambda x: callable(x)).all(): + from .attr import TensorMethod + type_ = TensorMethod + else: + type_ = Object + + return tree.type(type_) + @doc_from_base() @method_treelize(return_type=ndarray) def numpy(self: torch.Tensor) -> np.ndarray: -- GitLab