You need to sign in or sign up before continuing.
提交 8f298c55 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): add plenty of basic functions

上级 02cd9976
from .test_fake import TestNumpyFake from .test_funcs import TestNumpyFuncs
from .test_numpy import TestNumpyReal from .test_numpy import TestNumpyNumpy
import pytest
from treetensor import TreeNumpy
try:
import numpy as np
except ImportError:
need_fake = True
from treetensor.numpy.fake import FakeTreeNumpy
else:
need_fake = False
unittest_mark = pytest.mark.unittest if need_fake else pytest.mark.ignore
@unittest_mark
class TestNumpyFake:
def test_base(self):
assert TreeNumpy is FakeTreeNumpy
import numpy as np
import pytest
from treetensor.numpy import TreeNumpy, all_array_equal, equal, array_equal
# noinspection DuplicatedCode
@pytest.mark.unittest
class TestNumpyFuncs:
_DEMO_1 = TreeNumpy({
'a': np.array([[1, 2, 3], [5, 6, 7]]),
'b': np.array([1, 3, 5, 7]),
'x': {
'c': np.array([3, 5, 7]),
'd': np.array([[7, 9]]),
}
})
_DEMO_2 = TreeNumpy({
'a': np.array([[1, 2, 3], [5, 6, 8]]),
'b': np.array([1, 3, 5, 7]),
'x': {
'c': np.array([3, 5, 7]),
'd': np.array([[7, 9]]),
}
})
_DEMO_3 = TreeNumpy({
'a': np.array([[1, 2, 3], [5, 6, 7]]),
'b': np.array([1, 3, 5, 7]),
'x': {
'c': np.array([3, 5, 7]),
'd': np.array([[7, 9]]),
}
})
def test_all_array_equal(self):
assert not all_array_equal(self._DEMO_1, self._DEMO_2)
assert all_array_equal(self._DEMO_1, self._DEMO_3)
assert not all_array_equal(np.array([1, 2, 3]), np.array([1, 2, 4]))
assert all_array_equal(np.array([1, 2, 3]), np.array([1, 2, 3]))
def test_equal(self):
assert all_array_equal(
equal(self._DEMO_1, self._DEMO_2),
TreeNumpy({
'a': np.array([[True, True, True], [True, True, False]]),
'b': np.array([True, True, True, True]),
'x': {
'c': np.array([True, True, True]),
'd': np.array([[True, True]]),
}
})
)
assert all_array_equal(
equal(self._DEMO_1, self._DEMO_3),
TreeNumpy({
'a': np.array([[True, True, True], [True, True, True]]),
'b': np.array([True, True, True, True]),
'x': {
'c': np.array([True, True, True]),
'd': np.array([[True, True]]),
}
})
)
def test_array_equal(self):
assert array_equal(self._DEMO_1, self._DEMO_2) == TreeNumpy({
'a': False,
'b': True,
'x': {
'c': True,
'd': True,
}
})
assert array_equal(self._DEMO_1, self._DEMO_3) == TreeNumpy({
'a': True,
'b': True,
'x': {
'c': True,
'd': True,
}
})
import numpy as np
import pytest import pytest
from treetensor import TreeNumpy from treetensor.numpy import TreeNumpy
try:
import numpy as np
except ImportError:
need_real = False
else:
need_real = True
from treetensor.numpy.numpy import TreeNumpy as RealTreeNumpy
unittest_mark = pytest.mark.unittest if need_real else pytest.mark.ignore
@unittest_mark
class TestNumpyReal:
def test_base(self):
assert TreeNumpy is RealTreeNumpy
@pytest.mark.unittest
class TestNumpyNumpy:
_DEMO_1 = TreeNumpy({ _DEMO_1 = TreeNumpy({
'a': np.array([[1, 2, 3], [4, 5, 6]]), 'a': np.array([[1, 2, 3], [4, 5, 6]]),
'b': np.array([1, 3, 5, 7]), 'b': np.array([1, 3, 5, 7]),
......
from .test_funcs import TestTensorFuncs
from .test_treetensor import TestTensorTreetensor from .test_treetensor import TestTensorTreetensor
import pytest
import torch
from treetensor.tensor import TreeTensor, zeros, all_equal, zeros_like, ones, ones_like
# noinspection DuplicatedCode
@pytest.mark.unittest
class TestTensorFuncs:
def test_all_equal(self):
assert all_equal(
torch.tensor([[1, 2, 3], [4, 5, 6]]),
torch.tensor([[1, 2, 3], [4, 5, 6]]),
)
assert all_equal(
TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}),
TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}),
)
def test_zeros(self):
assert all_equal(zeros((2, 3)), torch.zeros(2, 3))
assert all_equal(zeros({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}), TreeTensor({
'a': torch.zeros(2, 3),
'b': torch.zeros(5, 6),
'x': {
'c': torch.zeros(2, 3, 4),
}
}))
def test_zeros_like(self):
assert all_equal(
zeros_like(TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
})),
TreeTensor({
'a': torch.tensor([[0, 0, 0], [0, 0, 0]]),
'b': torch.tensor([0, 0, 0, 0]),
'x': {
'c': torch.tensor([0, 0, 0]),
'd': torch.tensor([[[0, 0]]]),
}
})
)
def test_ones(self):
assert all_equal(ones((2, 3)), torch.ones(2, 3))
assert all_equal(ones({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}), TreeTensor({
'a': torch.ones(2, 3),
'b': torch.ones(5, 6),
'x': {
'c': torch.ones(2, 3, 4),
}
}))
def test_ones_like(self):
assert all_equal(
ones_like(TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
})),
TreeTensor({
'a': torch.tensor([[1, 1, 1], [1, 1, 1]]),
'b': torch.tensor([1, 1, 1, 1]),
'x': {
'c': torch.tensor([1, 1, 1]),
'd': torch.tensor([[[1, 1]]]),
}
})
)
import numpy as np
import pytest import pytest
import torch import torch
from treevalue import func_treelize
from treetensor import TreeTensor from treetensor.numpy import all_array_equal, TreeNumpy
from treetensor.tensor import TreeTensor, all_equal
_all_is = func_treelize(return_type=TreeTensor)(lambda x, y: x is y)
@pytest.mark.unittest @pytest.mark.unittest
...@@ -17,3 +22,27 @@ class TestTensorTreetensor: ...@@ -17,3 +22,27 @@ class TestTensorTreetensor:
def test_numel(self): def test_numel(self):
assert self._DEMO_1.numel() == 18 assert self._DEMO_1.numel() == 18
def test_numpy(self):
assert all_array_equal(self._DEMO_1.numpy(), TreeNumpy({
'a': np.array([[1, 2, 3], [4, 5, 6]]),
'b': np.array([[1, 2], [5, 6]]),
'x': {
'c': np.array([3, 5, 6, 7]),
'd': np.array([[[1, 2], [8, 9]]]),
}
}))
def test_cpu(self):
assert all_equal(self._DEMO_1.cpu(), self._DEMO_1)
assert _all_is(self._DEMO_1.cpu(), self._DEMO_1).reduce(lambda **kws: all(kws.values()))
def test_to(self):
assert all_equal(self._DEMO_1.to(torch.float32), TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),
'b': torch.tensor([[1, 2], [5, 6]], dtype=torch.float32),
'x': {
'c': torch.tensor([3, 5, 6, 7], dtype=torch.float32),
'd': torch.tensor([[[1, 2], [8, 9]]], dtype=torch.float32),
}
}))
from .numpy import * from .numpy import TreeNumpy
from .tensor import * from .tensor import TreeTensor
try: from .funcs import equal, array_equal, all_array_equal
import numpy as np from .numpy import TreeNumpy
except ImportError: # numpy not exist
from .fake import FakeTreeNumpy as TreeNumpy
else:
from .numpy import TreeNumpy as TreeNumpy
from treevalue import general_tree_value
class FakeTreeNumpy(general_tree_value()):
pass
from functools import partial
import numpy as np
from treevalue import func_treelize, TreeValue
from .numpy import TreeNumpy
_treelize = partial(func_treelize, return_type=TreeNumpy)
equal = _treelize()(np.equal)
array_equal = _treelize()(np.array_equal)
def all_array_equal(tx, ty, *args, **kwargs) -> bool:
_result = array_equal(tx, ty, *args, **kwargs)
if isinstance(tx, TreeValue) and isinstance(ty, TreeValue):
return _result.reduce(lambda **kws: all(list(kws.values())))
else:
return _result
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
from .funcs import zeros_like, full_like, ones_like, randint_like, randn_like, empty_like, zeros, randn, randint, \ from .funcs import zeros_like, full_like, ones_like, randint_like, randn_like, empty_like, zeros, randn, randint, \
func_treelize, ones, empty, full func_treelize, ones, empty, full, all, eq, equal, all_eq, all_equal
from .treetensor import TreeTensor from .treetensor import TreeTensor
...@@ -7,6 +7,7 @@ from treevalue import func_treelize, TreeValue ...@@ -7,6 +7,7 @@ from treevalue import func_treelize, TreeValue
from .treetensor import TreeTensor from .treetensor import TreeTensor
_treelize = partial(func_treelize, return_type=TreeTensor) _treelize = partial(func_treelize, return_type=TreeTensor)
_python_all = all
def _size_based_treelize(*args_, prefix: bool = False, tuple_: bool = False, **kwargs_): def _size_based_treelize(*args_, prefix: bool = False, tuple_: bool = False, **kwargs_):
...@@ -28,6 +29,7 @@ def _size_based_treelize(*args_, prefix: bool = False, tuple_: bool = False, **k ...@@ -28,6 +29,7 @@ def _size_based_treelize(*args_, prefix: bool = False, tuple_: bool = False, **k
return _decorator return _decorator
# Tensor generation based on shapes
zeros = _size_based_treelize()(torch.zeros) zeros = _size_based_treelize()(torch.zeros)
randn = _size_based_treelize()(torch.randn) randn = _size_based_treelize()(torch.randn)
randint = _size_based_treelize(prefix=True, tuple_=True)(torch.randint) randint = _size_based_treelize(prefix=True, tuple_=True)(torch.randint)
...@@ -35,9 +37,31 @@ ones = _size_based_treelize()(torch.ones) ...@@ -35,9 +37,31 @@ ones = _size_based_treelize()(torch.ones)
full = _size_based_treelize()(torch.full) full = _size_based_treelize()(torch.full)
empty = _size_based_treelize()(torch.empty) empty = _size_based_treelize()(torch.empty)
# Tensor generation based on another tensor
zeros_like = _treelize()(torch.zeros_like) zeros_like = _treelize()(torch.zeros_like)
randn_like = _treelize()(torch.randn_like) randn_like = _treelize()(torch.randn_like)
randint_like = _treelize()(torch.randint_like) randint_like = _treelize()(torch.randint_like)
ones_like = _treelize()(torch.ones_like) ones_like = _treelize()(torch.ones_like)
full_like = _treelize()(torch.full_like) full_like = _treelize()(torch.full_like)
empty_like = _treelize()(torch.empty_like) empty_like = _treelize()(torch.empty_like)
# Tensor operators
all = _treelize()(torch.all)
eq = _treelize()(torch.eq)
equal = _treelize()(torch.equal)
def all_eq(tx, ty, *args, **kwargs) -> bool:
_result = eq(tx, ty, *args, **kwargs)
if isinstance(tx, TreeValue) and isinstance(ty, TreeValue):
return _result.reduce(lambda **kws: _python_all(kws.values()))
else:
return _result
def all_equal(tx, ty, *args, **kwargs) -> bool:
_result = equal(tx, ty, *args, **kwargs)
if isinstance(tx, TreeValue) and isinstance(ty, TreeValue):
return _result.reduce(lambda **kws: _python_all(kws.values()))
else:
return _result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册