提交 dbf50440 编写于 作者: HansBug's avatar HansBug 😆

test(hansbug): refactor test in test_numpy.py and test_treetensor.py

上级 e3e3e952
import numpy as np
import pytest
import treetensor.numpy as tnp
from treetensor.common import TreeObject
from treetensor.numpy import TreeNumpy
# noinspection DuplicatedCode
@pytest.mark.unittest
class TestNumpyNumpy:
_DEMO_1 = TreeNumpy({
_DEMO_1 = tnp.TreeNumpy({
'a': np.array([[1, 2, 3], [4, 5, 6]]),
'b': np.array([1, 3, 5, 7]),
'x': {
......@@ -17,7 +17,7 @@ class TestNumpyNumpy:
}
})
_DEMO_2 = TreeNumpy({
_DEMO_2 = tnp.TreeNumpy({
'a': np.array([[1, 22, 3], [4, 5, 6]]),
'b': np.array([1, 3, 5, 7]),
'x': {
......@@ -26,7 +26,7 @@ class TestNumpyNumpy:
}
})
_DEMO_3 = TreeNumpy({
_DEMO_3 = tnp.TreeNumpy({
'a': np.array([[0, 0, 0], [0, 0, 0]]),
'b': np.array([0, 0, 0, 0]),
'x': {
......@@ -54,7 +54,7 @@ class TestNumpyNumpy:
assert self._DEMO_1.all()
assert not self._DEMO_2.all()
assert not self._DEMO_3.all()
assert TreeNumpy({
assert tnp.TreeNumpy({
'a': np.array([[True, True, True], [True, True, True]]),
'b': np.array([True, True, True, True]),
'x': {
......@@ -62,7 +62,7 @@ class TestNumpyNumpy:
'd': np.array([True, True, True])
}
}).all()
assert not TreeNumpy({
assert not tnp.TreeNumpy({
'a': np.array([[True, True, True], [True, True, True]]),
'b': np.array([True, True, True, True]),
'x': {
......@@ -70,7 +70,7 @@ class TestNumpyNumpy:
'd': np.array([True, True, False])
}
}).all()
assert not TreeNumpy({
assert not tnp.TreeNumpy({
'a': np.array([[False, False, False], [False, False, False]]),
'b': np.array([False, False, False, False]),
'x': {
......@@ -83,7 +83,7 @@ class TestNumpyNumpy:
assert self._DEMO_1.any()
assert self._DEMO_2.any()
assert not self._DEMO_3.any()
assert TreeNumpy({
assert tnp.TreeNumpy({
'a': np.array([[True, True, True], [True, True, True]]),
'b': np.array([True, True, True, True]),
'x': {
......@@ -91,7 +91,7 @@ class TestNumpyNumpy:
'd': np.array([True, True, True])
}
}).any()
assert TreeNumpy({
assert tnp.TreeNumpy({
'a': np.array([[True, True, True], [True, True, True]]),
'b': np.array([True, True, True, True]),
'x': {
......@@ -99,7 +99,7 @@ class TestNumpyNumpy:
'd': np.array([True, True, False])
}
}).any()
assert not TreeNumpy({
assert not tnp.TreeNumpy({
'a': np.array([[False, False, False], [False, False, False]]),
'b': np.array([False, False, False, False]),
'x': {
......@@ -121,7 +121,7 @@ class TestNumpyNumpy:
def test_gt(self):
assert not (self._DEMO_1 > self._DEMO_1).any()
assert not (self._DEMO_2 > self._DEMO_2).any()
assert ((self._DEMO_1 > self._DEMO_2) == TreeNumpy({
assert ((self._DEMO_1 > self._DEMO_2) == tnp.TreeNumpy({
'a': np.array([[False, False, False], [False, False, False]]),
'b': np.array([False, False, False, False]),
'x': {
......@@ -129,7 +129,7 @@ class TestNumpyNumpy:
'd': np.array([False, False, False])
}
})).all()
assert ((self._DEMO_2 > self._DEMO_1) == TreeNumpy({
assert ((self._DEMO_2 > self._DEMO_1) == tnp.TreeNumpy({
'a': np.array([[False, True, False], [False, False, False]]),
'b': np.array([False, False, False, False]),
'x': {
......@@ -141,7 +141,7 @@ class TestNumpyNumpy:
def test_ge(self):
assert (self._DEMO_1 >= self._DEMO_1).all()
assert (self._DEMO_2 >= self._DEMO_2).all()
assert ((self._DEMO_1 >= self._DEMO_2) == TreeNumpy({
assert ((self._DEMO_1 >= self._DEMO_2) == tnp.TreeNumpy({
'a': np.array([[True, False, True], [True, True, True]]),
'b': np.array([True, True, True, True]),
'x': {
......@@ -149,7 +149,7 @@ class TestNumpyNumpy:
'd': np.array([True, True, True])
}
})).all()
assert ((self._DEMO_2 >= self._DEMO_1) == TreeNumpy({
assert ((self._DEMO_2 >= self._DEMO_1) == tnp.TreeNumpy({
'a': np.array([[True, True, True], [True, True, True]]),
'b': np.array([True, True, True, True]),
'x': {
......@@ -161,7 +161,7 @@ class TestNumpyNumpy:
def test_lt(self):
assert not (self._DEMO_1 < self._DEMO_1).any()
assert not (self._DEMO_2 < self._DEMO_2).any()
assert ((self._DEMO_1 < self._DEMO_2) == TreeNumpy({
assert ((self._DEMO_1 < self._DEMO_2) == tnp.TreeNumpy({
'a': np.array([[False, True, False], [False, False, False]]),
'b': np.array([False, False, False, False]),
'x': {
......@@ -169,7 +169,7 @@ class TestNumpyNumpy:
'd': np.array([False, False, False])
}
})).all()
assert ((self._DEMO_2 < self._DEMO_1) == TreeNumpy({
assert ((self._DEMO_2 < self._DEMO_1) == tnp.TreeNumpy({
'a': np.array([[False, False, False], [False, False, False]]),
'b': np.array([False, False, False, False]),
'x': {
......@@ -181,7 +181,7 @@ class TestNumpyNumpy:
def test_le(self):
assert (self._DEMO_1 <= self._DEMO_1).all()
assert (self._DEMO_2 <= self._DEMO_2).all()
assert ((self._DEMO_1 <= self._DEMO_2) == TreeNumpy({
assert ((self._DEMO_1 <= self._DEMO_2) == tnp.TreeNumpy({
'a': np.array([[True, True, True], [True, True, True]]),
'b': np.array([True, True, True, True]),
'x': {
......@@ -189,7 +189,7 @@ class TestNumpyNumpy:
'd': np.array([True, True, True])
}
})).all()
assert ((self._DEMO_2 <= self._DEMO_1) == TreeNumpy({
assert ((self._DEMO_2 <= self._DEMO_1) == tnp.TreeNumpy({
'a': np.array([[True, False, True], [True, True, True]]),
'b': np.array([True, True, True, True]),
'x': {
......
......@@ -3,17 +3,15 @@ import pytest
import torch
from treevalue import func_treelize
from treetensor.numpy import TreeNumpy
from treetensor.numpy import all as _numpy_all
from treetensor.tensor import TreeTensor
from treetensor.tensor import all as _tensor_all
import treetensor.numpy as tnp
import treetensor.tensor as ttorch
_all_is = func_treelize(return_type=TreeTensor)(lambda x, y: x is y)
_all_is = func_treelize(return_type=ttorch.TreeTensor)(lambda x, y: x is y)
@pytest.mark.unittest
class TestTensorTreetensor:
_DEMO_1 = TreeTensor({
_DEMO_1 = ttorch.TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([[1, 2], [5, 6]]),
'x': {
......@@ -22,7 +20,7 @@ class TestTensorTreetensor:
}
})
_DEMO_2 = TreeTensor({
_DEMO_2 = ttorch.TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([[1, 2], [5, 60]]),
'x': {
......@@ -35,7 +33,7 @@ class TestTensorTreetensor:
assert self._DEMO_1.numel() == 18
def test_numpy(self):
assert _numpy_all(self._DEMO_1.numpy() == TreeNumpy({
assert tnp.all(self._DEMO_1.numpy() == tnp.TreeNumpy({
'a': np.array([[1, 2, 3], [4, 5, 6]]),
'b': np.array([[1, 2], [5, 6]]),
'x': {
......@@ -45,11 +43,11 @@ class TestTensorTreetensor:
}))
def test_cpu(self):
assert _tensor_all(self._DEMO_1.cpu() == self._DEMO_1)
assert ttorch.all(self._DEMO_1.cpu() == self._DEMO_1)
assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values()))
def test_to(self):
assert _tensor_all(self._DEMO_1.to(torch.float32) == TreeTensor({
assert ttorch.all(self._DEMO_1.to(torch.float32) == ttorch.TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([[1, 2], [5, 6]], dtype=torch.float32),
'x': {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册