From e3e3e9522550a9d2c2b0b9d333e2f1e47d12c806 Mon Sep 17 00:00:00 2001 From: HansBug Date: Sat, 11 Sep 2021 19:35:23 +0800 Subject: [PATCH] test(hansbug): refactor test in treetensor.torch --- test/tensor/test_funcs.py | 212 ++++++++++++++++++++++++-------------- 1 file changed, 135 insertions(+), 77 deletions(-) diff --git a/test/tensor/test_funcs.py b/test/tensor/test_funcs.py index f4d5745a3..c825eba6c 100644 --- a/test/tensor/test_funcs.py +++ b/test/tensor/test_funcs.py @@ -2,23 +2,21 @@ import pytest import torch from treevalue import TreeValue -from treetensor.tensor import TreeTensor, zeros, zeros_like, ones, ones_like, randint, randint_like, randn, \ - randn_like, full, full_like, TreeSize -from treetensor.tensor import all as _tensor_all +import treetensor.tensor as ttorch # noinspection DuplicatedCode @pytest.mark.unittest class TestTensorFuncs: def test_zeros(self): - assert _tensor_all(zeros((2, 3)) == torch.zeros(2, 3)) - assert _tensor_all(zeros(TreeValue({ + assert ttorch.all(ttorch.zeros((2, 3)) == torch.zeros(2, 3)) + assert ttorch.all(ttorch.zeros(TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } - })) == TreeTensor({ + })) == ttorch.TreeTensor({ 'a': torch.zeros(2, 3), 'b': torch.zeros(5, 6), 'x': { @@ -27,19 +25,19 @@ class TestTensorFuncs: })) def test_zeros_like(self): - assert _tensor_all( - zeros_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) == + assert ttorch.all( + ttorch.zeros_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) == torch.tensor([[0, 0, 0], [0, 0, 0]]), ) - assert _tensor_all( - zeros_like(TreeTensor({ + assert ttorch.all( + ttorch.zeros_like(ttorch.TreeTensor({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([1, 2, 3, 4]), 'x': { 'c': torch.tensor([5, 6, 7]), 'd': torch.tensor([[[8, 9]]]), } - })) == TreeTensor({ + })) == ttorch.TreeTensor({ 'a': torch.tensor([[0, 0, 0], [0, 0, 0]]), 'b': torch.tensor([0, 0, 0, 0]), 'x': { @@ -50,14 +48,14 @@ class TestTensorFuncs: ) def test_ones(self): - assert _tensor_all(ones((2, 3)) == torch.ones(2, 3)) - assert _tensor_all(ones(TreeValue({ + assert ttorch.all(ttorch.ones((2, 3)) == torch.ones(2, 3)) + assert ttorch.all(ttorch.ones(TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } - })) == TreeTensor({ + })) == ttorch.TreeTensor({ 'a': torch.ones(2, 3), 'b': torch.ones(5, 6), 'x': { @@ -66,19 +64,19 @@ class TestTensorFuncs: })) def test_ones_like(self): - assert _tensor_all( - ones_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) == + assert ttorch.all( + ttorch.ones_like(torch.tensor([[1, 2, 3], [4, 5, 6]])) == torch.tensor([[1, 1, 1], [1, 1, 1]]) ) - assert _tensor_all( - ones_like(TreeTensor({ + assert ttorch.all( + ttorch.ones_like(ttorch.TreeTensor({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([1, 2, 3, 4]), 'x': { 'c': torch.tensor([5, 6, 7]), 'd': torch.tensor([[[8, 9]]]), } - })) == TreeTensor({ + })) == ttorch.TreeTensor({ 'a': torch.tensor([[1, 1, 1], [1, 1, 1]]), 'b': torch.tensor([1, 1, 1, 1]), 'x': { @@ -89,19 +87,19 @@ class TestTensorFuncs: ) def test_randn(self): - _target = randn((200, 300)) + _target = ttorch.randn((200, 300)) assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02 assert 0.98 <= _target.view(60000).std().tolist() <= 1.02 assert _target.shape == torch.Size([200, 300]) - _target = randn(TreeValue({ + _target = ttorch.randn(TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } })) - assert _target.shape == TreeSize({ + assert _target.shape == ttorch.TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), 'x': { @@ -110,12 +108,12 @@ class TestTensorFuncs: }) def test_randn_like(self): - _target = randn_like(torch.ones(200, 300)) + _target = ttorch.randn_like(torch.ones(200, 300)) assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02 assert 0.98 <= _target.view(60000).std().tolist() <= 1.02 assert _target.shape == torch.Size([200, 300]) - _target = randn_like(TreeTensor({ + _target = ttorch.randn_like(ttorch.TreeTensor({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32), 'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32), 'x': { @@ -123,7 +121,7 @@ class TestTensorFuncs: 'd': torch.tensor([[[8, 9]]], dtype=torch.float32), } })) - assert _target.shape == TreeSize({ + assert _target.shape == ttorch.TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([4]), 'x': { @@ -133,16 +131,16 @@ class TestTensorFuncs: }) def test_randint(self): - _target = randint(TreeValue({ + _target = ttorch.randint(TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } }), -10, 10) - assert _tensor_all(_target < 10) - assert _tensor_all(-10 <= _target) - assert _target.shape == TreeSize({ + assert ttorch.all(_target < 10) + assert ttorch.all(-10 <= _target) + assert _target.shape == ttorch.TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), 'x': { @@ -150,16 +148,16 @@ class TestTensorFuncs: } }) - _target = randint(TreeValue({ + _target = ttorch.randint(TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } }), 10) - assert _tensor_all(_target < 10) - assert _tensor_all(0 <= _target) - assert _target.shape == TreeSize({ + assert ttorch.all(_target < 10) + assert ttorch.all(0 <= _target) + assert _target.shape == ttorch.TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), 'x': { @@ -168,7 +166,7 @@ class TestTensorFuncs: }) def test_randint_like(self): - _target = randint_like(TreeTensor({ + _target = ttorch.randint_like(ttorch.TreeTensor({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([1, 2, 3, 4]), 'x': { @@ -176,9 +174,9 @@ class TestTensorFuncs: 'd': torch.tensor([[[8, 9]]]), } }), -10, 10) - assert _tensor_all(_target < 10) - assert _tensor_all(-10 <= _target) - assert _target.shape == TreeSize({ + assert ttorch.all(_target < 10) + assert ttorch.all(-10 <= _target) + assert _target.shape == ttorch.TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([4]), 'x': { @@ -187,7 +185,7 @@ class TestTensorFuncs: } }) - _target = randint_like(TreeTensor({ + _target = ttorch.randint_like(ttorch.TreeTensor({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([1, 2, 3, 4]), 'x': { @@ -195,9 +193,9 @@ class TestTensorFuncs: 'd': torch.tensor([[[8, 9]]]), } }), 10) - assert _tensor_all(_target < 10) - assert _tensor_all(0 <= _target) - assert _target.shape == TreeSize({ + assert ttorch.all(_target < 10) + assert ttorch.all(0 <= _target) + assert _target.shape == ttorch.TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([4]), 'x': { @@ -207,15 +205,15 @@ class TestTensorFuncs: }) def test_full(self): - _target = full(TreeValue({ + _target = ttorch.full(TreeValue({ 'a': (2, 3), 'b': (5, 6), 'x': { 'c': (2, 3, 4), } }), 233) - assert _tensor_all(_target == 233) - assert _target.shape == TreeSize({ + assert ttorch.all(_target == 233) + assert _target.shape == ttorch.TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([5, 6]), 'x': { @@ -224,7 +222,7 @@ class TestTensorFuncs: }) def test_full_like(self): - _target = full_like(TreeTensor({ + _target = ttorch.full_like(ttorch.TreeTensor({ 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'b': torch.tensor([1, 2, 3, 4]), 'x': { @@ -232,8 +230,8 @@ class TestTensorFuncs: 'd': torch.tensor([[[8, 9]]]), } }), 233) - assert _tensor_all(_target == 233) - assert _target.shape == TreeSize({ + assert ttorch.all(_target == 233) + assert _target.shape == ttorch.TreeSize({ 'a': torch.Size([2, 3]), 'b': torch.Size([4]), 'x': { @@ -243,42 +241,102 @@ class TestTensorFuncs: }) def test_all(self): - r1 = _tensor_all(torch.tensor([1, 1, 1]) == 1) + r1 = ttorch.all(torch.tensor([True, True, True])) assert torch.is_tensor(r1) assert r1 == torch.tensor(True) + assert r1 - r2 = _tensor_all(torch.tensor([1, 1, 2]) == 1) + r2 = ttorch.all(torch.tensor([True, True, False])) assert torch.is_tensor(r2) assert r2 == torch.tensor(False) + assert not r2 - r3 = _tensor_all(TreeTensor({ - 'a': torch.Tensor([1, 2, 3]), - 'b': torch.Tensor([4, 5, 6]), - 'x': { - 'c': torch.Tensor([7, 8, 9]) - } - }) == TreeTensor({ - 'a': torch.Tensor([1, 2, 3]), - 'b': torch.Tensor([4, 5, 6]), - 'x': { - 'c': torch.Tensor([7, 8, 9]) - } - })) + r3 = ttorch.all(torch.tensor([False, False, False])) assert torch.is_tensor(r3) - assert r3 == torch.tensor(True) + assert r3 == torch.tensor(False) + assert not r3 - r4 = _tensor_all(TreeTensor({ - 'a': torch.Tensor([1, 2, 3]), - 'b': torch.Tensor([4, 5, 6]), - 'x': { - 'c': torch.Tensor([7, 8, 9]) - } - }) == TreeTensor({ - 'a': torch.Tensor([1, 2, 3]), - 'b': torch.Tensor([4, 5, 6]), - 'x': { - 'c': torch.Tensor([7, 8, 8]) - } - })) + r4 = ttorch.all(ttorch.TreeTensor({ + 'a': torch.tensor([True, True, True]), + 'b': torch.tensor([True, True, True]), + })).all() assert torch.is_tensor(r4) - assert r4 == torch.tensor(False) + assert r4 == torch.tensor(True) + assert r4 + + r5 = ttorch.all(ttorch.TreeTensor({ + 'a': torch.tensor([True, True, True]), + 'b': torch.tensor([True, True, False]), + })).all() + assert torch.is_tensor(r5) + assert r5 == torch.tensor(False) + assert not r5 + + r6 = ttorch.all(ttorch.TreeTensor({ + 'a': torch.tensor([False, False, False]), + 'b': torch.tensor([False, False, False]), + })).all() + assert torch.is_tensor(r6) + assert r6 == torch.tensor(False) + assert not r6 + + def test_any(self): + r1 = ttorch.any(torch.tensor([True, True, True])) + assert torch.is_tensor(r1) + assert r1 == torch.tensor(True) + assert r1 + + r2 = ttorch.any(torch.tensor([True, True, False])) + assert torch.is_tensor(r2) + assert r2 == torch.tensor(True) + assert r2 + + r3 = ttorch.any(torch.tensor([False, False, False])) + assert torch.is_tensor(r3) + assert r3 == torch.tensor(False) + assert not r3 + + r4 = ttorch.any(ttorch.TreeTensor({ + 'a': torch.tensor([True, True, True]), + 'b': torch.tensor([True, True, True]), + })).all() + assert torch.is_tensor(r4) + assert r4 == torch.tensor(True) + assert r4 + + r5 = ttorch.any(ttorch.TreeTensor({ + 'a': torch.tensor([True, True, True]), + 'b': torch.tensor([True, True, False]), + })).all() + assert torch.is_tensor(r5) + assert r5 == torch.tensor(True) + assert r5 + + r6 = ttorch.any(ttorch.TreeTensor({ + 'a': torch.tensor([False, False, False]), + 'b': torch.tensor([False, False, False]), + })).all() + assert torch.is_tensor(r6) + assert r6 == torch.tensor(False) + assert not r6 + + def test_eq(self): + assert ttorch.eq(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3])).all() + assert not ttorch.eq(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 2])).all() + assert ttorch.eq(torch.tensor([1, 1, 1]), 1).all() + assert not ttorch.eq(torch.tensor([1, 1, 2]), 1).all() + + assert ttorch.eq(ttorch.TreeTensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': torch.tensor([4, 5, 6]), + }), ttorch.TreeTensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': torch.tensor([4, 5, 6]), + })).all() + assert not ttorch.eq(ttorch.TreeTensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': torch.tensor([4, 5, 6]), + }), ttorch.TreeTensor({ + 'a': torch.tensor([1, 2, 3]), + 'b': torch.tensor([4, 5, 5]), + })).all() -- GitLab