提交 007d22dd 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): simplify unit test

上级 fbfdb128
...@@ -8,7 +8,7 @@ _DOC_FROM_TAG = '__doc_from__' ...@@ -8,7 +8,7 @@ _DOC_FROM_TAG = '__doc_from__'
if __name__ == '__main__': if __name__ == '__main__':
_torch_version = torch.__version__ _torch_version = torch.__version__
print_title(ttorch.funcs.__name__, levelc='=') print_title(ttorch.funcs.__name__, levelc='=')
current_module(ttorch.funcs.__name__) current_module(ttorch.__name__)
for _name in sorted(ttorch.funcs.__all__): for _name in sorted(ttorch.funcs.__all__):
_item = getattr(ttorch.funcs, _name) _item = getattr(ttorch.funcs, _name)
......
import pytest import pytest
import torch import torch
from treevalue import TreeValue
import treetensor.torch as ttorch import treetensor.torch as ttorch
...@@ -36,14 +35,14 @@ class TestTorchFuncs: ...@@ -36,14 +35,14 @@ class TestTorchFuncs:
})).all() })).all()
def test_zeros(self): def test_zeros(self):
assert ttorch.all(ttorch.zeros((2, 3)) == torch.zeros(2, 3)) assert ttorch.all(ttorch.zeros(2, 3) == torch.zeros(2, 3))
assert ttorch.all(ttorch.zeros(TreeValue({ assert ttorch.all(ttorch.zeros({
'a': (2, 3), 'a': (2, 3),
'b': (5, 6), 'b': (5, 6),
'x': { 'x': {
'c': (2, 3, 4), 'c': (2, 3, 4),
} }
})) == ttorch.Tensor({ }) == ttorch.Tensor({
'a': torch.zeros(2, 3), 'a': torch.zeros(2, 3),
'b': torch.zeros(5, 6), 'b': torch.zeros(5, 6),
'x': { 'x': {
...@@ -57,14 +56,14 @@ class TestTorchFuncs: ...@@ -57,14 +56,14 @@ class TestTorchFuncs:
torch.tensor([[0, 0, 0], [0, 0, 0]]), torch.tensor([[0, 0, 0], [0, 0, 0]]),
) )
assert ttorch.all( assert ttorch.all(
ttorch.zeros_like(({ ttorch.zeros_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]), 'b': torch.tensor([1, 2, 3, 4]),
'x': { 'x': {
'c': torch.tensor([5, 6, 7]), 'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]), 'd': torch.tensor([[[8, 9]]]),
} }
})) == ttorch.Tensor({ }) == ttorch.Tensor({
'a': torch.tensor([[0, 0, 0], [0, 0, 0]]), 'a': torch.tensor([[0, 0, 0], [0, 0, 0]]),
'b': torch.tensor([0, 0, 0, 0]), 'b': torch.tensor([0, 0, 0, 0]),
'x': { 'x': {
...@@ -75,14 +74,14 @@ class TestTorchFuncs: ...@@ -75,14 +74,14 @@ class TestTorchFuncs:
) )
def test_ones(self): def test_ones(self):
assert ttorch.all(ttorch.ones((2, 3)) == torch.ones(2, 3)) assert ttorch.all(ttorch.ones(2, 3) == torch.ones(2, 3))
assert ttorch.all(ttorch.ones(TreeValue({ assert ttorch.all(ttorch.ones({
'a': (2, 3), 'a': (2, 3),
'b': (5, 6), 'b': (5, 6),
'x': { 'x': {
'c': (2, 3, 4), 'c': (2, 3, 4),
} }
})) == ttorch.Tensor({ }) == ttorch.Tensor({
'a': torch.ones(2, 3), 'a': torch.ones(2, 3),
'b': torch.ones(5, 6), 'b': torch.ones(5, 6),
'x': { 'x': {
...@@ -96,14 +95,14 @@ class TestTorchFuncs: ...@@ -96,14 +95,14 @@ class TestTorchFuncs:
torch.tensor([[1, 1, 1], [1, 1, 1]]) torch.tensor([[1, 1, 1], [1, 1, 1]])
) )
assert ttorch.all( assert ttorch.all(
ttorch.ones_like(({ ttorch.ones_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]), 'b': torch.tensor([1, 2, 3, 4]),
'x': { 'x': {
'c': torch.tensor([5, 6, 7]), 'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]), 'd': torch.tensor([[[8, 9]]]),
} }
})) == ttorch.Tensor({ }) == ttorch.Tensor({
'a': torch.tensor([[1, 1, 1], [1, 1, 1]]), 'a': torch.tensor([[1, 1, 1], [1, 1, 1]]),
'b': torch.tensor([1, 1, 1, 1]), 'b': torch.tensor([1, 1, 1, 1]),
'x': { 'x': {
...@@ -114,18 +113,18 @@ class TestTorchFuncs: ...@@ -114,18 +113,18 @@ class TestTorchFuncs:
) )
def test_randn(self): def test_randn(self):
_target = ttorch.randn((200, 300)) _target = ttorch.randn(200, 300)
assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02 assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02 assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300]) assert _target.shape == torch.Size([200, 300])
_target = ttorch.randn(TreeValue({ _target = ttorch.randn({
'a': (2, 3), 'a': (2, 3),
'b': (5, 6), 'b': (5, 6),
'x': { 'x': {
'c': (2, 3, 4), 'c': (2, 3, 4),
} }
})) })
assert _target.shape == ttorch.Size({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]), 'b': torch.Size([5, 6]),
...@@ -140,14 +139,14 @@ class TestTorchFuncs: ...@@ -140,14 +139,14 @@ class TestTorchFuncs:
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02 assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300]) assert _target.shape == torch.Size([200, 300])
_target = ttorch.randn_like(({ _target = ttorch.randn_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32), 'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32),
'x': { 'x': {
'c': torch.tensor([5, 6, 7], dtype=torch.float32), 'c': torch.tensor([5, 6, 7], dtype=torch.float32),
'd': torch.tensor([[[8, 9]]], dtype=torch.float32), 'd': torch.tensor([[[8, 9]]], dtype=torch.float32),
} }
})) })
assert _target.shape == ttorch.Size({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([4]), 'b': torch.Size([4]),
...@@ -158,13 +157,13 @@ class TestTorchFuncs: ...@@ -158,13 +157,13 @@ class TestTorchFuncs:
}) })
def test_randint(self): def test_randint(self):
_target = ttorch.randint(-10, 10, TreeValue({ _target = ttorch.randint(-10, 10, {
'a': (2, 3), 'a': (2, 3),
'b': (5, 6), 'b': (5, 6),
'x': { 'x': {
'c': (2, 3, 4), 'c': (2, 3, 4),
} }
})) })
assert ttorch.all(_target < 10) assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target) assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.Size({ assert _target.shape == ttorch.Size({
...@@ -175,13 +174,13 @@ class TestTorchFuncs: ...@@ -175,13 +174,13 @@ class TestTorchFuncs:
} }
}) })
_target = ttorch.randint(10, TreeValue({ _target = ttorch.randint(10, {
'a': (2, 3), 'a': (2, 3),
'b': (5, 6), 'b': (5, 6),
'x': { 'x': {
'c': (2, 3, 4), 'c': (2, 3, 4),
} }
})) })
assert ttorch.all(_target < 10) assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target) assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.Size({ assert _target.shape == ttorch.Size({
...@@ -193,14 +192,14 @@ class TestTorchFuncs: ...@@ -193,14 +192,14 @@ class TestTorchFuncs:
}) })
def test_randint_like(self): def test_randint_like(self):
_target = ttorch.randint_like(({ _target = ttorch.randint_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]), 'b': torch.tensor([1, 2, 3, 4]),
'x': { 'x': {
'c': torch.tensor([5, 6, 7]), 'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]), 'd': torch.tensor([[[8, 9]]]),
} }
}), -10, 10) }, -10, 10)
assert ttorch.all(_target < 10) assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target) assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.Size({ assert _target.shape == ttorch.Size({
...@@ -212,14 +211,14 @@ class TestTorchFuncs: ...@@ -212,14 +211,14 @@ class TestTorchFuncs:
} }
}) })
_target = ttorch.randint_like(({ _target = ttorch.randint_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]), 'b': torch.tensor([1, 2, 3, 4]),
'x': { 'x': {
'c': torch.tensor([5, 6, 7]), 'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]), 'd': torch.tensor([[[8, 9]]]),
} }
}), 10) }, 10)
assert ttorch.all(_target < 10) assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target) assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.Size({ assert _target.shape == ttorch.Size({
...@@ -232,13 +231,13 @@ class TestTorchFuncs: ...@@ -232,13 +231,13 @@ class TestTorchFuncs:
}) })
def test_full(self): def test_full(self):
_target = ttorch.full(TreeValue({ _target = ttorch.full({
'a': (2, 3), 'a': (2, 3),
'b': (5, 6), 'b': (5, 6),
'x': { 'x': {
'c': (2, 3, 4), 'c': (2, 3, 4),
} }
}), 233) }, 233)
assert ttorch.all(_target == 233) assert ttorch.all(_target == 233)
assert _target.shape == ttorch.Size({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
...@@ -249,14 +248,14 @@ class TestTorchFuncs: ...@@ -249,14 +248,14 @@ class TestTorchFuncs:
}) })
def test_full_like(self): def test_full_like(self):
_target = ttorch.full_like(({ _target = ttorch.full_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]), 'b': torch.tensor([1, 2, 3, 4]),
'x': { 'x': {
'c': torch.tensor([5, 6, 7]), 'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]), 'd': torch.tensor([[[8, 9]]]),
} }
}), 233) }, 233)
assert ttorch.all(_target == 233) assert ttorch.all(_target == 233)
assert _target.shape == ttorch.Size({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
...@@ -268,13 +267,13 @@ class TestTorchFuncs: ...@@ -268,13 +267,13 @@ class TestTorchFuncs:
}) })
def test_empty(self): def test_empty(self):
_target = ttorch.empty(TreeValue({ _target = ttorch.empty({
'a': (2, 3), 'a': (2, 3),
'b': (5, 6), 'b': (5, 6),
'x': { 'x': {
'c': (2, 3, 4), 'c': (2, 3, 4),
} }
})) })
assert _target.shape == ttorch.Size({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]), 'b': torch.Size([5, 6]),
...@@ -284,14 +283,14 @@ class TestTorchFuncs: ...@@ -284,14 +283,14 @@ class TestTorchFuncs:
}) })
def test_empty_like(self): def test_empty_like(self):
_target = ttorch.empty_like(({ _target = ttorch.empty_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]), 'b': torch.tensor([1, 2, 3, 4]),
'x': { 'x': {
'c': torch.tensor([5, 6, 7]), 'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]), 'd': torch.tensor([[[8, 9]]]),
} }
})) })
assert _target.shape == ttorch.Size({ assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]), 'a': torch.Size([2, 3]),
'b': torch.Size([4]), 'b': torch.Size([4]),
...@@ -317,26 +316,26 @@ class TestTorchFuncs: ...@@ -317,26 +316,26 @@ class TestTorchFuncs:
assert r3 == torch.tensor(False) assert r3 == torch.tensor(False)
assert not r3 assert not r3
r4 = ttorch.all(({ r4 = ttorch.all({
'a': torch.tensor([True, True, True]), 'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]), 'b': torch.tensor([True, True, True]),
})).all() }).all()
assert torch.is_tensor(r4) assert torch.is_tensor(r4)
assert r4 == torch.tensor(True) assert r4 == torch.tensor(True)
assert r4 assert r4
r5 = ttorch.all(({ r5 = ttorch.all({
'a': torch.tensor([True, True, True]), 'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]), 'b': torch.tensor([True, True, False]),
})).all() }).all()
assert torch.is_tensor(r5) assert torch.is_tensor(r5)
assert r5 == torch.tensor(False) assert r5 == torch.tensor(False)
assert not r5 assert not r5
r6 = ttorch.all(({ r6 = ttorch.all({
'a': torch.tensor([False, False, False]), 'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]), 'b': torch.tensor([False, False, False]),
})).all() }).all()
assert torch.is_tensor(r6) assert torch.is_tensor(r6)
assert r6 == torch.tensor(False) assert r6 == torch.tensor(False)
assert not r6 assert not r6
...@@ -357,26 +356,26 @@ class TestTorchFuncs: ...@@ -357,26 +356,26 @@ class TestTorchFuncs:
assert r3 == torch.tensor(False) assert r3 == torch.tensor(False)
assert not r3 assert not r3
r4 = ttorch.any(({ r4 = ttorch.any({
'a': torch.tensor([True, True, True]), 'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]), 'b': torch.tensor([True, True, True]),
})).all() }).all()
assert torch.is_tensor(r4) assert torch.is_tensor(r4)
assert r4 == torch.tensor(True) assert r4 == torch.tensor(True)
assert r4 assert r4
r5 = ttorch.any(({ r5 = ttorch.any({
'a': torch.tensor([True, True, True]), 'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]), 'b': torch.tensor([True, True, False]),
})).all() }).all()
assert torch.is_tensor(r5) assert torch.is_tensor(r5)
assert r5 == torch.tensor(True) assert r5 == torch.tensor(True)
assert r5 assert r5
r6 = ttorch.any(({ r6 = ttorch.any({
'a': torch.tensor([False, False, False]), 'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]), 'b': torch.tensor([False, False, False]),
})).all() }).all()
assert torch.is_tensor(r6) assert torch.is_tensor(r6)
assert r6 == torch.tensor(False) assert r6 == torch.tensor(False)
assert not r6 assert not r6
...@@ -387,17 +386,17 @@ class TestTorchFuncs: ...@@ -387,17 +386,17 @@ class TestTorchFuncs:
assert ttorch.eq(torch.tensor([1, 1, 1]), 1).all() assert ttorch.eq(torch.tensor([1, 1, 1]), 1).all()
assert not ttorch.eq(torch.tensor([1, 1, 2]), 1).all() assert not ttorch.eq(torch.tensor([1, 1, 2]), 1).all()
assert ttorch.eq(({ assert ttorch.eq({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]), 'b': torch.tensor([4, 5, 6]),
}), ({ }, ({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]), 'b': torch.tensor([4, 5, 6]),
})).all() })).all()
assert not ttorch.eq(({ assert not ttorch.eq({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]), 'b': torch.tensor([4, 5, 6]),
}), ({ }, ({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]), 'b': torch.tensor([4, 5, 5]),
})).all() })).all()
...@@ -411,20 +410,20 @@ class TestTorchFuncs: ...@@ -411,20 +410,20 @@ class TestTorchFuncs:
assert isinstance(p2, bool) assert isinstance(p2, bool)
assert not p2 assert not p2
p3 = ttorch.equal(({ p3 = ttorch.equal({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]), 'b': torch.tensor([4, 5, 6]),
}), ({ }, ({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]), 'b': torch.tensor([4, 5, 6]),
})) }))
assert isinstance(p3, bool) assert isinstance(p3, bool)
assert p3 assert p3
p4 = ttorch.equal(({ p4 = ttorch.equal({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]), 'b': torch.tensor([4, 5, 6]),
}), ({ }, ({
'a': torch.tensor([1, 2, 3]), 'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]), 'b': torch.tensor([4, 5, 5]),
})) }))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册