提交 4ebeb305 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): upgrade auto system, more feature

上级 ba8e6f92
from .mark import choose_mark_with_existence_check
from .mark import choose_mark_with_existence_check, get_mark_with_existence_check
......@@ -3,12 +3,15 @@ import pytest
_TEST_PREFIX = 'test_'
def get_mark_with_existence_check(base, name):
_mark = pytest.mark.unittest if hasattr(base, name) else pytest.mark.ignore
return _mark
def choose_mark_with_existence_check(base, name: str = None):
def _decorator(func):
_name = name or func.__name__[len(_TEST_PREFIX):]
_mark = pytest.mark.unittest if hasattr(base, _name) else pytest.mark.ignore
func = _mark(func)
return func
_mark = get_mark_with_existence_check(base, _name)
return _mark(func)
return _decorator
from .funcs import *
from .tensor import *
from .test_module import TestTorchModule
from .test_size import TestTorchSize
from .test_auto import TestTorchFuncsAuto
from .test_autograd import TestTorchFuncsAutograd
from .test_comparison import TestTorchFuncsComparison
from .test_construct import TestTorchFuncsConstruct
......
import treetensor.torch as ttorch
from treetensor.utils import replaceable_partial
from ...tests import choose_mark_with_existence_check
from ...tests import choose_mark_with_existence_check, get_mark_with_existence_check
get_mark = replaceable_partial(get_mark_with_existence_check, base=ttorch)
choose_mark = replaceable_partial(choose_mark_with_existence_check, base=ttorch)
from math import nan
import pytest
import treetensor.torch as ttorch
from .base import get_mark
def func_not_implemented(name: str, need_exist=True):
mark_0 = pytest.mark.unittest if name not in ttorch.funcs.__all__ else pytest.mark.ignore
mark_1 = get_mark(name=name)
need_test = all(map(lambda x: x.name == 'unittest',
[mark_0, *((mark_1,) if need_exist else ())]))
return pytest.mark.unittest if need_test else pytest.mark.ignore
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchFuncsAuto:
@func_not_implemented('arctanh')
def test_u_arctanh(self):
tt = ttorch.tensor({
'a': [[0.8479, 1.0074, 0.2725],
[1.1674, 1.0784, 0.0655]],
'b': {'x': [[0.2644, 0.7268, 0.2781, 0.6469],
[2.0015, 0.4448, 0.8814, 1.0063],
[0.1847, 0.5864, 0.4417, 0.2117]]},
})
ttc = tt.clone()
assert ttorch.isclose(ttorch.arctanh(tt), ttorch.tensor({
'a': [[1.2487, nan, 0.2796],
[nan, nan, 0.0656]],
'b': {'x': [[0.2708, 0.9219, 0.2857, 0.7699],
[nan, 0.4782, 1.3821, nan],
[0.1868, 0.6722, 0.4743, 0.2150]]}
}), atol=1e-4, equal_nan=True).all()
assert ttorch.isclose(tt, ttc, atol=1e-4, equal_nan=True).all()
@func_not_implemented('arctanh_')
def test_u_arctanh_(self):
tt = ttorch.tensor({
'a': [[0.8479, 1.0074, 0.2725],
[1.1674, 1.0784, 0.0655]],
'b': {'x': [[0.2644, 0.7268, 0.2781, 0.6469],
[2.0015, 0.4448, 0.8814, 1.0063],
[0.1847, 0.5864, 0.4417, 0.2117]]},
})
ttc = tt.clone()
ttr = ttorch.arctanh_(tt)
assert ttr is tt
assert ttorch.isclose(ttr, ttorch.tensor({
'a': [[1.2487, nan, 0.2796],
[nan, nan, 0.0656]],
'b': {'x': [[0.2708, 0.9219, 0.2857, 0.7699],
[nan, 0.4782, 1.3821, nan],
[0.1868, 0.6722, 0.4743, 0.2150]]}
}), atol=1e-4, equal_nan=True).all()
assert not ttorch.isclose(tt, ttc, atol=1e-4, equal_nan=True).all()
from .test_auto import TestTorchTensorAuto
from .test_autograd import TestTorchTensorAutograd
from .test_clazz import TestTorchTensorClass
from .test_comparison import TestTorchTensorComparison
......
import treetensor.torch as ttorch
from treetensor.utils import replaceable_partial
from ...tests import choose_mark_with_existence_check
from ...tests import choose_mark_with_existence_check, get_mark_with_existence_check
get_mark = replaceable_partial(get_mark_with_existence_check, base=ttorch.Tensor)
choose_mark = replaceable_partial(choose_mark_with_existence_check, base=ttorch.Tensor)
from math import nan
from types import FunctionType, MethodType
import pytest
import treetensor.torch as ttorch
from treetensor.common import Object
from .base import get_mark
def method_not_implemented(name: str, need_exist=True):
mark_0 = pytest.mark.unittest if name not in ttorch.funcs.__all__ else pytest.mark.ignore
mark_1 = get_mark(name=name)
need_test = all(map(lambda x: x.name == 'unittest',
[mark_0, *((mark_1,) if need_exist else ())]))
return pytest.mark.unittest if need_test else pytest.mark.ignore
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchTensorAuto:
@method_not_implemented('arctanh')
def test_u_arctanh(self):
tt = ttorch.tensor({
'a': [[0.8479, 1.0074, 0.2725],
[1.1674, 1.0784, 0.0655]],
'b': {'x': [[0.2644, 0.7268, 0.2781, 0.6469],
[2.0015, 0.4448, 0.8814, 1.0063],
[0.1847, 0.5864, 0.4417, 0.2117]]},
})
ttc = tt.clone()
assert isinstance(ttorch.Tensor.arctanh, FunctionType)
assert isinstance(tt.arctanh, MethodType)
assert ttorch.isclose(tt.arctanh(), ttorch.tensor({
'a': [[1.2487, nan, 0.2796],
[nan, nan, 0.0656]],
'b': {'x': [[0.2708, 0.9219, 0.2857, 0.7699],
[nan, 0.4782, 1.3821, nan],
[0.1868, 0.6722, 0.4743, 0.2150]]}
}), atol=1e-4, equal_nan=True).all()
assert ttorch.isclose(tt, ttc, atol=1e-4, equal_nan=True).all()
@method_not_implemented('arctanh_')
def test_u_arctanh_(self):
tt = ttorch.tensor({
'a': [[0.8479, 1.0074, 0.2725],
[1.1674, 1.0784, 0.0655]],
'b': {'x': [[0.2644, 0.7268, 0.2781, 0.6469],
[2.0015, 0.4448, 0.8814, 1.0063],
[0.1847, 0.5864, 0.4417, 0.2117]]},
})
ttc = tt.clone()
assert isinstance(ttorch.Tensor.arctanh_, FunctionType)
assert isinstance(tt.arctanh_, MethodType)
ttr = tt.arctanh_()
assert ttr is tt
assert ttorch.isclose(ttr, ttorch.tensor({
'a': [[1.2487, nan, 0.2796],
[nan, nan, 0.0656]],
'b': {'x': [[0.2708, 0.9219, 0.2857, 0.7699],
[nan, 0.4782, 1.3821, nan],
[0.1868, 0.6722, 0.4743, 0.2150]]}
}), atol=1e-4, equal_nan=True).all()
assert not ttorch.isclose(tt, ttc, atol=1e-4, equal_nan=True).all()
@method_not_implemented('not_found', need_exist=False)
def test_u_not_found(self):
with pytest.raises(AttributeError):
_ = ttorch.Tensor.not_found
with pytest.raises(AttributeError):
_ = ttorch.tensor({'a': [1, 2.0]}).not_found
@method_not_implemented('is_cuda', need_exist=False)
def test_u_is_cuda(self):
tt = ttorch.tensor({
'a': [[0.8479, 1.0074, 0.2725],
[1.1674, 1.0784, 0.0655]],
'b': {'x': [[0.2644, 0.7268, 0.2781, 0.6469],
[2.0015, 0.4448, 0.8814, 1.0063],
[0.1847, 0.5864, 0.4417, 0.2117]]},
})
icuda = tt.is_cuda
assert isinstance(icuda, Object)
assert icuda == Object({
'a': False, 'b': {'x': False}
})
import pytest
import torch
import treetensor.torch as ttorch
from treetensor.utils import replaceable_partial
from ..tests import choose_mark_with_existence_check
choose_mark = replaceable_partial(choose_mark_with_existence_check, base=ttorch.Size)
@pytest.mark.unittest
class TestTorchModule:
def test_float32(self):
assert ttorch.float32 is torch.float32
def test_has_cuda(self):
assert ttorch.has_cuda == torch.has_cuda
def test_fxxk(self):
with pytest.raises(AttributeError):
_ = ttorch.fxxk
def test___all__(self):
assert 'has_cuda' not in ttorch.__all__
assert 'float32' not in ttorch.__all__
assert 'fxxk' not in ttorch.__all__
def test_dir(self):
assert 'has_cuda' not in dir(ttorch)
assert 'float32' not in dir(ttorch)
assert 'fxxk' not in dir(ttorch)
......@@ -51,9 +51,6 @@ class _Module(ModuleType):
else:
raise AttributeError(f'Attribute {repr(name)} not found in {repr(__name__)}.')
def __hasattr__(self, name):
return name in self.__all__
def __dir__(self) -> Iterable[str]:
return self.__all__
......
......@@ -2,50 +2,92 @@ from functools import wraps
from types import MethodType
import numpy as np
import torch
import torch as pytorch
from treevalue import method_treelize, TreeValue
from treevalue.utils import post_process
from ..base import Torch, auto_torch, rmreduce, post_reduce, auto_reduce
from ..size import Size
from ...common import Object, ireduce, clsmeta, return_self
from ...numpy import ndarray
from ...utils import current_names, class_autoremove, replaceable_partial
from ...utils import doc_from_base as original_doc_from_base
from .base import Torch, auto_torch, rmreduce, post_reduce, auto_reduce
from .size import Size
from ..common import Object, ireduce, clsmeta, return_self
from ..numpy import ndarray
from ..utils import current_names, class_autoremove, replaceable_partial
from ..utils import doc_from_base as original_doc_from_base
doc_from_base = replaceable_partial(original_doc_from_base, base=torch.Tensor)
__all__ = [
'Tensor'
]
doc_from_base = replaceable_partial(original_doc_from_base, base=pytorch.Tensor)
def _to_tensor(*args, **kwargs):
if (len(args) == 1 and not kwargs) or \
(not args and set(kwargs.keys()) == {'data'}):
data = args[0] if len(args) == 1 else kwargs['data']
if isinstance(data, torch.Tensor):
if isinstance(data, pytorch.Tensor):
return data
return torch.tensor(*args, **kwargs)
return pytorch.tensor(*args, **kwargs)
# noinspection PyMethodParameters
class _TensorMeta(clsmeta(_to_tensor, allow_dict=True)):
def __getattr__(cls, name):
if hasattr(torch.Tensor, name) and not name.startswith('_') \
and callable(getattr(torch.Tensor, name)):
_origin_func = getattr(torch.Tensor, name)
class _TorchProxy:
def __init__(self, cls):
self.__torch_funcs = {}
self.__cls = cls
def __getattr__(self, name):
if name in self.__torch_funcs.keys():
return self.__torch_funcs[name]
elif hasattr(pytorch.Tensor, name) and not name.startswith('_') \
and callable(getattr(pytorch.Tensor, name)):
_origin_func = getattr(pytorch.Tensor, name)
return_self_deco = return_self if name.endswith('_') else (lambda x: x)
@doc_from_base()
@return_self_deco
@post_process(lambda r: replaceable_partial(auto_torch, cls=cls)(r))
@post_process(lambda r: replaceable_partial(auto_torch, cls=self.__cls)(r))
@method_treelize(return_type=TreeValue, rise=True)
@wraps(_origin_func, assigned=('__name__',), updated=())
def _new_func(*args, **kwargs):
return _origin_func(*args, **kwargs)
_new_func.__qualname__ = f'{cls.__name__}.{name}'
_new_func.__module__ = cls.__module__
_new_func.__qualname__ = f'{self.__cls.__name__}.{name}'
_new_func.__module__ = self.__cls.__module__
self.__torch_funcs[name] = _new_func
return _new_func
else:
raise AttributeError(f'Function {repr(name)} not found in {repr(pytorch)}')
class _InstanceTorchProxy:
def __init__(self, proxy, s):
self.__proxy = proxy
self.__self = s
def __getattr__(self, name):
return MethodType(getattr(self.__proxy, name), self.__self)
class _BaseTensorMeta(clsmeta(_to_tensor, allow_dict=True)):
pass
# noinspection PyMethodParameters
class _TensorMeta(_BaseTensorMeta):
def __init__(cls, *args, **kwargs):
_BaseTensorMeta.__init__(cls, *args, **kwargs)
cls.__torch_proxy = None
@property
def torch(cls):
if not cls.__torch_proxy:
cls.__torch_proxy = _TorchProxy(cls)
return cls.__torch_proxy
def __getattr__(cls, name):
try:
return cls.torch.__getattr__(name)
except AttributeError:
raise AttributeError(f"type object {repr(cls.__name__)} has no attribute {repr(name)}")
......@@ -63,13 +105,13 @@ class Tensor(Torch, metaclass=_TensorMeta):
>>> import torch
>>> import treetensor.torch as ttorch
>>> torch.Tensor([1, 2, 3]) # in torch.Tensor, default type is float32
>>> pytorch.Tensor([1, 2, 3]) # in torch.Tensor, default type is float32
tensor([1., 2., 3.])
>>> ttorch.Tensor([1, 2, 3]) # a native Tensor object, its type is auto detected with torch.tensor
tensor([1, 2, 3])
>>> ttorch.Tensor([1, 2, 3], dtype=torch.float32) # with float32 type
>>> ttorch.Tensor([1, 2, 3], dtype=pytorch.float32) # with float32 type
tensor([1., 2., 3.])
>>> ttorch.Tensor({
......@@ -91,19 +133,47 @@ class Tensor(Torch, metaclass=_TensorMeta):
return getattr(self, key)
def _attr_extern(self, name):
if hasattr(torch.Tensor, name) and not name.startswith('_') \
and callable(getattr(torch.Tensor, name)):
return MethodType(getattr(self.__class__, name), self)
else:
try:
return getattr(self.torch, name)
except AttributeError:
tree = self.__get_attr(name)
if tree.map(lambda x: torch.is_tensor(x)).all():
if tree.map(lambda x: pytorch.is_tensor(x)).all():
return tree.type(Tensor)
else:
return tree
@property
def torch(self):
"""
Returns a proxy to get the auto generated function from ``torch``.
.. note::
This is useful when some of the method of ``TreeValue`` and ``Tensor``
have the same name, such as ``view``, you can use :meth:`torch.Tensor.view`
in this way.
>>> import treetensor.torch as ttorch
>>> t = ttorch.randn({'a': (2, 3), 'b': {'x': (3, 4)}})
<Tensor 0x7ff1b4720340>
├── a --> tensor([[-0.8067, -0.7860, 2.1065],
│ [ 1.8428, -0.0960, 0.9911]])
└── b --> <Tensor 0x7ff1b4720400>
└── x --> tensor([[ 1.5568, -0.8541, -0.1199, 1.1190],
[-0.7324, 2.7439, 1.0143, -0.0680],
[ 0.0344, 1.2085, -1.3644, 2.9778]])
>>> t.torch.view((-1, )) # t.view is a method of treevalue
<Tensor 0x7ff1b4725610>
├── a --> tensor([-0.8067, -0.7860, 2.1065, 1.8428, -0.0960, 0.9911])
└── b --> <Tensor 0x7ff1b4725400>
└── x --> tensor([ 1.5568, -0.8541, -0.1199, 1.1190, -0.7324, 2.7439, 1.0143, -0.0680,
0.0344, 1.2085, -1.3644, 2.9778])
"""
return _InstanceTorchProxy(self.__class__.torch, self)
@doc_from_base()
@method_treelize(return_type=ndarray)
def numpy(self: torch.Tensor) -> np.ndarray:
def numpy(self: pytorch.Tensor) -> np.ndarray:
"""
Returns ``self`` tree tensor as a NumPy ``ndarray``.
This tensor and the returned :class:`treetensor.numpy.ndarray` share the same underlying storage.
......@@ -113,7 +183,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
@doc_from_base()
@method_treelize(return_type=Object)
def tolist(self: torch.Tensor):
def tolist(self: pytorch.Tensor):
"""
Get the dump result of tree tensor.
......@@ -136,7 +206,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
@doc_from_base()
@method_treelize()
def cpu(self: torch.Tensor, *args, **kwargs):
def cpu(self: pytorch.Tensor, *args, **kwargs):
"""
Returns a copy of this tree tensor in CPU memory.
......@@ -147,7 +217,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
@doc_from_base()
@method_treelize()
def cuda(self: torch.Tensor, *args, **kwargs):
def cuda(self: pytorch.Tensor, *args, **kwargs):
"""
Returns a copy of this tree tensor in CUDA memory.
......@@ -158,7 +228,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
@doc_from_base()
@method_treelize()
def to(self: torch.Tensor, *args, **kwargs):
def to(self: pytorch.Tensor, *args, **kwargs):
"""
Turn the original tree tensor to another format.
......@@ -169,7 +239,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
>>> ttorch.tensor({
... 'a': [[1, 11], [2, 22], [3, 33]],
... 'b': {'x': [[4, 5], [6, 7]]},
... }).to(torch.float64)
... }).to(pytorch.float64)
<Tensor 0x7ff363bb6518>
├── a --> tensor([[ 1., 11.],
│ [ 2., 22.],
......@@ -183,7 +253,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
@doc_from_base()
@ireduce(sum)
@method_treelize(return_type=Object)
def numel(self: torch.Tensor):
def numel(self: pytorch.Tensor):
"""
See :func:`treetensor.torch.numel`
"""
......@@ -192,7 +262,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
@property
@doc_from_base()
@method_treelize(return_type=Size)
def shape(self: torch.Tensor):
def shape(self: pytorch.Tensor):
"""
Get the size of the tensors in the tree.
......@@ -323,7 +393,7 @@ class Tensor(Torch, metaclass=_TensorMeta):
return self.detach_()
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.all)
@post_reduce(pytorch.all)
@method_treelize(return_type=Object)
def __all_r(self, *args, **kwargs):
return self
......@@ -331,19 +401,19 @@ class Tensor(Torch, metaclass=_TensorMeta):
# noinspection PyShadowingBuiltins
@method_treelize()
def __all_nr(self, *args, **kwargs):
return torch.all(self, *args, **kwargs)
return pytorch.all(self, *args, **kwargs)
# noinspection PyArgumentList
@doc_from_base()
@auto_reduce(__all_r, __all_nr)
def all(self: torch.Tensor, *args, reduce=None, **kwargs) -> bool:
def all(self: pytorch.Tensor, *args, reduce=None, **kwargs) -> bool:
"""
See :func:`treetensor.torch.all`
"""
pass # pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.any)
@post_reduce(pytorch.any)
@method_treelize(return_type=Object)
def __any_r(self, *args, **kwargs):
return self
......@@ -351,19 +421,19 @@ class Tensor(Torch, metaclass=_TensorMeta):
# noinspection PyShadowingBuiltins
@method_treelize()
def __any_nr(self, *args, **kwargs):
return torch.any(self, *args, **kwargs)
return pytorch.any(self, *args, **kwargs)
# noinspection PyArgumentList
@doc_from_base()
@auto_reduce(__any_r, __any_nr)
def any(self: torch.Tensor, *args, reduce=None, **kwargs) -> bool:
def any(self: pytorch.Tensor, *args, reduce=None, **kwargs) -> bool:
"""
See :func:`treetensor.torch.any`
"""
pass # pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.max)
@post_reduce(pytorch.max)
@method_treelize(return_type=Object)
def __max_r(self, *args, **kwargs):
return self
......@@ -372,18 +442,18 @@ class Tensor(Torch, metaclass=_TensorMeta):
@post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r))
@method_treelize(return_type=TreeValue, rise=True)
def __max_nr(self, *args, **kwargs):
return torch.max(self, *args, **kwargs)
return pytorch.max(self, *args, **kwargs)
@doc_from_base()
@auto_reduce(__max_r, __max_nr)
def max(self: torch.Tensor, *args, reduce=None, **kwargs):
def max(self: pytorch.Tensor, *args, reduce=None, **kwargs):
"""
See :func:`treetensor.torch.max`
"""
pass # pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.min)
@post_reduce(pytorch.min)
@method_treelize(return_type=Object)
def __min_r(self, *args, **kwargs):
return self
......@@ -392,18 +462,18 @@ class Tensor(Torch, metaclass=_TensorMeta):
@post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r))
@method_treelize(return_type=TreeValue, rise=True)
def __min_nr(self, *args, **kwargs):
return torch.min(self, *args, **kwargs)
return pytorch.min(self, *args, **kwargs)
@doc_from_base()
@auto_reduce(__min_r, __min_nr)
def min(self: torch.Tensor, *args, reduce=None, **kwargs):
def min(self: pytorch.Tensor, *args, reduce=None, **kwargs):
"""
See :func:`treetensor.torch.min`
"""
pass # pragma: no cover
# noinspection PyShadowingBuiltins,PyUnusedLocal
@post_reduce(torch.sum)
@post_reduce(pytorch.sum)
@method_treelize(return_type=Object)
def __sum_r(self, *args, **kwargs):
return self
......@@ -412,11 +482,11 @@ class Tensor(Torch, metaclass=_TensorMeta):
@post_process(lambda r: replaceable_partial(auto_torch, cls=Tensor)(r))
@method_treelize(return_type=TreeValue, rise=True)
def __sum_nr(self, *args, **kwargs):
return torch.sum(self, *args, **kwargs)
return pytorch.sum(self, *args, **kwargs)
@doc_from_base()
@auto_reduce(__sum_r, __sum_nr)
def sum(self: torch.Tensor, *args, reduce=None, **kwargs):
def sum(self: pytorch.Tensor, *args, reduce=None, **kwargs):
"""
See :func:`treetensor.torch.sum`
"""
......@@ -933,12 +1003,12 @@ class Tensor(Torch, metaclass=_TensorMeta):
@rmreduce()
@method_treelize(return_type=Object)
def __masked_select_r(self, mask, *args, **kwargs):
return torch.masked_select(self, mask, *args, **kwargs)
return pytorch.masked_select(self, mask, *args, **kwargs)
# noinspection PyShadowingBuiltins
@method_treelize()
def __masked_select_nr(self, mask, *args, **kwargs):
return torch.masked_select(self, mask, *args, **kwargs)
return pytorch.masked_select(self, mask, *args, **kwargs)
# noinspection PyUnusedLocal,PyMethodParameters,PyMethodMayBeStatic
def __ms_determine(mask, *args, out=None, **kwargs):
......@@ -958,14 +1028,14 @@ class Tensor(Torch, metaclass=_TensorMeta):
pass # pragma: no cover
# noinspection PyUnusedLocal
@post_reduce(torch.std)
@post_reduce(pytorch.std)
@method_treelize(return_type=Object)
def __std_r(self, *args, **kwargs):
return self
@method_treelize()
def __std_nr(self, *args, **kwargs):
return torch.std(self, *args, **kwargs)
return pytorch.std(self, *args, **kwargs)
@doc_from_base()
@auto_reduce(__std_r, __std_nr)
......@@ -977,14 +1047,14 @@ class Tensor(Torch, metaclass=_TensorMeta):
pass # pragma: no cover
# noinspection PyUnusedLocal
@post_reduce(torch.mean)
@post_reduce(pytorch.mean)
@method_treelize(return_type=Object)
def __mean_r(self, *args, **kwargs):
return self
@method_treelize()
def __mean_nr(self, *args, **kwargs):
return torch.mean(self, *args, **kwargs)
return pytorch.mean(self, *args, **kwargs)
@doc_from_base()
@auto_reduce(__mean_r, __mean_nr)
......
from .tensor import Tensor
__all__ = [
'Tensor'
]
from treevalue import method_treelize, TreeValue
from treevalue.utils import post_process
from ..base import Torch, auto_torch
from ..tensor import Tensor
from ...utils import replaceable_partial
auto_tensor = replaceable_partial(auto_torch, cls=Tensor)
class TensorMethod(Torch):
@post_process(auto_tensor)
@method_treelize(return_type=TreeValue, rise=True)
def __call__(self, *args, **kwargs):
return self(*args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册