diff --git a/treetensor/torch/tensor/attr.py b/treetensor/torch/tensor/attr.py index fef8db77d026ddd461947b12779ccdfd20952fdc..0891806fb71acb52fa7d80c97b358451adf6573f 100644 --- a/treetensor/torch/tensor/attr.py +++ b/treetensor/torch/tensor/attr.py @@ -1,9 +1,15 @@ -from treevalue import tree_class +from treevalue import method_treelize, TreeValue +from treevalue.utils import post_process -from .tensor import Tensor -from ..base import Torch +from ..base import Torch, auto_torch +from ..tensor import Tensor +from ...utils import replaceable_partial + +auto_tensor = replaceable_partial(auto_torch, cls=Tensor) -@tree_class(return_type=Tensor) class TensorMethod(Torch): - pass + @post_process(auto_tensor) + @method_treelize(return_type=TreeValue, rise=True) + def __call__(self, *args, **kwargs): + return self(*args, **kwargs) diff --git a/treetensor/torch/tensor/tensor.py b/treetensor/torch/tensor/tensor.py index ba8482de5a108196ff42b28512a54aa81d98ab91..ca0ea8c20a21da0fe6e4156a247510f9795c35be 100644 --- a/treetensor/torch/tensor/tensor.py +++ b/treetensor/torch/tensor/tensor.py @@ -1,3 +1,5 @@ +from functools import wraps + import numpy as np import torch from treevalue import method_treelize, TreeValue @@ -64,17 +66,24 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): def __get_attr(self, key): return getattr(self, key) - def _attr_extern(self, key): - tree = self.__get_attr(key) - if tree.map(lambda x: torch.is_tensor(x)).all(): - type_ = Tensor - elif tree.map(lambda x: callable(x)).all(): + def _attr_extern(self, name): + tree = self.__get_attr(name) + if hasattr(torch.Tensor, name) and not name.startswith('_') \ + and callable(getattr(torch.Tensor, name)): from .attr import TensorMethod - type_ = TensorMethod - else: - type_ = Object + tree = tree.type(TensorMethod) - return tree.type(type_) + @wraps(getattr(torch, name), assigned=('__name__', '__qualname__'), updated=()) + def _new_func(*args, **kwargs): + result = tree(*args, **kwargs) + return self if name.endswith('_') else result + + _new_func.__self__ = self + return _new_func + elif tree.map(lambda x: torch.is_tensor(x)).all(): + return tree.type(Tensor) + else: + return tree @doc_from_base() @method_treelize(return_type=ndarray)