From 7c15d9f7f59dcf836d3339a0ed49a1a9d16e6471 Mon Sep 17 00:00:00 2001 From: HansBug Date: Wed, 22 Sep 2021 10:59:06 +0800 Subject: [PATCH] dev(hansbug): add compatiable to lower versions of torch --- docs/source/api_doc/utils/clazz.rst | 1 - docs/source/api_doc/utils/index.rst | 1 + docs/source/api_doc/utils/reflection.rst | 25 ++++++ test/tests/__init__.py | 1 + test/tests/mark.py | 14 ++++ test/torch/test_funcs.py | 51 +++++++++++- test/torch/test_size.py | 13 +++- test/torch/test_tensor.py | 56 +++++++++++--- treetensor/torch/funcs.py | 98 +++++++++++++----------- treetensor/torch/size.py | 10 ++- treetensor/torch/tensor.py | 69 +++++++++-------- treetensor/utils/__init__.py | 1 + treetensor/utils/doc.py | 16 +++- treetensor/utils/reflection.py | 96 +++++++++++++++++++++++ 14 files changed, 352 insertions(+), 100 deletions(-) create mode 100644 docs/source/api_doc/utils/reflection.rst create mode 100644 test/tests/__init__.py create mode 100644 test/tests/mark.py create mode 100644 treetensor/utils/reflection.py diff --git a/docs/source/api_doc/utils/clazz.rst b/docs/source/api_doc/utils/clazz.rst index f1105dc47..df1ad65c4 100644 --- a/docs/source/api_doc/utils/clazz.rst +++ b/docs/source/api_doc/utils/clazz.rst @@ -23,4 +23,3 @@ current_names .. autofunction:: current_names - diff --git a/docs/source/api_doc/utils/index.rst b/docs/source/api_doc/utils/index.rst index 498b1e395..3ce38aca7 100644 --- a/docs/source/api_doc/utils/index.rst +++ b/docs/source/api_doc/utils/index.rst @@ -7,3 +7,4 @@ treetensor.utils clazz doc func + reflection diff --git a/docs/source/api_doc/utils/reflection.rst b/docs/source/api_doc/utils/reflection.rst new file mode 100644 index 000000000..01586810e --- /dev/null +++ b/docs/source/api_doc/utils/reflection.rst @@ -0,0 +1,25 @@ +Reflection Utils +============================ + +.. py:currentmodule:: treetensor.utils + +.. automodule:: treetensor.utils.reflection + + +removed +-------------------- + +.. autofunction:: removed + + +class_autoremove +-------------------- + +.. autofunction:: class_autoremove + + +module_autoremove +-------------------- + +.. autofunction:: module_autoremove + diff --git a/test/tests/__init__.py b/test/tests/__init__.py new file mode 100644 index 000000000..a3f7a94ff --- /dev/null +++ b/test/tests/__init__.py @@ -0,0 +1 @@ +from .mark import choose_mark_with_existence_check diff --git a/test/tests/mark.py b/test/tests/mark.py new file mode 100644 index 000000000..a94bb8ba3 --- /dev/null +++ b/test/tests/mark.py @@ -0,0 +1,14 @@ +import pytest + +_TEST_PREFIX = 'test_' + + +def choose_mark_with_existence_check(base, name: str = None): + def _decorator(func): + _name = name or func.__name__[len(_TEST_PREFIX):] + _mark = pytest.mark.unittest if hasattr(base, _name) else pytest.mark.ignore + + func = _mark(func) + return func + + return _decorator diff --git a/test/torch/test_funcs.py b/test/torch/test_funcs.py index 5bf386bce..d0e8e674b 100644 --- a/test/torch/test_funcs.py +++ b/test/torch/test_funcs.py @@ -1,12 +1,15 @@ -import pytest import torch import treetensor.torch as ttorch +from treetensor.utils import replaceable_partial +from ..tests import choose_mark_with_existence_check + +choose_mark = replaceable_partial(choose_mark_with_existence_check, base=ttorch) # noinspection DuplicatedCode,PyUnresolvedReferences -@pytest.mark.unittest class TestTorchFuncs: + @choose_mark() def test_tensor(self): t1 = ttorch.tensor(True) assert isinstance(t1, torch.Tensor) @@ -34,6 +37,7 @@ class TestTorchFuncs: } })).all() + @choose_mark() def test_zeros(self): assert ttorch.all(ttorch.zeros(2, 3) == torch.zeros(2, 3)) assert ttorch.all(ttorch.zeros({ @@ -50,6 +54,7 @@ class TestTorchFuncs: } })) + @choose_mark() def test_zeros_like(self): assert ttorch.all( ttorch.zeros_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) == @@ -73,6 +78,7 @@ class TestTorchFuncs: }) ) + @choose_mark() def test_ones(self): assert ttorch.all(ttorch.ones(2, 3) == torch.ones(2, 3)) assert ttorch.all(ttorch.ones({ @@ -89,6 +95,7 @@ class TestTorchFuncs: } })) + @choose_mark() def test_ones_like(self): assert ttorch.all( ttorch.ones_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) == @@ -112,6 +119,7 @@ class TestTorchFuncs: }) ) + @choose_mark() def test_randn(self): _target = ttorch.randn(200, 300) assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02 @@ -133,6 +141,7 @@ class TestTorchFuncs: } }) + @choose_mark() def test_randn_like(self): _target = ttorch.randn_like(torch.ones(200, 300)) assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02 @@ -156,6 +165,7 @@ class TestTorchFuncs: } }) + @choose_mark() def test_randint(self): _target = ttorch.randint(-10, 10, { 'a': (2, 3), @@ -191,6 +201,7 @@ class TestTorchFuncs: } }) + @choose_mark() def test_randint_like(self): _target = ttorch.randint_like({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), @@ -230,6 +241,7 @@ class TestTorchFuncs: } }) + @choose_mark() def test_full(self): _target = ttorch.full({ 'a': (2, 3), @@ -247,6 +259,7 @@ class TestTorchFuncs: } }) + @choose_mark() def test_full_like(self): _target = ttorch.full_like({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), @@ -266,6 +279,7 @@ class TestTorchFuncs: } }) + @choose_mark() def test_empty(self): _target = ttorch.empty({ 'a': (2, 3), @@ -282,6 +296,7 @@ class TestTorchFuncs: } }) + @choose_mark() def test_empty_like(self): _target = ttorch.empty_like({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), @@ -300,6 +315,7 @@ class TestTorchFuncs: } }) + @choose_mark() def test_all(self): r1 = ttorch.all(torch.tensor([True, True, True])) assert torch.is_tensor(r1) @@ -340,6 +356,7 @@ class TestTorchFuncs: assert r6 == torch.tensor(False) assert not r6 + @choose_mark() def test_any(self): r1 = ttorch.any(torch.tensor([True, True, True])) assert torch.is_tensor(r1) @@ -380,6 +397,7 @@ class TestTorchFuncs: assert r6 == torch.tensor(False) assert not r6 + @choose_mark() def test_eq(self): assert ttorch.eq(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3])).all() assert not ttorch.eq(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 2])).all() @@ -401,6 +419,7 @@ class TestTorchFuncs: 'b': torch.tensor([4, 5, 5]), })).all() + @choose_mark() def test_ne(self): assert (ttorch.ne( torch.tensor([[1, 2], [3, 4]]), @@ -422,6 +441,7 @@ class TestTorchFuncs: 'b': [True, True, False], })).all() + @choose_mark() def test_lt(self): assert (ttorch.lt( torch.tensor([[1, 2], [3, 4]]), @@ -443,6 +463,7 @@ class TestTorchFuncs: 'b': [True, False, False], })).all() + @choose_mark() def test_le(self): assert (ttorch.le( torch.tensor([[1, 2], [3, 4]]), @@ -464,6 +485,7 @@ class TestTorchFuncs: 'b': [True, False, True], })).all() + @choose_mark() def test_gt(self): assert (ttorch.gt( torch.tensor([[1, 2], [3, 4]]), @@ -485,6 +507,7 @@ class TestTorchFuncs: 'b': [False, True, False], })).all() + @choose_mark() def test_ge(self): assert (ttorch.ge( torch.tensor([[1, 2], [3, 4]]), @@ -506,6 +529,7 @@ class TestTorchFuncs: 'b': [False, True, True], })).all() + @choose_mark() def test_equal(self): p1 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3])) assert isinstance(p1, bool) @@ -535,6 +559,7 @@ class TestTorchFuncs: assert isinstance(p4, bool) assert not p4 + @choose_mark() def test_min(self): t1 = ttorch.min(torch.tensor([1.0, 2.0, 1.5])) assert isinstance(t1, torch.Tensor) @@ -548,6 +573,7 @@ class TestTorchFuncs: 'b': {'x': 0.9}, }) + @choose_mark() def test_max(self): t1 = ttorch.max(torch.tensor([1.0, 2.0, 1.5])) assert isinstance(t1, torch.Tensor) @@ -561,6 +587,7 @@ class TestTorchFuncs: 'b': {'x': 2.5, } }) + @choose_mark() def test_sum(self): assert ttorch.sum(torch.tensor([1.0, 2.0, 1.5])) == torch.tensor(4.5) assert ttorch.sum(ttorch.tensor({ @@ -568,6 +595,7 @@ class TestTorchFuncs: 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, })) == torch.tensor(11.0) + @choose_mark() def test_clone(self): t1 = ttorch.clone(torch.tensor([1.0, 2.0, 1.5])) assert isinstance(t1, torch.Tensor) @@ -582,6 +610,7 @@ class TestTorchFuncs: 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, })).all() + @choose_mark() def test_dot(self): t1 = ttorch.dot(torch.tensor([1, 2]), torch.tensor([2, 3])) assert isinstance(t1, torch.Tensor) @@ -599,6 +628,7 @@ class TestTorchFuncs: ) assert (t2 == ttorch.tensor({'a': 38, 'b': {'x': 11}})).all() + @choose_mark() def test_matmul(self): t1 = ttorch.matmul( torch.tensor([[1, 2], [3, 4]]), @@ -622,6 +652,7 @@ class TestTorchFuncs: 'b': {'x': 40} })).all() + @choose_mark() def test_mm(self): t1 = ttorch.mm( torch.tensor([[1, 2], [3, 4]]), @@ -645,6 +676,7 @@ class TestTorchFuncs: 'b': {'x': [[44, 32], [80, 59]]}, })).all() + @choose_mark() def test_isfinite(self): t1 = ttorch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) assert isinstance(t1, torch.Tensor) @@ -659,6 +691,7 @@ class TestTorchFuncs: 'b': {'x': [[True, False, True], [False, True, False]]}, })) + @choose_mark() def test_isinf(self): t1 = ttorch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) assert isinstance(t1, torch.Tensor) @@ -673,6 +706,7 @@ class TestTorchFuncs: 'b': {'x': [[False, True, False], [True, False, False]]}, })) + @choose_mark() def test_isnan(self): t1 = ttorch.isnan(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) assert isinstance(t1, torch.Tensor) @@ -687,6 +721,7 @@ class TestTorchFuncs: 'b': {'x': [[False, False, False], [False, False, True]]}, })).all() + @choose_mark() def test_abs(self): t1 = ttorch.abs(ttorch.tensor([12, 0, -3])) assert isinstance(t1, torch.Tensor) @@ -701,6 +736,7 @@ class TestTorchFuncs: 'b': {'x': [[3, 1], [0, 2]]}, })).all() + @choose_mark() def test_abs_(self): t1 = ttorch.tensor([12, 0, -3]) assert isinstance(t1, torch.Tensor) @@ -721,6 +757,7 @@ class TestTorchFuncs: 'b': {'x': [[3, 1], [0, 2]]}, })).all() + @choose_mark() def test_clamp(self): t1 = ttorch.clamp(ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922]), min=-0.5, max=0.5) assert isinstance(t1, torch.Tensor) @@ -736,6 +773,7 @@ class TestTorchFuncs: [0.0489, -0.5000, -0.5000]]}, })) < 1e-6).all() + @choose_mark() def test_clamp_(self): t1 = ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922]) t1r = ttorch.clamp_(t1, min=-0.5, max=0.5) @@ -755,6 +793,7 @@ class TestTorchFuncs: [0.0489, -0.5000, -0.5000]]}, })) < 1e-6).all() + @choose_mark() def test_sign(self): t1 = ttorch.sign(ttorch.tensor([12, 0, -3])) assert isinstance(t1, torch.Tensor) @@ -770,6 +809,7 @@ class TestTorchFuncs: [0, -1]]}, })).all() + @choose_mark() def test_round(self): t1 = ttorch.round(ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])) assert isinstance(t1, torch.Tensor) @@ -787,6 +827,7 @@ class TestTorchFuncs: [-5., -2., 3.]]}, })) < 1e-6).all() + @choose_mark() def test_round_(self): t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) t1r = ttorch.round_(t1) @@ -808,6 +849,7 @@ class TestTorchFuncs: [-5., -2., 3.]]}, })) < 1e-6).all() + @choose_mark() def test_floor(self): t1 = ttorch.floor(ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])) assert isinstance(t1, torch.Tensor) @@ -825,6 +867,7 @@ class TestTorchFuncs: [-5., -2., 2.]]}, })) < 1e-6).all() + @choose_mark() def test_floor_(self): t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) t1r = ttorch.floor_(t1) @@ -846,6 +889,7 @@ class TestTorchFuncs: [-5., -2., 2.]]}, })) < 1e-6).all() + @choose_mark() def test_ceil(self): t1 = ttorch.ceil(ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]])) assert isinstance(t1, torch.Tensor) @@ -863,6 +907,7 @@ class TestTorchFuncs: [-4., -2., 3.]]}, })) < 1e-6).all() + @choose_mark() def test_ceil_(self): t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) t1r = ttorch.ceil_(t1) @@ -884,6 +929,7 @@ class TestTorchFuncs: [-4., -2., 3.]]}, })) < 1e-6).all() + @choose_mark() def test_sigmoid(self): t1 = ttorch.sigmoid(ttorch.tensor([1.0, 2.0, -1.5])) assert isinstance(t1, torch.Tensor) @@ -899,6 +945,7 @@ class TestTorchFuncs: [0.0759, 0.5622]]}, })) < 1e-4).all() + @choose_mark() def test_sigmoid_(self): t1 = ttorch.tensor([1.0, 2.0, -1.5]) t1r = ttorch.sigmoid_(t1) diff --git a/test/torch/test_size.py b/test/torch/test_size.py index bfaa2a14b..45ed9e9cf 100644 --- a/test/torch/test_size.py +++ b/test/torch/test_size.py @@ -1,16 +1,18 @@ import pytest import torch -from treevalue import func_treelize, typetrans, TreeValue +from treevalue import typetrans, TreeValue import treetensor.torch as ttorch from treetensor.common import Object +from treetensor.utils import replaceable_partial +from ..tests import choose_mark_with_existence_check -_all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y) +choose_mark = replaceable_partial(choose_mark_with_existence_check, base=ttorch.Size) -@pytest.mark.unittest class TestTorchSize: - def test_init(self): + @choose_mark() + def test___init__(self): t1 = ttorch.Size([1, 2, 3]) assert isinstance(t1, torch.Size) assert t1 == torch.Size([1, 2, 3]) @@ -27,6 +29,7 @@ class TestTorchSize: 'c': torch.Size([5]), }) + @choose_mark() def test_numel(self): assert ttorch.Size({ 'a': [1, 2, 3], @@ -34,6 +37,7 @@ class TestTorchSize: 'c': [5], }).numel() == 23 + @choose_mark() def test_index(self): assert ttorch.Size({ 'a': [1, 2, 3], @@ -52,6 +56,7 @@ class TestTorchSize: 'c': [5], }).index(100) + @choose_mark() def test_count(self): assert ttorch.Size({ 'a': [1, 2, 3], diff --git a/test/torch/test_tensor.py b/test/torch/test_tensor.py index bfd8f8fe5..f08d76ddc 100644 --- a/test/torch/test_tensor.py +++ b/test/torch/test_tensor.py @@ -1,17 +1,18 @@ import numpy as np -import pytest import torch from treevalue import func_treelize, typetrans, TreeValue import treetensor.numpy as tnp import treetensor.torch as ttorch from treetensor.common import Object +from treetensor.utils import replaceable_partial +from ..tests import choose_mark_with_existence_check _all_is = func_treelize(return_type=ttorch.Tensor)(lambda x, y: x is y) +choose_mark = replaceable_partial(choose_mark_with_existence_check, base=ttorch.Tensor) # noinspection PyUnresolvedReferences,DuplicatedCode -@pytest.mark.unittest class TestTorchTensor: _DEMO_1 = ttorch.Tensor({ 'a': [[1, 2, 3], [4, 5, 6]], @@ -31,7 +32,8 @@ class TestTorchTensor: } }) - def test_init(self): + @choose_mark() + def test___init__(self): assert (ttorch.Tensor([1, 2, 3]) == torch.tensor([1, 2, 3])).all() assert (ttorch.Tensor([1, 2, 3], dtype=torch.float32) == torch.FloatTensor([1, 2, 3])).all() assert (self._DEMO_1 == typetrans(TreeValue({ @@ -43,9 +45,11 @@ class TestTorchTensor: } }), ttorch.Tensor)).all() + @choose_mark() def test_numel(self): assert self._DEMO_1.numel() == 18 + @choose_mark() def test_numpy(self): assert tnp.all(self._DEMO_1.numpy() == tnp.ndarray({ 'a': np.array([[1, 2, 3], [4, 5, 6]]), @@ -56,10 +60,12 @@ class TestTorchTensor: } })) + @choose_mark() def test_cpu(self): assert ttorch.all(self._DEMO_1.cpu() == self._DEMO_1) assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values())) + @choose_mark() def test_to(self): assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.Tensor({ 'a': torch.FloatTensor([[1, 2, 3], [4, 5, 6]]), @@ -70,6 +76,7 @@ class TestTorchTensor: } })) + @choose_mark() def test_all(self): t1 = ttorch.Tensor({ 'a': [True, True], @@ -87,6 +94,7 @@ class TestTorchTensor: assert t2.dtype == torch.bool assert not t2 + @choose_mark() def test_tolist(self): assert self._DEMO_1.tolist() == Object({ 'a': [[1, 2, 3], [4, 5, 6]], @@ -97,6 +105,7 @@ class TestTorchTensor: } }) + @choose_mark() def test_any(self): t1 = ttorch.Tensor({ 'a': [True, False], @@ -114,6 +123,7 @@ class TestTorchTensor: assert t2.dtype == torch.bool assert not t2 + @choose_mark() def test_max(self): t1 = ttorch.Tensor({ 'a': [1, 2], @@ -122,6 +132,7 @@ class TestTorchTensor: assert isinstance(t1, torch.Tensor) assert t1.tolist() == 3 + @choose_mark() def test_min(self): t1 = ttorch.Tensor({ 'a': [1, 2], @@ -130,6 +141,7 @@ class TestTorchTensor: assert isinstance(t1, torch.Tensor) assert t1.tolist() == -1 + @choose_mark() def test_sum(self): t1 = ttorch.Tensor({ 'a': [1, 2], @@ -138,7 +150,8 @@ class TestTorchTensor: assert isinstance(t1, torch.Tensor) assert t1.tolist() == 7 - def test_eq(self): + @choose_mark() + def test___eq__(self): assert ((ttorch.Tensor({ 'a': [1, 2], 'b': {'x': [[0, 3], [2, -1]]} @@ -150,7 +163,8 @@ class TestTorchTensor: 'b': {'x': [[False, True], [False, False]]} })).all() - def test_ne(self): + @choose_mark() + def test___ne__(self): assert ((ttorch.Tensor({ 'a': [1, 2], 'b': {'x': [[0, 3], [2, -1]]} @@ -162,7 +176,8 @@ class TestTorchTensor: 'b': {'x': [[True, False], [True, True]]} })).all() - def test_lt(self): + @choose_mark() + def test___lt__(self): assert ((ttorch.Tensor({ 'a': [1, 2], 'b': {'x': [[0, 3], [2, -1]]} @@ -174,7 +189,8 @@ class TestTorchTensor: 'b': {'x': [[False, False], [True, False]]} })).all() - def test_le(self): + @choose_mark() + def test___le__(self): assert ((ttorch.Tensor({ 'a': [1, 2], 'b': {'x': [[0, 3], [2, -1]]} @@ -186,7 +202,8 @@ class TestTorchTensor: 'b': {'x': [[False, True], [True, False]]} })).all() - def test_gt(self): + @choose_mark() + def test___gt__(self): assert ((ttorch.Tensor({ 'a': [1, 2], 'b': {'x': [[0, 3], [2, -1]]} @@ -198,7 +215,8 @@ class TestTorchTensor: 'b': {'x': [[True, False], [False, True]]} })).all() - def test_ge(self): + @choose_mark() + def test___ge__(self): assert ((ttorch.Tensor({ 'a': [1, 2], 'b': {'x': [[0, 3], [2, -1]]} @@ -210,6 +228,7 @@ class TestTorchTensor: 'b': {'x': [[True, True], [False, True]]} })).all() + @choose_mark() def test_clone(self): t1 = ttorch.tensor([1.0, 2.0, 1.5]).clone() assert isinstance(t1, torch.Tensor) @@ -224,6 +243,7 @@ class TestTorchTensor: 'b': {'x': [[1.8, 0.9], [1.3, 2.5]]}, })).all() + @choose_mark() def test_dot(self): t1 = torch.tensor([1, 2]).dot(torch.tensor([2, 3])) assert isinstance(t1, torch.Tensor) @@ -240,6 +260,7 @@ class TestTorchTensor: ) assert (t2 == ttorch.tensor({'a': 38, 'b': {'x': 11}})).all() + @choose_mark() def test_matmul(self): t1 = torch.tensor([[1, 2], [3, 4]]).matmul( torch.tensor([[5, 6], [7, 2]]), @@ -261,6 +282,7 @@ class TestTorchTensor: 'b': {'x': 40} })).all() + @choose_mark() def test_mm(self): t1 = torch.tensor([[1, 2], [3, 4]]).mm( torch.tensor([[5, 6], [7, 2]]), @@ -282,6 +304,7 @@ class TestTorchTensor: 'b': {'x': [[44, 32], [80, 59]]}, })).all() + @choose_mark() def test_isfinite(self): t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isfinite() assert isinstance(t1, torch.Tensor) @@ -296,6 +319,7 @@ class TestTorchTensor: 'b': {'x': [[True, False, True], [False, True, False]]}, })) + @choose_mark() def test_isinf(self): t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isinf() assert isinstance(t1, torch.Tensor) @@ -310,6 +334,7 @@ class TestTorchTensor: 'b': {'x': [[False, True, False], [True, False, False]]}, })) + @choose_mark() def test_isnan(self): t1 = torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')]).isnan() assert isinstance(t1, torch.Tensor) @@ -324,6 +349,7 @@ class TestTorchTensor: 'b': {'x': [[False, False, False], [False, False, True]]}, })).all() + @choose_mark() def test_abs(self): t1 = ttorch.tensor([12, 0, -3]).abs() assert isinstance(t1, torch.Tensor) @@ -338,6 +364,7 @@ class TestTorchTensor: 'b': {'x': [[3, 1], [0, 2]]}, })).all() + @choose_mark() def test_abs_(self): t1 = ttorch.tensor([12, 0, -3]) t1r = t1.abs_() @@ -356,6 +383,7 @@ class TestTorchTensor: 'b': {'x': [[3, 1], [0, 2]]}, })).all() + @choose_mark() def test_clamp(self): t1 = ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922]).clamp(min=-0.5, max=0.5) assert isinstance(t1, torch.Tensor) @@ -371,6 +399,7 @@ class TestTorchTensor: [0.0489, -0.5000, -0.5000]]}, })) < 1e-6).all() + @choose_mark() def test_clamp_(self): t1 = ttorch.tensor([-1.7120, 0.1734, -0.0478, 2.0922]) t1r = t1.clamp_(min=-0.5, max=0.5) @@ -390,6 +419,7 @@ class TestTorchTensor: [0.0489, -0.5000, -0.5000]]}, })) < 1e-6).all() + @choose_mark() def test_sign(self): t1 = ttorch.tensor([12, 0, -3]).sign() assert isinstance(t1, torch.Tensor) @@ -405,6 +435,7 @@ class TestTorchTensor: [0, -1]]}, })).all() + @choose_mark() def test_round(self): t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]).round() assert isinstance(t1, torch.Tensor) @@ -422,6 +453,7 @@ class TestTorchTensor: [-5., -2., 3.]]}, })) < 1e-6).all() + @choose_mark() def test_round_(self): t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) t1r = t1.round_() @@ -443,6 +475,7 @@ class TestTorchTensor: [-5., -2., 3.]]}, })) < 1e-6).all() + @choose_mark() def test_floor(self): t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]).floor() assert isinstance(t1, torch.Tensor) @@ -460,6 +493,7 @@ class TestTorchTensor: [-5., -2., 2.]]}, })) < 1e-6).all() + @choose_mark() def test_floor_(self): t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) t1r = t1.floor_() @@ -481,6 +515,7 @@ class TestTorchTensor: [-5., -2., 2.]]}, })) < 1e-6).all() + @choose_mark() def test_ceil(self): t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]).ceil() assert isinstance(t1, torch.Tensor) @@ -498,6 +533,7 @@ class TestTorchTensor: [-4., -2., 3.]]}, })) < 1e-6).all() + @choose_mark() def test_ceil_(self): t1 = ttorch.tensor([[1.2, -1.8], [-2.3, 2.8]]) t1r = t1.ceil_() @@ -519,6 +555,7 @@ class TestTorchTensor: [-4., -2., 3.]]}, })) < 1e-6).all() + @choose_mark() def test_sigmoid(self): t1 = ttorch.tensor([1.0, 2.0, -1.5]).sigmoid() assert isinstance(t1, torch.Tensor) @@ -534,6 +571,7 @@ class TestTorchTensor: [0.0759, 0.5622]]}, })) < 1e-4).all() + @choose_mark() def test_sigmoid_(self): t1 = ttorch.tensor([1.0, 2.0, -1.5]) t1r = t1.sigmoid_() diff --git a/treetensor/torch/funcs.py b/treetensor/torch/funcs.py index f77fbea5b..6278b21d4 100644 --- a/treetensor/torch/funcs.py +++ b/treetensor/torch/funcs.py @@ -1,4 +1,5 @@ import builtins +import sys import torch from treevalue import TreeValue @@ -8,7 +9,8 @@ from treevalue.utils import post_process from .tensor import Tensor, tireduce from ..common import Object, ireduce, return_self -from ..utils import replaceable_partial, doc_from, args_mapping +from ..utils import doc_from_base as original_doc_from_base +from ..utils import replaceable_partial, args_mapping, module_autoremove __all__ = [ 'zeros', 'zeros_like', @@ -31,9 +33,10 @@ func_treelize = post_process(post_process(args_mapping( lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))( replaceable_partial(original_func_treelize, return_type=Tensor) ) +doc_from_base = replaceable_partial(original_doc_from_base, base=torch) -@doc_from(torch.zeros) +@doc_from_base() @func_treelize() def zeros(*args, **kwargs): """ @@ -58,7 +61,7 @@ def zeros(*args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.zeros_like) +@doc_from_base() @func_treelize() def zeros_like(input, *args, **kwargs): """ @@ -85,7 +88,7 @@ def zeros_like(input, *args, **kwargs): return torch.zeros_like(input, *args, **kwargs) -@doc_from(torch.randn) +@doc_from_base() @func_treelize() def randn(*args, **kwargs): """ @@ -111,7 +114,7 @@ def randn(*args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.randn_like) +@doc_from_base() @func_treelize() def randn_like(input, *args, **kwargs): """ @@ -139,7 +142,7 @@ def randn_like(input, *args, **kwargs): return torch.randn_like(input, *args, **kwargs) -@doc_from(torch.randint) +@doc_from_base() @func_treelize() def randint(*args, **kwargs): """ @@ -165,7 +168,7 @@ def randint(*args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.randint_like) +@doc_from_base() @func_treelize() def randint_like(input, *args, **kwargs): """ @@ -193,7 +196,7 @@ def randint_like(input, *args, **kwargs): return torch.randint_like(input, *args, **kwargs) -@doc_from(torch.ones) +@doc_from_base() @func_treelize() def ones(*args, **kwargs): """ @@ -218,7 +221,7 @@ def ones(*args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.ones_like) +@doc_from_base() @func_treelize() def ones_like(input, *args, **kwargs): """ @@ -245,7 +248,7 @@ def ones_like(input, *args, **kwargs): return torch.ones_like(input, *args, **kwargs) -@doc_from(torch.full) +@doc_from_base() @func_treelize() def full(*args, **kwargs): """ @@ -270,7 +273,7 @@ def full(*args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.full_like) +@doc_from_base() @func_treelize() def full_like(input, *args, **kwargs): """ @@ -298,7 +301,7 @@ def full_like(input, *args, **kwargs): return torch.full_like(input, *args, **kwargs) -@doc_from(torch.empty) +@doc_from_base() @func_treelize() def empty(*args, **kwargs): """ @@ -324,7 +327,7 @@ def empty(*args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.empty_like) +@doc_from_base() @func_treelize() def empty_like(input, *args, **kwargs): """ @@ -353,7 +356,7 @@ def empty_like(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.all) +@doc_from_base() @tireduce(torch.all) @func_treelize(return_type=Object) def all(input, *args, **kwargs): @@ -392,7 +395,7 @@ def all(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.any) +@doc_from_base() @tireduce(torch.any) @func_treelize(return_type=Object) def any(input, *args, **kwargs): @@ -431,7 +434,7 @@ def any(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.min) +@doc_from_base() @tireduce(torch.min) @func_treelize(return_type=Object) def min(input, *args, **kwargs): @@ -470,7 +473,7 @@ def min(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.max) +@doc_from_base() @tireduce(torch.max) @func_treelize(return_type=Object) def max(input, *args, **kwargs): @@ -509,7 +512,7 @@ def max(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.sum) +@doc_from_base() @tireduce(torch.sum) @func_treelize(return_type=Object) def sum(input, *args, **kwargs): @@ -548,7 +551,7 @@ def sum(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.eq) +@doc_from_base() @func_treelize() def eq(input, other, *args, **kwargs): """ @@ -584,7 +587,7 @@ def eq(input, other, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.ne) +@doc_from_base() @func_treelize() def ne(input, other, *args, **kwargs): """ @@ -620,7 +623,7 @@ def ne(input, other, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.lt) +@doc_from_base() @func_treelize() def lt(input, other, *args, **kwargs): """ @@ -656,7 +659,7 @@ def lt(input, other, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.le) +@doc_from_base() @func_treelize() def le(input, other, *args, **kwargs): """ @@ -692,7 +695,7 @@ def le(input, other, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.gt) +@doc_from_base() @func_treelize() def gt(input, other, *args, **kwargs): """ @@ -728,7 +731,7 @@ def gt(input, other, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.ge) +@doc_from_base() @func_treelize() def ge(input, other, *args, **kwargs): """ @@ -764,7 +767,7 @@ def ge(input, other, *args, **kwargs): # noinspection PyShadowingBuiltins,PyArgumentList -@doc_from(torch.equal) +@doc_from_base() @ireduce(builtins.all) @func_treelize() def equal(input, other): @@ -796,7 +799,7 @@ def equal(input, other): return torch.equal(input, other) -@doc_from(torch.tensor) +@doc_from_base() @func_treelize() def tensor(*args, **kwargs): """ @@ -823,7 +826,7 @@ def tensor(*args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.clone) +@doc_from_base() @func_treelize() def clone(input, *args, **kwargs): """ @@ -853,7 +856,7 @@ def clone(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.dot) +@doc_from_base() @func_treelize() def dot(input, other, *args, **kwargs): """ @@ -885,7 +888,7 @@ def dot(input, other, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.matmul) +@doc_from_base() @func_treelize() def matmul(input, other, *args, **kwargs): """ @@ -922,7 +925,7 @@ def matmul(input, other, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.mm) +@doc_from_base() @func_treelize() def mm(input, mat2, *args, **kwargs): """ @@ -960,7 +963,7 @@ def mm(input, mat2, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.isfinite) +@doc_from_base() @func_treelize() def isfinite(input): """ @@ -988,7 +991,7 @@ def isfinite(input): # noinspection PyShadowingBuiltins -@doc_from(torch.isinf) +@doc_from_base() @func_treelize() def isinf(input): """ @@ -1016,7 +1019,7 @@ def isinf(input): # noinspection PyShadowingBuiltins -@doc_from(torch.isnan) +@doc_from_base() @func_treelize() def isnan(input): """ @@ -1044,7 +1047,7 @@ def isnan(input): # noinspection PyShadowingBuiltins -@doc_from(torch.abs) +@doc_from_base() @func_treelize() def abs(input, *args, **kwargs): """ @@ -1071,7 +1074,7 @@ def abs(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.abs_) +@doc_from_base() @return_self @func_treelize() def abs_(input): @@ -1103,7 +1106,7 @@ def abs_(input): # noinspection PyShadowingBuiltins -@doc_from(torch.clamp) +@doc_from_base() @func_treelize() def clamp(input, *args, **kwargs): """ @@ -1130,7 +1133,7 @@ def clamp(input, *args, **kwargs): # noinspection PyShadowingBuiltins,PyUnresolvedReferences -@doc_from(torch.clamp_) +@doc_from_base() @return_self @func_treelize() def clamp_(input, *args, **kwargs): @@ -1162,7 +1165,7 @@ def clamp_(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.sign) +@doc_from_base() @func_treelize() def sign(input, *args, **kwargs): """ @@ -1189,7 +1192,7 @@ def sign(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.round) +@doc_from_base() @func_treelize() def round(input, *args, **kwargs): """ @@ -1219,7 +1222,7 @@ def round(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.round_) +@doc_from_base() @return_self @func_treelize() def round_(input): @@ -1253,7 +1256,7 @@ def round_(input): # noinspection PyShadowingBuiltins -@doc_from(torch.floor) +@doc_from_base() @func_treelize() def floor(input, *args, **kwargs): """ @@ -1283,7 +1286,7 @@ def floor(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.floor_) +@doc_from_base() @return_self @func_treelize() def floor_(input): @@ -1317,7 +1320,7 @@ def floor_(input): # noinspection PyShadowingBuiltins -@doc_from(torch.ceil) +@doc_from_base() @func_treelize() def ceil(input, *args, **kwargs): """ @@ -1347,7 +1350,7 @@ def ceil(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.ceil_) +@doc_from_base() @return_self @func_treelize() def ceil_(input): @@ -1381,7 +1384,7 @@ def ceil_(input): # noinspection PyShadowingBuiltins -@doc_from(torch.sigmoid) +@doc_from_base() @func_treelize() def sigmoid(input, *args, **kwargs): """ @@ -1408,7 +1411,7 @@ def sigmoid(input, *args, **kwargs): # noinspection PyShadowingBuiltins -@doc_from(torch.sigmoid_) +@doc_from_base() @return_self @func_treelize() def sigmoid_(input): @@ -1437,3 +1440,6 @@ def sigmoid_(input): [0.0759, 0.5622]]) """ return torch.sigmoid_(input) + + +sys.modules[__name__] = module_autoremove(sys.modules[__name__]) diff --git a/treetensor/torch/size.py b/treetensor/torch/size.py index 5203ef4b0..2e86cfb9b 100644 --- a/treetensor/torch/size.py +++ b/treetensor/torch/size.py @@ -8,12 +8,14 @@ from treevalue.utils import post_process from .base import Torch from ..common import Object, clsmeta, ireduce -from ..utils import replaceable_partial, doc_from, current_names, args_mapping +from ..utils import doc_from_base as original_doc_from_base +from ..utils import replaceable_partial, current_names, args_mapping func_treelize = post_process(post_process(args_mapping( lambda i, x: TreeValue(x) if isinstance(x, (dict, BaseTree, TreeValue)) else x)))( replaceable_partial(original_func_treelize) ) +doc_from_base = replaceable_partial(original_doc_from_base, base=torch.Size) __all__ = [ 'Size' @@ -69,7 +71,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)): """ super(Torch, self).__init__(data) - @doc_from(torch.Size.numel) + @doc_from_base() @ireduce(sum) @func_treelize(return_type=Object) def numel(self: torch.Size) -> Object: @@ -88,7 +90,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)): """ return self.numel() - @doc_from(torch.Size.index) + @doc_from_base() @_post_index @func_treelize(return_type=Object) def index(self: torch.Size, value, *args, **kwargs) -> Object: @@ -120,7 +122,7 @@ class Size(Torch, metaclass=clsmeta(torch.Size, allow_dict=True)): except ValueError: return None - @doc_from(torch.Size.count) + @doc_from_base() @ireduce(sum) @func_treelize(return_type=Object) def count(self: torch.Size, *args, **kwargs) -> Object: diff --git a/treetensor/torch/tensor.py b/treetensor/torch/tensor.py index f33a4365b..4dba97b17 100644 --- a/treetensor/torch/tensor.py +++ b/treetensor/torch/tensor.py @@ -7,7 +7,8 @@ from .base import Torch from .size import Size from ..common import Object, ireduce, clsmeta, return_self from ..numpy import ndarray -from ..utils import current_names, doc_from +from ..utils import current_names, class_autoremove, replaceable_partial +from ..utils import doc_from_base as original_doc_from_base __all__ = [ 'Tensor' @@ -15,6 +16,7 @@ __all__ = [ _reduce_tensor_wrap = pre_process(lambda it: ((torch.tensor([*it]),), {})) tireduce = pre_process(lambda rfunc: ((_reduce_tensor_wrap(rfunc),), {}))(ireduce) +doc_from_base = replaceable_partial(original_doc_from_base, base=torch.Tensor) def _to_tensor(*args, **kwargs): @@ -29,6 +31,7 @@ def _to_tensor(*args, **kwargs): # noinspection PyTypeChecker @current_names() +@class_autoremove class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): # noinspection PyUnusedLocal def __init__(self, data, *args, **kwargs): @@ -63,7 +66,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ super(Torch, self).__init__(data) - @doc_from(torch.Tensor.numpy) + @doc_from_base() @method_treelize(return_type=ndarray) def numpy(self: torch.Tensor) -> np.ndarray: """ @@ -73,7 +76,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.numpy() - @doc_from(torch.Tensor.tolist) + @doc_from_base() @method_treelize(return_type=Object) def tolist(self: torch.Tensor): """ @@ -96,7 +99,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.tolist() - @doc_from(torch.Tensor.cpu) + @doc_from_base() @method_treelize() def cpu(self: torch.Tensor, *args, **kwargs): """ @@ -107,7 +110,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.cpu(*args, **kwargs) - @doc_from(torch.Tensor.cuda) + @doc_from_base() @method_treelize() def cuda(self: torch.Tensor, *args, **kwargs): """ @@ -118,7 +121,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.cuda(*args, **kwargs) - @doc_from(torch.Tensor.to) + @doc_from_base() @method_treelize() def to(self: torch.Tensor, *args, **kwargs): """ @@ -142,7 +145,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.to(*args, **kwargs) - @doc_from(torch.Tensor.numel) + @doc_from_base() @ireduce(sum) @method_treelize(return_type=Object) def numel(self: torch.Tensor): @@ -152,7 +155,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): return self.numel() @property - @doc_from(torch.Tensor.shape) + @doc_from_base() @method_treelize(return_type=Size) def shape(self: torch.Tensor): """ @@ -174,7 +177,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): return self.shape # noinspection PyArgumentList - @doc_from(torch.Tensor.all) + @doc_from_base() @tireduce(torch.all) @method_treelize(return_type=Object) def all(self: torch.Tensor, *args, **kwargs) -> bool: @@ -184,7 +187,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): return self.all(*args, **kwargs) # noinspection PyArgumentList - @doc_from(torch.Tensor.any) + @doc_from_base() @tireduce(torch.any) @method_treelize(return_type=Object) def any(self: torch.Tensor, *args, **kwargs) -> bool: @@ -193,7 +196,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.any(*args, **kwargs) - @doc_from(torch.Tensor.max) + @doc_from_base() @tireduce(torch.max) @method_treelize(return_type=Object) def max(self: torch.Tensor, *args, **kwargs): @@ -202,7 +205,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.max(*args, **kwargs) - @doc_from(torch.Tensor.min) + @doc_from_base() @tireduce(torch.min) @method_treelize(return_type=Object) def min(self: torch.Tensor, *args, **kwargs): @@ -211,7 +214,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.min(*args, **kwargs) - @doc_from(torch.Tensor.sum) + @doc_from_base() @tireduce(torch.sum) @method_treelize(return_type=Object) def sum(self: torch.Tensor, *args, **kwargs): @@ -262,7 +265,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self >= other - @doc_from(torch.Tensor.clone) + @doc_from_base() @method_treelize() def clone(self, *args, **kwargs): """ @@ -270,7 +273,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.clone(*args, **kwargs) - @doc_from(torch.Tensor.dot) + @doc_from_base() @method_treelize() def dot(self, other, *args, **kwargs): """ @@ -278,7 +281,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.dot(other, *args, **kwargs) - @doc_from(torch.Tensor.mm) + @doc_from_base() @method_treelize() def mm(self, mat2, *args, **kwargs): """ @@ -286,7 +289,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.mm(mat2, *args, **kwargs) - @doc_from(torch.Tensor.matmul) + @doc_from_base() @method_treelize() def matmul(self, tensor2, *args, **kwargs): """ @@ -294,7 +297,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.matmul(tensor2, *args, **kwargs) - @doc_from(torch.Tensor.isfinite) + @doc_from_base() @method_treelize() def isfinite(self): """ @@ -302,7 +305,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.isfinite() - @doc_from(torch.Tensor.isinf) + @doc_from_base() @method_treelize() def isinf(self): """ @@ -310,7 +313,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.isinf() - @doc_from(torch.Tensor.isnan) + @doc_from_base() @method_treelize() def isnan(self): """ @@ -318,7 +321,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.isnan() - @doc_from(torch.Tensor.abs) + @doc_from_base() @method_treelize() def abs(self, *args, **kwargs): """ @@ -326,7 +329,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.abs(*args, **kwargs) - @doc_from(torch.Tensor.abs_) + @doc_from_base() @return_self @method_treelize() def abs_(self, *args, **kwargs): @@ -335,7 +338,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.abs_(*args, **kwargs) - @doc_from(torch.Tensor.clamp) + @doc_from_base() @method_treelize() def clamp(self, *args, **kwargs): """ @@ -343,7 +346,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.clamp(*args, **kwargs) - @doc_from(torch.Tensor.clamp_) + @doc_from_base() @return_self @method_treelize() def clamp_(self, *args, **kwargs): @@ -352,7 +355,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.clamp_(*args, **kwargs) - @doc_from(torch.Tensor.sign) + @doc_from_base() @method_treelize() def sign(self, *args, **kwargs): """ @@ -360,7 +363,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.sign(*args, **kwargs) - @doc_from(torch.Tensor.sigmoid) + @doc_from_base() @method_treelize() def sigmoid(self, *args, **kwargs): """ @@ -368,7 +371,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.sigmoid(*args, **kwargs) - @doc_from(torch.Tensor.sigmoid_) + @doc_from_base() @return_self @method_treelize() def sigmoid_(self, *args, **kwargs): @@ -377,7 +380,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.sigmoid_(*args, **kwargs) - @doc_from(torch.Tensor.floor) + @doc_from_base() @method_treelize() def floor(self, *args, **kwargs): """ @@ -385,7 +388,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.floor(*args, **kwargs) - @doc_from(torch.Tensor.floor_) + @doc_from_base() @return_self @method_treelize() def floor_(self, *args, **kwargs): @@ -394,7 +397,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.floor_(*args, **kwargs) - @doc_from(torch.Tensor.ceil) + @doc_from_base() @method_treelize() def ceil(self, *args, **kwargs): """ @@ -402,7 +405,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.ceil(*args, **kwargs) - @doc_from(torch.Tensor.ceil_) + @doc_from_base() @return_self @method_treelize() def ceil_(self, *args, **kwargs): @@ -411,7 +414,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.ceil_(*args, **kwargs) - @doc_from(torch.Tensor.round) + @doc_from_base() @method_treelize() def round(self, *args, **kwargs): """ @@ -419,7 +422,7 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)): """ return self.round(*args, **kwargs) - @doc_from(torch.Tensor.round_) + @doc_from_base() @return_self @method_treelize() def round_(self, *args, **kwargs): diff --git a/treetensor/utils/__init__.py b/treetensor/utils/__init__.py index 45e90ec28..b7bb4de24 100644 --- a/treetensor/utils/__init__.py +++ b/treetensor/utils/__init__.py @@ -1,3 +1,4 @@ from .clazz import * from .doc import * from .func import * +from .reflection import * diff --git a/treetensor/utils/doc.py b/treetensor/utils/doc.py index d1b08ddb6..0eb7b103a 100644 --- a/treetensor/utils/doc.py +++ b/treetensor/utils/doc.py @@ -1,8 +1,10 @@ """ Documentation Decorators. """ +from .reflection import removed + __all__ = [ - 'doc_from', + 'doc_from', 'doc_from_base', ] _DOC_FROM_TAG = '__doc_from__' @@ -14,3 +16,15 @@ def doc_from(src): return obj return _decorator + + +def doc_from_base(base, name: str = None): + def _decorator(func): + _name = name or func.__name__ + if hasattr(base, _name): + func = doc_from(getattr(base, _name))(func) + else: + func = removed(func) + return func + + return _decorator diff --git a/treetensor/utils/reflection.py b/treetensor/utils/reflection.py new file mode 100644 index 000000000..265162b72 --- /dev/null +++ b/treetensor/utils/reflection.py @@ -0,0 +1,96 @@ +from types import ModuleType + +__all__ = [ + 'removed', 'class_autoremove', 'module_autoremove', +] + +_REMOVED_TAG = '__removed__' + + +def removed(obj): + """ + Overview: + Add ``__removed__`` attribute to the given object. + The given ``object`` will be marked as removed, will be removed when + :func:`class_autoremove` or :func:`module_autoremove` is used. + + Arguments: + - obj: Given object to be marked. + + Returns: + - marked: Marked object. + """ + setattr(obj, _REMOVED_TAG, True) + return obj + + +def _is_removed(obj) -> bool: + return not not getattr(obj, _REMOVED_TAG, False) + + +def class_autoremove(cls: type) -> type: + """ + Overview: + Remove the items which are marked as removed in the given ``cls``. + + Arguments: + - cls (:obj:`type`): Given class. + + Returns: + - marked (:obj:`type`): Marked class. + + Examples:: + + >>> @class_autoremove + >>> class MyClass: + >>> pass + """ + for _name in dir(cls): + if _is_removed(getattr(cls, _name)): + delattr(cls, _name) + return cls + + +def module_autoremove(module: ModuleType): + """ + Overview: + Remove the items which are marked as removed in the given ``module``. + + Arguments: + - module (:obj:`ModuleType`): Given module. + + Returns: + - marked (:obj:`ModuleType`): Marked module. + + Examples:: + + >>> # At the imports' part + >>> import sys + >>> + >>> # At the very bottom of the module + >>> sys.modules[__name__] = module_autoremove(sys.modules[__name__]) + >>> + """ + if hasattr(module, '__all__'): + names = getattr(module, '__all__') + + def names_postprocess(new_names): + _names = getattr(module, '__all__') + _names[:] = new_names[:] + setattr(module, '__all__', _names) + else: + names = dir(module) + + # noinspection PyUnusedLocal + def names_postprocess(new_names): + pass + + _new_names = [] + for _name in names: + if _is_removed(getattr(module, _name)): + delattr(module, _name) + else: + _new_names.append(_name) + + names_postprocess(_new_names) + return module -- GitLab