提交 007d22dd 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): simplify unit test

上级 fbfdb128
......@@ -8,7 +8,7 @@ _DOC_FROM_TAG = '__doc_from__'
if __name__ == '__main__':
_torch_version = torch.__version__
print_title(ttorch.funcs.__name__, levelc='=')
current_module(ttorch.funcs.__name__)
current_module(ttorch.__name__)
for _name in sorted(ttorch.funcs.__all__):
_item = getattr(ttorch.funcs, _name)
......
import pytest
import torch
from treevalue import TreeValue
import treetensor.torch as ttorch
......@@ -36,14 +35,14 @@ class TestTorchFuncs:
})).all()
def test_zeros(self):
assert ttorch.all(ttorch.zeros((2, 3)) == torch.zeros(2, 3))
assert ttorch.all(ttorch.zeros(TreeValue({
assert ttorch.all(ttorch.zeros(2, 3) == torch.zeros(2, 3))
assert ttorch.all(ttorch.zeros({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
})) == ttorch.Tensor({
}) == ttorch.Tensor({
'a': torch.zeros(2, 3),
'b': torch.zeros(5, 6),
'x': {
......@@ -57,14 +56,14 @@ class TestTorchFuncs:
torch.tensor([[0, 0, 0], [0, 0, 0]]),
)
assert ttorch.all(
ttorch.zeros_like(({
ttorch.zeros_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
})) == ttorch.Tensor({
}) == ttorch.Tensor({
'a': torch.tensor([[0, 0, 0], [0, 0, 0]]),
'b': torch.tensor([0, 0, 0, 0]),
'x': {
......@@ -75,14 +74,14 @@ class TestTorchFuncs:
)
def test_ones(self):
assert ttorch.all(ttorch.ones((2, 3)) == torch.ones(2, 3))
assert ttorch.all(ttorch.ones(TreeValue({
assert ttorch.all(ttorch.ones(2, 3) == torch.ones(2, 3))
assert ttorch.all(ttorch.ones({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
})) == ttorch.Tensor({
}) == ttorch.Tensor({
'a': torch.ones(2, 3),
'b': torch.ones(5, 6),
'x': {
......@@ -96,14 +95,14 @@ class TestTorchFuncs:
torch.tensor([[1, 1, 1], [1, 1, 1]])
)
assert ttorch.all(
ttorch.ones_like(({
ttorch.ones_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
})) == ttorch.Tensor({
}) == ttorch.Tensor({
'a': torch.tensor([[1, 1, 1], [1, 1, 1]]),
'b': torch.tensor([1, 1, 1, 1]),
'x': {
......@@ -114,18 +113,18 @@ class TestTorchFuncs:
)
def test_randn(self):
_target = ttorch.randn((200, 300))
_target = ttorch.randn(200, 300)
assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300])
_target = ttorch.randn(TreeValue({
_target = ttorch.randn({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}))
})
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
......@@ -140,14 +139,14 @@ class TestTorchFuncs:
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300])
_target = ttorch.randn_like(({
_target = ttorch.randn_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32),
'x': {
'c': torch.tensor([5, 6, 7], dtype=torch.float32),
'd': torch.tensor([[[8, 9]]], dtype=torch.float32),
}
}))
})
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
......@@ -158,13 +157,13 @@ class TestTorchFuncs:
})
def test_randint(self):
_target = ttorch.randint(-10, 10, TreeValue({
_target = ttorch.randint(-10, 10, {
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}))
})
assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.Size({
......@@ -175,13 +174,13 @@ class TestTorchFuncs:
}
})
_target = ttorch.randint(10, TreeValue({
_target = ttorch.randint(10, {
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}))
})
assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.Size({
......@@ -193,14 +192,14 @@ class TestTorchFuncs:
})
def test_randint_like(self):
_target = ttorch.randint_like(({
_target = ttorch.randint_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}), -10, 10)
}, -10, 10)
assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.Size({
......@@ -212,14 +211,14 @@ class TestTorchFuncs:
}
})
_target = ttorch.randint_like(({
_target = ttorch.randint_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}), 10)
}, 10)
assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.Size({
......@@ -232,13 +231,13 @@ class TestTorchFuncs:
})
def test_full(self):
_target = ttorch.full(TreeValue({
_target = ttorch.full({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}), 233)
}, 233)
assert ttorch.all(_target == 233)
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
......@@ -249,14 +248,14 @@ class TestTorchFuncs:
})
def test_full_like(self):
_target = ttorch.full_like(({
_target = ttorch.full_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}), 233)
}, 233)
assert ttorch.all(_target == 233)
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
......@@ -268,13 +267,13 @@ class TestTorchFuncs:
})
def test_empty(self):
_target = ttorch.empty(TreeValue({
_target = ttorch.empty({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}))
})
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
......@@ -284,14 +283,14 @@ class TestTorchFuncs:
})
def test_empty_like(self):
_target = ttorch.empty_like(({
_target = ttorch.empty_like({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}))
})
assert _target.shape == ttorch.Size({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
......@@ -317,26 +316,26 @@ class TestTorchFuncs:
assert r3 == torch.tensor(False)
assert not r3
r4 = ttorch.all(({
r4 = ttorch.all({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]),
})).all()
}).all()
assert torch.is_tensor(r4)
assert r4 == torch.tensor(True)
assert r4
r5 = ttorch.all(({
r5 = ttorch.all({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
})).all()
}).all()
assert torch.is_tensor(r5)
assert r5 == torch.tensor(False)
assert not r5
r6 = ttorch.all(({
r6 = ttorch.all({
'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]),
})).all()
}).all()
assert torch.is_tensor(r6)
assert r6 == torch.tensor(False)
assert not r6
......@@ -357,26 +356,26 @@ class TestTorchFuncs:
assert r3 == torch.tensor(False)
assert not r3
r4 = ttorch.any(({
r4 = ttorch.any({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]),
})).all()
}).all()
assert torch.is_tensor(r4)
assert r4 == torch.tensor(True)
assert r4
r5 = ttorch.any(({
r5 = ttorch.any({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
})).all()
}).all()
assert torch.is_tensor(r5)
assert r5 == torch.tensor(True)
assert r5
r6 = ttorch.any(({
r6 = ttorch.any({
'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]),
})).all()
}).all()
assert torch.is_tensor(r6)
assert r6 == torch.tensor(False)
assert not r6
......@@ -387,17 +386,17 @@ class TestTorchFuncs:
assert ttorch.eq(torch.tensor([1, 1, 1]), 1).all()
assert not ttorch.eq(torch.tensor([1, 1, 2]), 1).all()
assert ttorch.eq(({
assert ttorch.eq({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}), ({
}, ({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
})).all()
assert not ttorch.eq(({
assert not ttorch.eq({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}), ({
}, ({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]),
})).all()
......@@ -411,20 +410,20 @@ class TestTorchFuncs:
assert isinstance(p2, bool)
assert not p2
p3 = ttorch.equal(({
p3 = ttorch.equal({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}), ({
}, ({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}))
assert isinstance(p3, bool)
assert p3
p4 = ttorch.equal(({
p4 = ttorch.equal({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}), ({
}, ({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]),
}))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册