提交 7c15d9f7 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): add compatiable to lower versions of torch

上级 f1799c33
......@@ -23,4 +23,3 @@ current_names
.. autofunction:: current_names
......@@ -7,3 +7,4 @@ treetensor.utils
clazz
doc
func
reflection
Reflection Utils
============================
.. py:currentmodule:: treetensor.utils
.. automodule:: treetensor.utils.reflection
removed
--------------------
.. autofunction:: removed
class_autoremove
--------------------
.. autofunction:: class_autoremove
module_autoremove
--------------------
.. autofunction:: module_autoremove
from .mark import choose_mark_with_existence_check
import pytest
_TEST_PREFIX = 'test_'
def choose_mark_with_existence_check(base, name: str = None):
def _decorator(func):
_name = name or func.__name__[len(_TEST_PREFIX):]
_mark = pytest.mark.unittest if hasattr(base, _name) else pytest.mark.ignore
func = _mark(func)
return func
return _decorator
import pytest
import torch
import treetensor.torch as ttorch
from treetensor.utils import replaceable_partial
from ..tests import choose_mark_with_existence_check
choose_mark = replaceable_partial(choose_mark_with_existence_check, base=ttorch)
# noinspection DuplicatedCode,PyUnresolvedReferences
@pytest.mark.unittest
class TestTorchFuncs:
@choose_mark()
def test_tensor(self):
t1 = ttorch.tensor(True)
assert isinstance(t1, torch.Tensor)
......@@ -34,6 +37,7 @@ class TestTorchFuncs:
}
})).all()
@choose_mark()
def test_zeros(self):
assert ttorch.all(ttorch.zeros(2, 3) == torch.zeros(2, 3))
assert ttorch.all(ttorch.zeros({
......@@ -50,6 +54,7 @@ class TestTorchFuncs:
}
}))
@choose_mark()
def test_zeros_like(self):
assert ttorch.all(
ttorch.zeros_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) ==
......@@ -73,6 +78,7 @@ class TestTorchFuncs:
})
)
@choose_mark()
def test_ones(self):
assert ttorch.all(ttorch.ones(2, 3) == torch.ones(2, 3))
assert ttorch.all(ttorch.ones({
......@@ -89,6 +95,7 @@ class TestTorchFuncs:
}
}))
@choose_mark()
def test_ones_like(self):
assert ttorch.all(
ttorch.ones_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) ==
......@@ -112,6 +119,7 @@ class TestTorchFuncs:
})
)
@choose_mark()
def test_randn(self):
_target = ttorch.randn(200, 300)
assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02
......@@ -133,6 +141,7 @@ class TestTorchFuncs:
}
})
@choose_mark()
def test_randn_like(self):
_target = ttorch.randn_like(torch.ones(200, 300))
assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02
......@@ -156,6 +165,7 @@ class TestTorchFuncs:
}
})
@choose_mark()
def test_randint(self):
_target = ttorch.randint(-10, 10, {
'a': (2, 3),
......@@ -191,6 +201,7 @@ class TestTorchFuncs:
}
})
@choose_mark()
def test_randint_like(self):
_target = ttorch.randint_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
......@@ -230,6 +241,7 @@ class TestTorchFuncs:
}
})
@choose_mark()
def test_full(self):
_target = ttorch.full({
'a': (2, 3),
......@@ -247,6 +259,7 @@ class TestTorchFuncs:
}
})
@choose_mark()
def test_full_like(self):
_target = ttorch.full_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
......@@ -266,6 +279,7 @@ class TestTorchFuncs:
}
})
@choose_mark()
def test_empty(self):
_target = ttorch.empty({
'a': (2, 3),
......@@ -282,6 +296,7 @@ class TestTorchFuncs:
}
})
@choose_mark()
def test_empty_like(self):
_target = ttorch.empty_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
......@@ -300,6 +315,7 @@ class TestTorchFuncs:
}
})
@choose_mark()
def test_all(self):
r1 = ttorch.all(torch.tensor([True, True, True]))
assert torch.is_tensor(r1)
......@@ -340,6 +356,7 @@ class TestTorchFuncs:
assert r6 == torch.tensor(False)
assert not r6
@choose_mark()
def test_any(self):
r1 = ttorch.any(torch.tensor([True, True, True]))
assert torch.is_tensor(r1)
......@@ -380,6 +397,7 @@ class TestTorchFuncs:
assert r6 == torch.tensor(False)
assert not r6
@choose_mark()
def test_eq(self):
assert ttorch.eq(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3])).all()
assert not ttorch.eq(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 2])).all()
......@@ -401,6 +419,7 @@ class TestTorchFuncs:
'b': torch.tensor([4, 5, 5]),
})).all()
@choose_mark()
def test_ne(self):
assert (ttorch.ne(
torch.tensor([[1, 2], [3, 4]]),
......@@ -422,6 +441,7 @@ class TestTorchFuncs:
'b': [True, True, False],
})).all()
@choose_mark()
def test_lt(self):
assert (ttorch.lt(
torch.tensor([[1, 2], [3, 4]]),
......@@ -443,6 +463,7 @@ class TestTorchFuncs:
'b': [True, False, False],
})).all()
@choose_mark()
def test_le(self):
assert (ttorch.le(
torch.tensor([[1, 2], [3, 4]]),
......@@ -464,6 +485,7 @@ class TestTorchFuncs:
'b': [True, False, True],
})).all()
@choose_mark()
def test_gt(self):
assert (ttorch.gt(
torch.tensor([[1, 2], [3, 4]]),
......@@ -485,6 +507,7 @@ class TestTorchFuncs:
'b': [False, True, False],
})).all()
@choose_mark()
def test_ge(self):
assert (ttorch.ge(
torch.tensor([[1, 2], [3, 4]]),
......@@ -506,6 +529,7 @@ class TestTorchFuncs:
'b': [False, True, True],
})).all()
@choose_mark()
def test_equal(self):
p1 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3]))
assert isinstance(p1, bool)
......@@ -535,6 +559,7 @@ class TestTorchFuncs:
assert isinstance(p4, bool)
assert not p4
@choose_mark()
def test_min(self):
t1 = ttorch.min(torch.tensor([1.0, 2.0, 1.5]))
assert isinstance(t1, torch.Tensor)
......@@ -548,6 +573,7 @@ class TestTorchFuncs:
'b': {'x': 0.9},
})
@choose_mark()
def test_max(self):
t1 = ttorch.max(torch.tensor([1.0, 2.0, 1.5]))
assert isinstance(t1, torch.Tensor)
......@@ -561,6 +587,7 @@ class TestTorchFuncs:
'b': {'x': 2.5, }
})
@choose_mark()
def test_sum(self):
assert ttorch.sum(torch.tensor([1.0, 2.0, 1.5])) == torch.tensor(4.5)
assert ttorch.sum(ttorch.tensor({
......@@ -568,6 +595,7 @@ class TestTorchFuncs:
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})) == torch.tensor(11.0)
@choose_mark()
def test_clone(self):
t1 = ttorch.clone(torch.tensor([1.0, 2.0, 1.5]))
assert isinstance(t1, torch.Tensor)
......@@ -582,6 +610,7 @@ class TestTorchFuncs:
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})).all()
@choose_mark()
def test_dot(self):
t1 = ttorch.dot(torch.tensor([1, 2]), torch.tensor([2, 3]))
assert isinstance(t1, torch.Tensor)
......@@ -599,6 +628,7 @@ class TestTorchFuncs:
)
assert (t2 == ttorch.tensor({'a': 38, 'b': {'x': 11}})).all()
@choose_mark()
def test_matmul(self):
t1 = ttorch.matmul(
torch.tensor([[1, 2], [3, 4]]),
......@@ -622,6 +652,7 @@ class TestTorchFuncs:
'b': {'x': 40}
})).all()
@choose_mark()
def test_mm(self):
t1 = ttorch.mm(
torch.tensor([[1, 2], [3, 4]]),
......@@ -645,6 +676,7 @@ class TestTorchFuncs:
'b': {'x': [[44, 32], [80, 59]]},
})).all()
@choose_mark()
def test_isfinite(self):
t1 = ttorch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
assert isinstance(t1, torch.Tensor)
......@@ -659,6 +691,7 @@ class TestTorchFuncs:
'b': {'x': [[True, False, True], [False, True, False]]},
}))
@choose_mark()
def test_isinf(self):
t1 = ttorch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
assert isinstance(t1, torch.Tensor)
......@@ -673,6 +706,7 @@ class TestTorchFuncs:
'b': {'x': [[False, True, False], [True, False, False]]},
}))
@choose_mark()
def test_isnan(self):
t1 = ttorch.isnan(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]))
assert isinstance(t1, torch.Tensor)
......@@ -687,6 +721,7 @@ class TestTorchFuncs:
'b': {'x': [[False, False, False], [False, False, True]]},
})).all()
@choose_mark()
def test_abs(self):
t1 = ttorch.abs(ttorch.tensor([12, 0, -3]))
assert isinstance(t1, torch.Tensor)
......@@ -701,6 +736,7 @@ class TestTorchFuncs:
'b': {'x': [[3, 1], [0, 2]]},
})).all()
@choose_mark()
def test_abs_(self):
t1 = ttorch.tensor([12, 0, -3])
assert isinstance(t1, torch.Tensor)
......@@ -721,6 +757,7 @@ class TestTorchFuncs:
'b': {'x': [[3, 1], [0, 2]]},
})).all()
@choose_mark()
def test_clamp(self):
t1 = ttorch.clamp(ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922]), min=-0.5, max=0.5)
assert isinstance(t1, torch.Tensor)
......@@ -736,6 +773,7 @@ class TestTorchFuncs:
[0.0489, -0.5000, -0.5000]]},
})) < 1e-6).all()
@choose_mark()
def test_clamp_(self):
t1 = ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922])
t1r = ttorch.clamp_(t1, min=-0.5, max=0.5)
......@@ -755,6 +793,7 @@ class TestTorchFuncs:
[0.0489, -0.5000, -0.5000]]},
})) < 1e-6).all()
@choose_mark()
def test_sign(self):
t1 = ttorch.sign(ttorch.tensor([12, 0, -3]))
assert isinstance(t1, torch.Tensor)
......@@ -770,6 +809,7 @@ class TestTorchFuncs:
[0, -1]]},
})).all()
@choose_mark()
def test_round(self):
t1 = ttorch.round(ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]))
assert isinstance(t1, torch.Tensor)
......@@ -787,6 +827,7 @@ class TestTorchFuncs:
[-5., -2., 3.]]},
})) < 1e-6).all()
@choose_mark()
def test_round_(self):
t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])
t1r = ttorch.round_(t1)
......@@ -808,6 +849,7 @@ class TestTorchFuncs:
[-5., -2., 3.]]},
})) < 1e-6).all()
@choose_mark()
def test_floor(self):
t1 = ttorch.floor(ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]))
assert isinstance(t1, torch.Tensor)
......@@ -825,6 +867,7 @@ class TestTorchFuncs:
[-5., -2., 2.]]},
})) < 1e-6).all()
@choose_mark()
def test_floor_(self):
t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])
t1r = ttorch.floor_(t1)
......@@ -846,6 +889,7 @@ class TestTorchFuncs:
[-5., -2., 2.]]},
})) < 1e-6).all()
@choose_mark()
def test_ceil(self):
t1 = ttorch.ceil(ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]))
assert isinstance(t1, torch.Tensor)
......@@ -863,6 +907,7 @@ class TestTorchFuncs:
[-4., -2., 3.]]},
})) < 1e-6).all()
@choose_mark()
def test_ceil_(self):
t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])
t1r = ttorch.ceil_(t1)
......@@ -884,6 +929,7 @@ class TestTorchFuncs:
[-4., -2., 3.]]},
})) < 1e-6).all()
@choose_mark()
def test_sigmoid(self):
t1 = ttorch.sigmoid(ttorch.tensor([1.0, 2.0, -1.5]))
assert isinstance(t1, torch.Tensor)
......@@ -899,6 +945,7 @@ class TestTorchFuncs:
[0.0759, 0.5622]]},
})) < 1e-4).all()
@choose_mark()
def test_sigmoid_(self):
t1 = ttorch.tensor([1.0, 2.0, -1.5])
t1r = ttorch.sigmoid_(t1)
......
import pytest
import torch
from treevalue import func_treelize, typetrans, TreeValue
from treevalue import typetrans, TreeValue
import treetensor.torch as ttorch
from treetensor.common import Object
from treetensor.utils import replaceable_partial
from ..tests import choose_mark_with_existence_check
_all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
choose_mark = replaceable_partial(choose_mark_with_existence_check, base=ttorch.Size)
@pytest.mark.unittest
class TestTorchSize:
def test_init(self):
@choose_mark()
def test___init__(self):
t1 = ttorch.Size([1, 2, 3])
assert isinstance(t1, torch.Size)
assert t1 == torch.Size([1, 2, 3])
......@@ -27,6 +29,7 @@ class TestTorchSize:
'c': torch.Size([5]),
})
@choose_mark()
def test_numel(self):
assert ttorch.Size({
'a': [1, 2, 3],
......@@ -34,6 +37,7 @@ class TestTorchSize:
'c': [5],
}).numel() == 23
@choose_mark()
def test_index(self):
assert ttorch.Size({
'a': [1, 2, 3],
......@@ -52,6 +56,7 @@ class TestTorchSize:
'c': [5],
}).index(100)
@choose_mark()
def test_count(self):
assert ttorch.Size({
'a': [1, 2, 3],
......
import numpy as np
import pytest
import torch
from treevalue import func_treelize, typetrans, TreeValue
import treetensor.numpy as tnp
import treetensor.torch as ttorch
from treetensor.common import Object
from treetensor.utils import replaceable_partial
from ..tests import choose_mark_with_existence_check
_all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y)
choose_mark = replaceable_partial(choose_mark_with_existence_check, base=ttorch.Tensor)
# noinspection PyUnresolvedReferences,DuplicatedCode
@pytest.mark.unittest
class TestTorchTensor:
_DEMO_1 = ttorch.Tensor({
'a': [[1, 2, 3], [4, 5, 6]],
......@@ -31,7 +32,8 @@ class TestTorchTensor:
}
})
def test_init(self):
@choose_mark()
def test___init__(self):
assert (ttorch.Tensor([1, 2, 3]) == torch.tensor([1, 2, 3])).all()
assert (ttorch.Tensor([1, 2, 3], dtype=torch.float32) == torch.FloatTensor([1, 2, 3])).all()
assert (self._DEMO_1 == typetrans(TreeValue({
......@@ -43,9 +45,11 @@ class TestTorchTensor:
}
}), ttorch.Tensor)).all()
@choose_mark()
def test_numel(self):
assert self._DEMO_1.numel() == 18
@choose_mark()
def test_numpy(self):
assert tnp.all(self._DEMO_1.numpy() == tnp.ndarray({
'a': np.array([[1, 2, 3], [4, 5, 6]]),
......@@ -56,10 +60,12 @@ class TestTorchTensor:
}
}))
@choose_mark()
def test_cpu(self):
assert ttorch.all(self._DEMO_1.cpu() == self._DEMO_1)
assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values()))
@choose_mark()
def test_to(self):
assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.Tensor({
'a': torch.FloatTensor([[1, 2, 3], [4, 5, 6]]),
......@@ -70,6 +76,7 @@ class TestTorchTensor:
}
}))
@choose_mark()
def test_all(self):
t1 = ttorch.Tensor({
'a': [True, True],
......@@ -87,6 +94,7 @@ class TestTorchTensor:
assert t2.dtype == torch.bool
assert not t2
@choose_mark()
def test_tolist(self):
assert self._DEMO_1.tolist() == Object({
'a': [[1, 2, 3], [4, 5, 6]],
......@@ -97,6 +105,7 @@ class TestTorchTensor:
}
})
@choose_mark()
def test_any(self):
t1 = ttorch.Tensor({
'a': [True, False],
......@@ -114,6 +123,7 @@ class TestTorchTensor:
assert t2.dtype == torch.bool
assert not t2
@choose_mark()
def test_max(self):
t1 = ttorch.Tensor({
'a': [1, 2],
......@@ -122,6 +132,7 @@ class TestTorchTensor:
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 3
@choose_mark()
def test_min(self):
t1 = ttorch.Tensor({
'a': [1, 2],
......@@ -130,6 +141,7 @@ class TestTorchTensor:
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == -1
@choose_mark()
def test_sum(self):
t1 = ttorch.Tensor({
'a': [1, 2],
......@@ -138,7 +150,8 @@ class TestTorchTensor:
assert isinstance(t1, torch.Tensor)
assert t1.tolist() == 7
def test_eq(self):
@choose_mark()
def test___eq__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
......@@ -150,7 +163,8 @@ class TestTorchTensor:
'b': {'x': [[False, True], [False, False]]}
})).all()
def test_ne(self):
@choose_mark()
def test___ne__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
......@@ -162,7 +176,8 @@ class TestTorchTensor:
'b': {'x': [[True, False], [True, True]]}
})).all()
def test_lt(self):
@choose_mark()
def test___lt__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
......@@ -174,7 +189,8 @@ class TestTorchTensor:
'b': {'x': [[False, False], [True, False]]}
})).all()
def test_le(self):
@choose_mark()
def test___le__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
......@@ -186,7 +202,8 @@ class TestTorchTensor:
'b': {'x': [[False, True], [True, False]]}
})).all()
def test_gt(self):
@choose_mark()
def test___gt__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
......@@ -198,7 +215,8 @@ class TestTorchTensor:
'b': {'x': [[True, False], [False, True]]}
})).all()
def test_ge(self):
@choose_mark()
def test___ge__(self):
assert ((ttorch.Tensor({
'a': [1, 2],
'b': {'x': [[0, 3], [2, -1]]}
......@@ -210,6 +228,7 @@ class TestTorchTensor:
'b': {'x': [[True, True], [False, True]]}
})).all()
@choose_mark()
def test_clone(self):
t1 = ttorch.tensor([1.0, 2.0, 1.5]).clone()
assert isinstance(t1, torch.Tensor)
......@@ -224,6 +243,7 @@ class TestTorchTensor:
'b': {'x': [[1.8, 0.9], [1.3, 2.5]]},
})).all()
@choose_mark()
def test_dot(self):
t1 = torch.tensor([1, 2]).dot(torch.tensor([2, 3]))
assert isinstance(t1, torch.Tensor)
......@@ -240,6 +260,7 @@ class TestTorchTensor:
)
assert (t2 == ttorch.tensor({'a': 38, 'b': {'x': 11}})).all()
@choose_mark()
def test_matmul(self):
t1 = torch.tensor([[1, 2], [3, 4]]).matmul(
torch.tensor([[5, 6], [7, 2]]),
......@@ -261,6 +282,7 @@ class TestTorchTensor:
'b': {'x': 40}
})).all()
@choose_mark()
def test_mm(self):
t1 = torch.tensor([[1, 2], [3, 4]]).mm(
torch.tensor([[5, 6], [7, 2]]),
......@@ -282,6 +304,7 @@ class TestTorchTensor:
'b': {'x': [[44, 32], [80, 59]]},
})).all()
@choose_mark()
def test_isfinite(self):
t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite()
assert isinstance(t1, torch.Tensor)
......@@ -296,6 +319,7 @@ class TestTorchTensor:
'b': {'x': [[True, False, True], [False, True, False]]},
}))
@choose_mark()
def test_isinf(self):
t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf()
assert isinstance(t1, torch.Tensor)
......@@ -310,6 +334,7 @@ class TestTorchTensor:
'b': {'x': [[False, True, False], [True, False, False]]},
}))
@choose_mark()
def test_isnan(self):
t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan()
assert isinstance(t1, torch.Tensor)
......@@ -324,6 +349,7 @@ class TestTorchTensor:
'b': {'x': [[False, False, False], [False, False, True]]},
})).all()
@choose_mark()
def test_abs(self):
t1 = ttorch.tensor([12, 0, -3]).abs()
assert isinstance(t1, torch.Tensor)
......@@ -338,6 +364,7 @@ class TestTorchTensor:
'b': {'x': [[3, 1], [0, 2]]},
})).all()
@choose_mark()
def test_abs_(self):
t1 = ttorch.tensor([12, 0, -3])
t1r = t1.abs_()
......@@ -356,6 +383,7 @@ class TestTorchTensor:
'b': {'x': [[3, 1], [0, 2]]},
})).all()
@choose_mark()
def test_clamp(self):
t1 = ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922]).clamp(min=-0.5, max=0.5)
assert isinstance(t1, torch.Tensor)
......@@ -371,6 +399,7 @@ class TestTorchTensor:
[0.0489, -0.5000, -0.5000]]},
})) < 1e-6).all()
@choose_mark()
def test_clamp_(self):
t1 = ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922])
t1r = t1.clamp_(min=-0.5, max=0.5)
......@@ -390,6 +419,7 @@ class TestTorchTensor:
[0.0489, -0.5000, -0.5000]]},
})) < 1e-6).all()
@choose_mark()
def test_sign(self):
t1 = ttorch.tensor([12, 0, -3]).sign()
assert isinstance(t1, torch.Tensor)
......@@ -405,6 +435,7 @@ class TestTorchTensor:
[0, -1]]},
})).all()
@choose_mark()
def test_round(self):
t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]).round()
assert isinstance(t1, torch.Tensor)
......@@ -422,6 +453,7 @@ class TestTorchTensor:
[-5., -2., 3.]]},
})) < 1e-6).all()
@choose_mark()
def test_round_(self):
t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])
t1r = t1.round_()
......@@ -443,6 +475,7 @@ class TestTorchTensor:
[-5., -2., 3.]]},
})) < 1e-6).all()
@choose_mark()
def test_floor(self):
t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]).floor()
assert isinstance(t1, torch.Tensor)
......@@ -460,6 +493,7 @@ class TestTorchTensor:
[-5., -2., 2.]]},
})) < 1e-6).all()
@choose_mark()
def test_floor_(self):
t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])
t1r = t1.floor_()
......@@ -481,6 +515,7 @@ class TestTorchTensor:
[-5., -2., 2.]]},
})) < 1e-6).all()
@choose_mark()
def test_ceil(self):
t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]).ceil()
assert isinstance(t1, torch.Tensor)
......@@ -498,6 +533,7 @@ class TestTorchTensor:
[-4., -2., 3.]]},
})) < 1e-6).all()
@choose_mark()
def test_ceil_(self):
t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])
t1r = t1.ceil_()
......@@ -519,6 +555,7 @@ class TestTorchTensor:
[-4., -2., 3.]]},
})) < 1e-6).all()
@choose_mark()
def test_sigmoid(self):
t1 = ttorch.tensor([1.0, 2.0, -1.5]).sigmoid()
assert isinstance(t1, torch.Tensor)
......@@ -534,6 +571,7 @@ class TestTorchTensor:
[0.0759, 0.5622]]},
})) < 1e-4).all()
@choose_mark()
def test_sigmoid_(self):
t1 = ttorch.tensor([1.0, 2.0, -1.5])
t1r = t1.sigmoid_()
......
import builtins
import sys
import torch
from treevalue import TreeValue
......@@ -8,7 +9,8 @@ from treevalue.utils import post_process
from .tensor import Tensor, tireduce
from ..common import Object, ireduce, return_self
from ..utils import replaceable_partial, doc_from, args_mapping
from ..utils import doc_from_base as original_doc_from_base
from ..utils import replaceable_partial, args_mapping, module_autoremove
__all__ = [
'zeros', 'zeros_like',
......@@ -31,9 +33,10 @@ func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))(
replaceable_partial(original_func_treelize, return_type=Tensor)
)
doc_from_base = replaceable_partial(original_doc_from_base, base=torch)
@doc_from(torch.zeros)
@doc_from_base()
@func_treelize()
def zeros(*args, **kwargs):
"""
......@@ -58,7 +61,7 @@ def zeros(*args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.zeros_like)
@doc_from_base()
@func_treelize()
def zeros_like(input, *args, **kwargs):
"""
......@@ -85,7 +88,7 @@ def zeros_like(input, *args, **kwargs):
return torch.zeros_like(input, *args, **kwargs)
@doc_from(torch.randn)
@doc_from_base()
@func_treelize()
def randn(*args, **kwargs):
"""
......@@ -111,7 +114,7 @@ def randn(*args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.randn_like)
@doc_from_base()
@func_treelize()
def randn_like(input, *args, **kwargs):
"""
......@@ -139,7 +142,7 @@ def randn_like(input, *args, **kwargs):
return torch.randn_like(input, *args, **kwargs)
@doc_from(torch.randint)
@doc_from_base()
@func_treelize()
def randint(*args, **kwargs):
"""
......@@ -165,7 +168,7 @@ def randint(*args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.randint_like)
@doc_from_base()
@func_treelize()
def randint_like(input, *args, **kwargs):
"""
......@@ -193,7 +196,7 @@ def randint_like(input, *args, **kwargs):
return torch.randint_like(input, *args, **kwargs)
@doc_from(torch.ones)
@doc_from_base()
@func_treelize()
def ones(*args, **kwargs):
"""
......@@ -218,7 +221,7 @@ def ones(*args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.ones_like)
@doc_from_base()
@func_treelize()
def ones_like(input, *args, **kwargs):
"""
......@@ -245,7 +248,7 @@ def ones_like(input, *args, **kwargs):
return torch.ones_like(input, *args, **kwargs)
@doc_from(torch.full)
@doc_from_base()
@func_treelize()
def full(*args, **kwargs):
"""
......@@ -270,7 +273,7 @@ def full(*args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.full_like)
@doc_from_base()
@func_treelize()
def full_like(input, *args, **kwargs):
"""
......@@ -298,7 +301,7 @@ def full_like(input, *args, **kwargs):
return torch.full_like(input, *args, **kwargs)
@doc_from(torch.empty)
@doc_from_base()
@func_treelize()
def empty(*args, **kwargs):
"""
......@@ -324,7 +327,7 @@ def empty(*args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.empty_like)
@doc_from_base()
@func_treelize()
def empty_like(input, *args, **kwargs):
"""
......@@ -353,7 +356,7 @@ def empty_like(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.all)
@doc_from_base()
@tireduce(torch.all)
@func_treelize(return_type=Object)
def all(input, *args, **kwargs):
......@@ -392,7 +395,7 @@ def all(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.any)
@doc_from_base()
@tireduce(torch.any)
@func_treelize(return_type=Object)
def any(input, *args, **kwargs):
......@@ -431,7 +434,7 @@ def any(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.min)
@doc_from_base()
@tireduce(torch.min)
@func_treelize(return_type=Object)
def min(input, *args, **kwargs):
......@@ -470,7 +473,7 @@ def min(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.max)
@doc_from_base()
@tireduce(torch.max)
@func_treelize(return_type=Object)
def max(input, *args, **kwargs):
......@@ -509,7 +512,7 @@ def max(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.sum)
@doc_from_base()
@tireduce(torch.sum)
@func_treelize(return_type=Object)
def sum(input, *args, **kwargs):
......@@ -548,7 +551,7 @@ def sum(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.eq)
@doc_from_base()
@func_treelize()
def eq(input, other, *args, **kwargs):
"""
......@@ -584,7 +587,7 @@ def eq(input, other, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.ne)
@doc_from_base()
@func_treelize()
def ne(input, other, *args, **kwargs):
"""
......@@ -620,7 +623,7 @@ def ne(input, other, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.lt)
@doc_from_base()
@func_treelize()
def lt(input, other, *args, **kwargs):
"""
......@@ -656,7 +659,7 @@ def lt(input, other, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.le)
@doc_from_base()
@func_treelize()
def le(input, other, *args, **kwargs):
"""
......@@ -692,7 +695,7 @@ def le(input, other, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.gt)
@doc_from_base()
@func_treelize()
def gt(input, other, *args, **kwargs):
"""
......@@ -728,7 +731,7 @@ def gt(input, other, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.ge)
@doc_from_base()
@func_treelize()
def ge(input, other, *args, **kwargs):
"""
......@@ -764,7 +767,7 @@ def ge(input, other, *args, **kwargs):
# noinspection PyShadowingBuiltins,PyArgumentList
@doc_from(torch.equal)
@doc_from_base()
@ireduce(builtins.all)
@func_treelize()
def equal(input, other):
......@@ -796,7 +799,7 @@ def equal(input, other):
return torch.equal(input, other)
@doc_from(torch.tensor)
@doc_from_base()
@func_treelize()
def tensor(*args, **kwargs):
"""
......@@ -823,7 +826,7 @@ def tensor(*args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.clone)
@doc_from_base()
@func_treelize()
def clone(input, *args, **kwargs):
"""
......@@ -853,7 +856,7 @@ def clone(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.dot)
@doc_from_base()
@func_treelize()
def dot(input, other, *args, **kwargs):
"""
......@@ -885,7 +888,7 @@ def dot(input, other, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.matmul)
@doc_from_base()
@func_treelize()
def matmul(input, other, *args, **kwargs):
"""
......@@ -922,7 +925,7 @@ def matmul(input, other, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.mm)
@doc_from_base()
@func_treelize()
def mm(input, mat2, *args, **kwargs):
"""
......@@ -960,7 +963,7 @@ def mm(input, mat2, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.isfinite)
@doc_from_base()
@func_treelize()
def isfinite(input):
"""
......@@ -988,7 +991,7 @@ def isfinite(input):
# noinspection PyShadowingBuiltins
@doc_from(torch.isinf)
@doc_from_base()
@func_treelize()
def isinf(input):
"""
......@@ -1016,7 +1019,7 @@ def isinf(input):
# noinspection PyShadowingBuiltins
@doc_from(torch.isnan)
@doc_from_base()
@func_treelize()
def isnan(input):
"""
......@@ -1044,7 +1047,7 @@ def isnan(input):
# noinspection PyShadowingBuiltins
@doc_from(torch.abs)
@doc_from_base()
@func_treelize()
def abs(input, *args, **kwargs):
"""
......@@ -1071,7 +1074,7 @@ def abs(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.abs_)
@doc_from_base()
@return_self
@func_treelize()
def abs_(input):
......@@ -1103,7 +1106,7 @@ def abs_(input):
# noinspection PyShadowingBuiltins
@doc_from(torch.clamp)
@doc_from_base()
@func_treelize()
def clamp(input, *args, **kwargs):
"""
......@@ -1130,7 +1133,7 @@ def clamp(input, *args, **kwargs):
# noinspection PyShadowingBuiltins,PyUnresolvedReferences
@doc_from(torch.clamp_)
@doc_from_base()
@return_self
@func_treelize()
def clamp_(input, *args, **kwargs):
......@@ -1162,7 +1165,7 @@ def clamp_(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.sign)
@doc_from_base()
@func_treelize()
def sign(input, *args, **kwargs):
"""
......@@ -1189,7 +1192,7 @@ def sign(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.round)
@doc_from_base()
@func_treelize()
def round(input, *args, **kwargs):
"""
......@@ -1219,7 +1222,7 @@ def round(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.round_)
@doc_from_base()
@return_self
@func_treelize()
def round_(input):
......@@ -1253,7 +1256,7 @@ def round_(input):
# noinspection PyShadowingBuiltins
@doc_from(torch.floor)
@doc_from_base()
@func_treelize()
def floor(input, *args, **kwargs):
"""
......@@ -1283,7 +1286,7 @@ def floor(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.floor_)
@doc_from_base()
@return_self
@func_treelize()
def floor_(input):
......@@ -1317,7 +1320,7 @@ def floor_(input):
# noinspection PyShadowingBuiltins
@doc_from(torch.ceil)
@doc_from_base()
@func_treelize()
def ceil(input, *args, **kwargs):
"""
......@@ -1347,7 +1350,7 @@ def ceil(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.ceil_)
@doc_from_base()
@return_self
@func_treelize()
def ceil_(input):
......@@ -1381,7 +1384,7 @@ def ceil_(input):
# noinspection PyShadowingBuiltins
@doc_from(torch.sigmoid)
@doc_from_base()
@func_treelize()
def sigmoid(input, *args, **kwargs):
"""
......@@ -1408,7 +1411,7 @@ def sigmoid(input, *args, **kwargs):
# noinspection PyShadowingBuiltins
@doc_from(torch.sigmoid_)
@doc_from_base()
@return_self
@func_treelize()
def sigmoid_(input):
......@@ -1437,3 +1440,6 @@ def sigmoid_(input):
[0.0759, 0.5622]])
"""
return torch.sigmoid_(input)
sys.modules[__name__] = module_autoremove(sys.modules[__name__])
......@@ -8,12 +8,14 @@ from treevalue.utils import post_process
from .base import Torch
from ..common import Object, clsmeta, ireduce
from ..utils import replaceable_partial, doc_from, current_names, args_mapping
from ..utils import doc_from_base as original_doc_from_base
from ..utils import replaceable_partial, current_names, args_mapping
func_treelize = post_process(post_process(args_mapping(
lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))(
replaceable_partial(original_func_treelize)
)
doc_from_base = replaceable_partial(original_doc_from_base, base=torch.Size)
__all__ = [
'Size'
......@@ -69,7 +71,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)):
"""
super(Torch, self).__init__(data)
@doc_from(torch.Size.numel)
@doc_from_base()
@ireduce(sum)
@func_treelize(return_type=Object)
def numel(self: torch.Size) -> Object:
......@@ -88,7 +90,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)):
"""
return self.numel()
@doc_from(torch.Size.index)
@doc_from_base()
@_post_index
@func_treelize(return_type=Object)
def index(self: torch.Size, value, *args, **kwargs) -> Object:
......@@ -120,7 +122,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)):
except ValueError:
return None
@doc_from(torch.Size.count)
@doc_from_base()
@ireduce(sum)
@func_treelize(return_type=Object)
def count(self: torch.Size, *args, **kwargs) -> Object:
......
......@@ -7,7 +7,8 @@ from .base import Torch
from .size import Size
from ..common import Object, ireduce, clsmeta, return_self
from ..numpy import ndarray
from ..utils import current_names, doc_from
from ..utils import current_names, class_autoremove, replaceable_partial
from ..utils import doc_from_base as original_doc_from_base
__all__ = [
'Tensor'
......@@ -15,6 +16,7 @@ __all__ = [
_reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {}))
tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduce)
doc_from_base = replaceable_partial(original_doc_from_base, base=torch.Tensor)
def _to_tensor(*args, **kwargs):
......@@ -29,6 +31,7 @@ def _to_tensor(*args, **kwargs):
# noinspection PyTypeChecker
@current_names()
@class_autoremove
class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
# noinspection PyUnusedLocal
def __init__(self, data, *args, **kwargs):
......@@ -63,7 +66,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
super(Torch, self).__init__(data)
@doc_from(torch.Tensor.numpy)
@doc_from_base()
@method_treelize(return_type=ndarray)
def numpy(self: torch.Tensor) -> np.ndarray:
"""
......@@ -73,7 +76,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.numpy()
@doc_from(torch.Tensor.tolist)
@doc_from_base()
@method_treelize(return_type=Object)
def tolist(self: torch.Tensor):
"""
......@@ -96,7 +99,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.tolist()
@doc_from(torch.Tensor.cpu)
@doc_from_base()
@method_treelize()
def cpu(self: torch.Tensor, *args, **kwargs):
"""
......@@ -107,7 +110,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.cpu(*args, **kwargs)
@doc_from(torch.Tensor.cuda)
@doc_from_base()
@method_treelize()
def cuda(self: torch.Tensor, *args, **kwargs):
"""
......@@ -118,7 +121,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.cuda(*args, **kwargs)
@doc_from(torch.Tensor.to)
@doc_from_base()
@method_treelize()
def to(self: torch.Tensor, *args, **kwargs):
"""
......@@ -142,7 +145,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.to(*args, **kwargs)
@doc_from(torch.Tensor.numel)
@doc_from_base()
@ireduce(sum)
@method_treelize(return_type=Object)
def numel(self: torch.Tensor):
......@@ -152,7 +155,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return self.numel()
@property
@doc_from(torch.Tensor.shape)
@doc_from_base()
@method_treelize(return_type=Size)
def shape(self: torch.Tensor):
"""
......@@ -174,7 +177,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return self.shape
# noinspection PyArgumentList
@doc_from(torch.Tensor.all)
@doc_from_base()
@tireduce(torch.all)
@method_treelize(return_type=Object)
def all(self: torch.Tensor, *args, **kwargs) -> bool:
......@@ -184,7 +187,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
return self.all(*args, **kwargs)
# noinspection PyArgumentList
@doc_from(torch.Tensor.any)
@doc_from_base()
@tireduce(torch.any)
@method_treelize(return_type=Object)
def any(self: torch.Tensor, *args, **kwargs) -> bool:
......@@ -193,7 +196,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.any(*args, **kwargs)
@doc_from(torch.Tensor.max)
@doc_from_base()
@tireduce(torch.max)
@method_treelize(return_type=Object)
def max(self: torch.Tensor, *args, **kwargs):
......@@ -202,7 +205,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.max(*args, **kwargs)
@doc_from(torch.Tensor.min)
@doc_from_base()
@tireduce(torch.min)
@method_treelize(return_type=Object)
def min(self: torch.Tensor, *args, **kwargs):
......@@ -211,7 +214,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.min(*args, **kwargs)
@doc_from(torch.Tensor.sum)
@doc_from_base()
@tireduce(torch.sum)
@method_treelize(return_type=Object)
def sum(self: torch.Tensor, *args, **kwargs):
......@@ -262,7 +265,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self >= other
@doc_from(torch.Tensor.clone)
@doc_from_base()
@method_treelize()
def clone(self, *args, **kwargs):
"""
......@@ -270,7 +273,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.clone(*args, **kwargs)
@doc_from(torch.Tensor.dot)
@doc_from_base()
@method_treelize()
def dot(self, other, *args, **kwargs):
"""
......@@ -278,7 +281,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.dot(other, *args, **kwargs)
@doc_from(torch.Tensor.mm)
@doc_from_base()
@method_treelize()
def mm(self, mat2, *args, **kwargs):
"""
......@@ -286,7 +289,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.mm(mat2, *args, **kwargs)
@doc_from(torch.Tensor.matmul)
@doc_from_base()
@method_treelize()
def matmul(self, tensor2, *args, **kwargs):
"""
......@@ -294,7 +297,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.matmul(tensor2, *args, **kwargs)
@doc_from(torch.Tensor.isfinite)
@doc_from_base()
@method_treelize()
def isfinite(self):
"""
......@@ -302,7 +305,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.isfinite()
@doc_from(torch.Tensor.isinf)
@doc_from_base()
@method_treelize()
def isinf(self):
"""
......@@ -310,7 +313,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.isinf()
@doc_from(torch.Tensor.isnan)
@doc_from_base()
@method_treelize()
def isnan(self):
"""
......@@ -318,7 +321,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.isnan()
@doc_from(torch.Tensor.abs)
@doc_from_base()
@method_treelize()
def abs(self, *args, **kwargs):
"""
......@@ -326,7 +329,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.abs(*args, **kwargs)
@doc_from(torch.Tensor.abs_)
@doc_from_base()
@return_self
@method_treelize()
def abs_(self, *args, **kwargs):
......@@ -335,7 +338,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.abs_(*args, **kwargs)
@doc_from(torch.Tensor.clamp)
@doc_from_base()
@method_treelize()
def clamp(self, *args, **kwargs):
"""
......@@ -343,7 +346,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.clamp(*args, **kwargs)
@doc_from(torch.Tensor.clamp_)
@doc_from_base()
@return_self
@method_treelize()
def clamp_(self, *args, **kwargs):
......@@ -352,7 +355,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.clamp_(*args, **kwargs)
@doc_from(torch.Tensor.sign)
@doc_from_base()
@method_treelize()
def sign(self, *args, **kwargs):
"""
......@@ -360,7 +363,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.sign(*args, **kwargs)
@doc_from(torch.Tensor.sigmoid)
@doc_from_base()
@method_treelize()
def sigmoid(self, *args, **kwargs):
"""
......@@ -368,7 +371,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.sigmoid(*args, **kwargs)
@doc_from(torch.Tensor.sigmoid_)
@doc_from_base()
@return_self
@method_treelize()
def sigmoid_(self, *args, **kwargs):
......@@ -377,7 +380,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.sigmoid_(*args, **kwargs)
@doc_from(torch.Tensor.floor)
@doc_from_base()
@method_treelize()
def floor(self, *args, **kwargs):
"""
......@@ -385,7 +388,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.floor(*args, **kwargs)
@doc_from(torch.Tensor.floor_)
@doc_from_base()
@return_self
@method_treelize()
def floor_(self, *args, **kwargs):
......@@ -394,7 +397,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.floor_(*args, **kwargs)
@doc_from(torch.Tensor.ceil)
@doc_from_base()
@method_treelize()
def ceil(self, *args, **kwargs):
"""
......@@ -402,7 +405,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.ceil(*args, **kwargs)
@doc_from(torch.Tensor.ceil_)
@doc_from_base()
@return_self
@method_treelize()
def ceil_(self, *args, **kwargs):
......@@ -411,7 +414,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.ceil_(*args, **kwargs)
@doc_from(torch.Tensor.round)
@doc_from_base()
@method_treelize()
def round(self, *args, **kwargs):
"""
......@@ -419,7 +422,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
"""
return self.round(*args, **kwargs)
@doc_from(torch.Tensor.round_)
@doc_from_base()
@return_self
@method_treelize()
def round_(self, *args, **kwargs):
......
from .clazz import *
from .doc import *
from .func import *
from .reflection import *
"""
Documentation Decorators.
"""
from .reflection import removed
__all__ = [
'doc_from',
'doc_from', 'doc_from_base',
]
_DOC_FROM_TAG = '__doc_from__'
......@@ -14,3 +16,15 @@ def doc_from(src):
return obj
return _decorator
def doc_from_base(base, name: str = None):
def _decorator(func):
_name = name or func.__name__
if hasattr(base, _name):
func = doc_from(getattr(base, _name))(func)
else:
func = removed(func)
return func
return _decorator
from types import ModuleType
__all__ = [
'removed', 'class_autoremove', 'module_autoremove',
]
_REMOVED_TAG = '__removed__'
def removed(obj):
"""
Overview:
Add ``__removed__`` attribute to the given object.
The given ``object`` will be marked as removed, will be removed when
:func:`class_autoremove` or :func:`module_autoremove` is used.
Arguments:
- obj: Given object to be marked.
Returns:
- marked: Marked object.
"""
setattr(obj, _REMOVED_TAG, True)
return obj
def _is_removed(obj) -> bool:
return not not getattr(obj, _REMOVED_TAG, False)
def class_autoremove(cls: type) -> type:
"""
Overview:
Remove the items which are marked as removed in the given ``cls``.
Arguments:
- cls (:obj:`type`): Given class.
Returns:
- marked (:obj:`type`): Marked class.
Examples::
>>> @class_autoremove
>>> class MyClass:
>>> pass
"""
for _name in dir(cls):
if _is_removed(getattr(cls, _name)):
delattr(cls, _name)
return cls
def module_autoremove(module: ModuleType):
"""
Overview:
Remove the items which are marked as removed in the given ``module``.
Arguments:
- module (:obj:`ModuleType`): Given module.
Returns:
- marked (:obj:`ModuleType`): Marked module.
Examples::
>>> # At the imports' part
>>> import sys
>>>
>>> # At the very bottom of the module
>>> sys.modules[__name__] = module_autoremove(sys.modules[__name__])
>>>
"""
if hasattr(module, '__all__'):
names = getattr(module, '__all__')
def names_postprocess(new_names):
_names = getattr(module, '__all__')
_names[:] = new_names[:]
setattr(module, '__all__', _names)
else:
names = dir(module)
# noinspection PyUnusedLocal
def names_postprocess(new_names):
pass
_new_names = []
for _name in names:
if _is_removed(getattr(module, _name)):
delattr(module, _name)
else:
_new_names.append(_name)
names_postprocess(_new_names)
return module
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册