提交 ddc4ecec 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): add new unittests

上级 8f298c55
import pytest import pytest
import torch import torch
from treetensor.tensor import TreeTensor, zeros, all_equal, zeros_like, ones, ones_like from treetensor.tensor import TreeTensor, zeros, all_equal, zeros_like, ones, ones_like, randint, randint_like, randn, \
randn_like, full, full_like
from treetensor.tensor import all as _tensor_all
# noinspection DuplicatedCode # noinspection DuplicatedCode
...@@ -48,6 +50,10 @@ class TestTensorFuncs: ...@@ -48,6 +50,10 @@ class TestTensorFuncs:
})) }))
def test_zeros_like(self): def test_zeros_like(self):
assert all_equal(
zeros_like(torch.tensor([[1, 2, 3], [4, 5, 6]])),
torch.tensor([[0, 0, 0], [0, 0, 0]]),
)
assert all_equal( assert all_equal(
zeros_like(TreeTensor({ zeros_like(TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
...@@ -84,6 +90,10 @@ class TestTensorFuncs: ...@@ -84,6 +90,10 @@ class TestTensorFuncs:
})) }))
def test_ones_like(self): def test_ones_like(self):
assert all_equal(
ones_like(torch.tensor([[1, 2, 3], [4, 5, 6]])),
torch.tensor([[1, 1, 1], [1, 1, 1]])
)
assert all_equal( assert all_equal(
ones_like(TreeTensor({ ones_like(TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]), 'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
...@@ -102,3 +112,157 @@ class TestTensorFuncs: ...@@ -102,3 +112,157 @@ class TestTensorFuncs:
} }
}) })
) )
def test_randn(self):
_target = randn((200, 300))
assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300])
_target = randn({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
})
assert _target.raw_shape == TreeTensor({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
'c': torch.Size([2, 3, 4]),
}
})
def test_randn_like(self):
_target = randn_like(torch.ones(200, 300))
assert -0.02 <= _target.view(60000).mean().tolist() <= 0.02
assert 0.98 <= _target.view(60000).std().tolist() <= 1.02
assert _target.shape == torch.Size([200, 300])
_target = randn_like(TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([1, 2, 3, 4], dtype=torch.float32),
'x': {
'c': torch.tensor([5, 6, 7], dtype=torch.float32),
'd': torch.tensor([[[8, 9]]], dtype=torch.float32),
}
}))
assert _target.raw_shape == TreeTensor({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
'c': torch.Size([3]),
'd': torch.Size([1, 1, 2]),
}
})
def test_randint(self):
_target = randint({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}, -10, 10)
assert _tensor_all(_target < 10).all()
assert _tensor_all(-10 <= _target).all()
assert _target.raw_shape == TreeTensor({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
'c': torch.Size([2, 3, 4]),
}
})
_target = randint({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}, 10)
assert _tensor_all(_target < 10).all()
assert _tensor_all(0 <= _target).all()
assert _target.raw_shape == TreeTensor({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
'c': torch.Size([2, 3, 4]),
}
})
def test_randint_like(self):
_target = randint_like(TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}), -10, 10)
assert _tensor_all(_target < 10).all()
assert _tensor_all(-10 <= _target).all()
assert _target.raw_shape == TreeTensor({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
'c': torch.Size([3]),
'd': torch.Size([1, 1, 2]),
}
})
_target = randint_like(TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}), 10)
assert _tensor_all(_target < 10).all()
assert _tensor_all(0 <= _target).all()
assert _target.raw_shape == TreeTensor({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
'c': torch.Size([3]),
'd': torch.Size([1, 1, 2]),
}
})
def test_full(self):
_target = full({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}, 233)
assert _tensor_all(_target.tensor_eq(233)).all()
assert _target.raw_shape == TreeTensor({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
'c': torch.Size([2, 3, 4]),
}
})
def test_full_like(self):
_target = full_like(TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}), 233)
assert _tensor_all(_target.tensor_eq(233)).all()
assert _target.raw_shape == TreeTensor({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
'c': torch.Size([3]),
'd': torch.Size([1, 1, 2]),
}
})
from .treelist import TreeList
from treevalue import general_tree_value
class TreeList(general_tree_value()):
pass
from treevalue import general_tree_value from treevalue import general_tree_value, method_treelize
from ..common import TreeList
class TreeNumpy(general_tree_value()): class TreeNumpy(general_tree_value()):
...@@ -7,6 +9,8 @@ class TreeNumpy(general_tree_value()): ...@@ -7,6 +9,8 @@ class TreeNumpy(general_tree_value()):
Real numpy tree. Real numpy tree.
""" """
tolist = method_treelize(return_type=TreeList)(lambda d: d.tolist())
@property @property
def size(self) -> int: def size(self) -> int:
return self \ return self \
......
...@@ -34,7 +34,7 @@ zeros = _size_based_treelize()(torch.zeros) ...@@ -34,7 +34,7 @@ zeros = _size_based_treelize()(torch.zeros)
randn = _size_based_treelize()(torch.randn) randn = _size_based_treelize()(torch.randn)
randint = _size_based_treelize(prefix=True, tuple_=True)(torch.randint) randint = _size_based_treelize(prefix=True, tuple_=True)(torch.randint)
ones = _size_based_treelize()(torch.ones) ones = _size_based_treelize()(torch.ones)
full = _size_based_treelize()(torch.full) full = _size_based_treelize(tuple_=True)(torch.full)
empty = _size_based_treelize()(torch.empty) empty = _size_based_treelize()(torch.empty)
# Tensor generation based on another tensor # Tensor generation based on another tensor
......
from functools import partial
from operator import __eq__
from torch import Tensor from torch import Tensor
from treevalue import general_tree_value, method_treelize from treevalue import general_tree_value, method_treelize, TreeValue
from ..common import TreeList
from ..numpy import TreeNumpy from ..numpy import TreeNumpy
def _same_merge(eq, hash_, **kwargs):
kws = {
key: value for key, value in kwargs.items()
if not (isinstance(value, TreeValue) and not value)
}
class _Wrapper:
def __init__(self, v):
self.v = v
def __hash__(self):
return hash_(self.v)
def __eq__(self, other):
return eq(self.v, other.v)
if len(set(_Wrapper(v) for v in kws.values())) == 1:
return list(kws.values())[0]
else:
return TreeTensor(kws)
# noinspection PyTypeChecker,PyShadowingBuiltins # noinspection PyTypeChecker,PyShadowingBuiltins
class TreeTensor(general_tree_value()): class TreeTensor(general_tree_value()):
def numel(self) -> int: def numel(self) -> int:
...@@ -11,7 +37,46 @@ class TreeTensor(general_tree_value()): ...@@ -11,7 +37,46 @@ class TreeTensor(general_tree_value()):
.map(lambda t: t.numel()) \ .map(lambda t: t.numel()) \
.reduce(lambda **kws: sum(kws.values())) .reduce(lambda **kws: sum(kws.values()))
@property
def raw_shape(self):
return self.map(lambda t: t.shape)
@property
def shape(self):
return self.raw_shape.reduce(partial(_same_merge, __eq__, hash))
numpy = method_treelize(return_type=TreeNumpy)(Tensor.numpy) numpy = method_treelize(return_type=TreeNumpy)(Tensor.numpy)
tolist = method_treelize(return_type=TreeList)(Tensor.tolist)
cpu = method_treelize()(Tensor.cpu) cpu = method_treelize()(Tensor.cpu)
cuda = method_treelize()(Tensor.cuda) cuda = method_treelize()(Tensor.cuda)
to = method_treelize()(Tensor.to) to = method_treelize()(Tensor.to)
@method_treelize()
def __lt__(self, other):
return self < other
@method_treelize()
def __le__(self, other):
return self <= other
@method_treelize()
def __gt__(self, other):
return self > other
@method_treelize()
def __ge__(self, other):
return self >= other
@method_treelize()
def tensor_eq(self, other):
return self == other
@method_treelize()
def tensor_ne(self, other):
return self != other
def all(self):
return self.reduce(lambda **kws: all(kws.values()))
def any(self):
return self.reduce(lambda **kws: any(kws.values()))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册