提交 e3e3e952 编写于 作者: HansBug's avatar HansBug 😆

test(hansbug): refactor test in treetensor.torch

上级 9dd98666
......@@ -2,23 +2,21 @@ import pytest
import torch
from treevalue import TreeValue
from treetensor.tensor import TreeTensor, zeros, zeros_like, ones, ones_like, randint, randint_like, randn, \
randn_like, full, full_like, TreeSize
from treetensor.tensor import all as _tensor_all
import treetensor.tensor as ttorch
# noinspection DuplicatedCode
@pytest.mark.unittest
class TestTensorFuncs:
def test_zeros(self):
assert _tensor_all(zeros((2, 3)) == torch.zeros(2, 3))
assert _tensor_all(zeros(TreeValue({
assert ttorch.all(ttorch.zeros((2, 3)) == torch.zeros(2, 3))
assert ttorch.all(ttorch.zeros(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
})) == TreeTensor({
})) == ttorch.TreeTensor({
'a': torch.zeros(2, 3),
'b': torch.zeros(5, 6),
'x': {
......@@ -27,19 +25,19 @@ class TestTensorFuncs:
}))
def test_zeros_like(self):
assert _tensor_all(
zeros_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) ==
assert ttorch.all(
ttorch.zeros_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) ==
torch.tensor([[0, 0, 0], [0, 0, 0]]),
)
assert _tensor_all(
zeros_like(TreeTensor({
assert ttorch.all(
ttorch.zeros_like(ttorch.TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
})) == TreeTensor({
})) == ttorch.TreeTensor({
'a': torch.tensor([[0, 0, 0], [0, 0, 0]]),
'b': torch.tensor([0, 0, 0, 0]),
'x': {
......@@ -50,14 +48,14 @@ class TestTensorFuncs:
)
def test_ones(self):
assert _tensor_all(ones((2, 3)) == torch.ones(2, 3))
assert _tensor_all(ones(TreeValue({
assert ttorch.all(ttorch.ones((2, 3)) == torch.ones(2, 3))
assert ttorch.all(ttorch.ones(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
})) == TreeTensor({
})) == ttorch.TreeTensor({
'a': torch.ones(2, 3),
'b': torch.ones(5, 6),
'x': {
......@@ -66,19 +64,19 @@ class TestTensorFuncs:
}))
def test_ones_like(self):
assert _tensor_all(
ones_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) ==
assert ttorch.all(
ttorch.ones_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) ==
torch.tensor([[1, 1, 1], [1, 1, 1]])
)
assert _tensor_all(
ones_like(TreeTensor({
assert ttorch.all(
ttorch.ones_like(ttorch.TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
})) == TreeTensor({
})) == ttorch.TreeTensor({
'a': torch.tensor([[1, 1, 1], [1, 1, 1]]),
'b': torch.tensor([1, 1, 1, 1]),
'x': {
......@@ -89,19 +87,19 @@ class TestTensorFuncs:
)
def test_randn(self):
_target = randn((200, 300))
_target = ttorch.randn((200, 300))
assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300])
_target = randn(TreeValue({
_target = ttorch.randn(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}))
assert _target.shape == TreeSize({
assert _target.shape == ttorch.TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
......@@ -110,12 +108,12 @@ class TestTensorFuncs:
})
def test_randn_like(self):
_target = randn_like(torch.ones(200, 300))
_target = ttorch.randn_like(torch.ones(200, 300))
assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300])
_target = randn_like(TreeTensor({
_target = ttorch.randn_like(ttorch.TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32),
'x': {
......@@ -123,7 +121,7 @@ class TestTensorFuncs:
'd': torch.tensor([[[8, 9]]], dtype=torch.float32),
}
}))
assert _target.shape == TreeSize({
assert _target.shape == ttorch.TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
......@@ -133,16 +131,16 @@ class TestTensorFuncs:
})
def test_randint(self):
_target = randint(TreeValue({
_target = ttorch.randint(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}), -10, 10)
assert _tensor_all(_target < 10)
assert _tensor_all(-10 <= _target)
assert _target.shape == TreeSize({
assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
......@@ -150,16 +148,16 @@ class TestTensorFuncs:
}
})
_target = randint(TreeValue({
_target = ttorch.randint(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}), 10)
assert _tensor_all(_target < 10)
assert _tensor_all(0 <= _target)
assert _target.shape == TreeSize({
assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
......@@ -168,7 +166,7 @@ class TestTensorFuncs:
})
def test_randint_like(self):
_target = randint_like(TreeTensor({
_target = ttorch.randint_like(ttorch.TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
......@@ -176,9 +174,9 @@ class TestTensorFuncs:
'd': torch.tensor([[[8, 9]]]),
}
}), -10, 10)
assert _tensor_all(_target < 10)
assert _tensor_all(-10 <= _target)
assert _target.shape == TreeSize({
assert ttorch.all(_target < 10)
assert ttorch.all(-10 <= _target)
assert _target.shape == ttorch.TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
......@@ -187,7 +185,7 @@ class TestTensorFuncs:
}
})
_target = randint_like(TreeTensor({
_target = ttorch.randint_like(ttorch.TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
......@@ -195,9 +193,9 @@ class TestTensorFuncs:
'd': torch.tensor([[[8, 9]]]),
}
}), 10)
assert _tensor_all(_target < 10)
assert _tensor_all(0 <= _target)
assert _target.shape == TreeSize({
assert ttorch.all(_target < 10)
assert ttorch.all(0 <= _target)
assert _target.shape == ttorch.TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
......@@ -207,15 +205,15 @@ class TestTensorFuncs:
})
def test_full(self):
_target = full(TreeValue({
_target = ttorch.full(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}), 233)
assert _tensor_all(_target == 233)
assert _target.shape == TreeSize({
assert ttorch.all(_target == 233)
assert _target.shape == ttorch.TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
......@@ -224,7 +222,7 @@ class TestTensorFuncs:
})
def test_full_like(self):
_target = full_like(TreeTensor({
_target = ttorch.full_like(ttorch.TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
......@@ -232,8 +230,8 @@ class TestTensorFuncs:
'd': torch.tensor([[[8, 9]]]),
}
}), 233)
assert _tensor_all(_target == 233)
assert _target.shape == TreeSize({
assert ttorch.all(_target == 233)
assert _target.shape == ttorch.TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
......@@ -243,42 +241,102 @@ class TestTensorFuncs:
})
def test_all(self):
r1 = _tensor_all(torch.tensor([1, 1, 1]) == 1)
r1 = ttorch.all(torch.tensor([True, True, True]))
assert torch.is_tensor(r1)
assert r1 == torch.tensor(True)
assert r1
r2 = _tensor_all(torch.tensor([1, 1, 2]) == 1)
r2 = ttorch.all(torch.tensor([True, True, False]))
assert torch.is_tensor(r2)
assert r2 == torch.tensor(False)
assert not r2
r3 = _tensor_all(TreeTensor({
'a': torch.Tensor([1, 2, 3]),
'b': torch.Tensor([4, 5, 6]),
'x': {
'c': torch.Tensor([7, 8, 9])
}
}) == TreeTensor({
'a': torch.Tensor([1, 2, 3]),
'b': torch.Tensor([4, 5, 6]),
'x': {
'c': torch.Tensor([7, 8, 9])
}
}))
r3 = ttorch.all(torch.tensor([False, False, False]))
assert torch.is_tensor(r3)
assert r3 == torch.tensor(True)
assert r3 == torch.tensor(False)
assert not r3
r4 = _tensor_all(TreeTensor({
'a': torch.Tensor([1, 2, 3]),
'b': torch.Tensor([4, 5, 6]),
'x': {
'c': torch.Tensor([7, 8, 9])
}
}) == TreeTensor({
'a': torch.Tensor([1, 2, 3]),
'b': torch.Tensor([4, 5, 6]),
'x': {
'c': torch.Tensor([7, 8, 8])
}
}))
r4 = ttorch.all(ttorch.TreeTensor({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]),
})).all()
assert torch.is_tensor(r4)
assert r4 == torch.tensor(False)
assert r4 == torch.tensor(True)
assert r4
r5 = ttorch.all(ttorch.TreeTensor({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
})).all()
assert torch.is_tensor(r5)
assert r5 == torch.tensor(False)
assert not r5
r6 = ttorch.all(ttorch.TreeTensor({
'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]),
})).all()
assert torch.is_tensor(r6)
assert r6 == torch.tensor(False)
assert not r6
def test_any(self):
r1 = ttorch.any(torch.tensor([True, True, True]))
assert torch.is_tensor(r1)
assert r1 == torch.tensor(True)
assert r1
r2 = ttorch.any(torch.tensor([True, True, False]))
assert torch.is_tensor(r2)
assert r2 == torch.tensor(True)
assert r2
r3 = ttorch.any(torch.tensor([False, False, False]))
assert torch.is_tensor(r3)
assert r3 == torch.tensor(False)
assert not r3
r4 = ttorch.any(ttorch.TreeTensor({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, True]),
})).all()
assert torch.is_tensor(r4)
assert r4 == torch.tensor(True)
assert r4
r5 = ttorch.any(ttorch.TreeTensor({
'a': torch.tensor([True, True, True]),
'b': torch.tensor([True, True, False]),
})).all()
assert torch.is_tensor(r5)
assert r5 == torch.tensor(True)
assert r5
r6 = ttorch.any(ttorch.TreeTensor({
'a': torch.tensor([False, False, False]),
'b': torch.tensor([False, False, False]),
})).all()
assert torch.is_tensor(r6)
assert r6 == torch.tensor(False)
assert not r6
def test_eq(self):
assert ttorch.eq(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3])).all()
assert not ttorch.eq(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 2])).all()
assert ttorch.eq(torch.tensor([1, 1, 1]), 1).all()
assert not ttorch.eq(torch.tensor([1, 1, 2]), 1).all()
assert ttorch.eq(ttorch.TreeTensor({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}), ttorch.TreeTensor({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
})).all()
assert not ttorch.eq(ttorch.TreeTensor({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}), ttorch.TreeTensor({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]),
})).all()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册