提交 02cd9976 编写于 作者: HansBug's avatar HansBug 😆

dev(hansbug): save the new unittest and development code && use coverage>=5

上级 9ee559bc
coverage>=4,<5
coverage>=5
mock>=4.0.3
flake8~=3.5
pytest~=5.4.3
......
from .config import *
from .numpy import *
from .tensor import *
from .test_fake import TestNumpyFake
from .test_numpy import TestNumpyReal
import pytest
from treetensor import TreeNumpy
try:
import numpy as np
except ImportError:
need_fake = True
from treetensor.numpy.fake import FakeTreeNumpy
else:
need_fake = False
unittest_mark = pytest.mark.unittest if need_fake else pytest.mark.ignore
@unittest_mark
class TestNumpyFake:
def test_base(self):
assert TreeNumpy is FakeTreeNumpy
import pytest
from treetensor import TreeNumpy
try:
import numpy as np
except ImportError:
need_real = False
else:
need_real = True
from treetensor.numpy.numpy import TreeNumpy as RealTreeNumpy
unittest_mark = pytest.mark.unittest if need_real else pytest.mark.ignore
@unittest_mark
class TestNumpyReal:
def test_base(self):
assert TreeNumpy is RealTreeNumpy
_DEMO_1 = TreeNumpy({
'a': np.array([[1, 2, 3], [4, 5, 6]]),
'b': np.array([1, 3, 5, 7]),
'x': {
'c': np.array([[11], [23]]),
'd': np.array([3, 9, 11.0])
}
})
def test_size(self):
assert self._DEMO_1.size == 15
def test_nbytes(self):
assert self._DEMO_1.nbytes == 120
def test_sum(self):
assert self._DEMO_1.sum() == 94.0
from .test_treetensor import TestTensorTreetensor
import pytest
import torch
from treetensor import TreeTensor
@pytest.mark.unittest
class TestTensorTreetensor:
_DEMO_1 = TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([[1, 2], [5, 6]]),
'x': {
'c': torch.tensor([3, 5, 6, 7]),
'd': torch.tensor([[[1, 2], [8, 9]]]),
}
})
def test_numel(self):
assert self._DEMO_1.numel() == 18
from .tensor import TreeTensor
from .numpy import *
from .tensor import *
try:
import numpy as np
except ImportError: # numpy not exist
from .fake import FakeTreeNumpy as TreeNumpy
else:
from .numpy import TreeNumpy as TreeNumpy
from treevalue import general_tree_value
class FakeTreeNumpy(general_tree_value()):
pass
from treevalue import general_tree_value
class TreeNumpy(general_tree_value()):
"""
Overview:
Real numpy tree.
"""
@property
def size(self) -> int:
return self \
.map(lambda d: d.size) \
.reduce(lambda **kwargs: sum(kwargs.values()))
@property
def nbytes(self) -> int:
return self \
.map(lambda d: d.nbytes) \
.reduce(lambda **kwargs: sum(kwargs.values()))
def sum(self):
return self \
.map(lambda d: d.sum()) \
.reduce(lambda **kwargs: sum(kwargs.values()))
# noinspection PyUnresolvedReferences
from .funcs import zeros_like, full_like, ones_like, randint_like, randn_like, empty_like, zeros, randn, randint, \
func_treelize, ones, empty, full
from .treetensor import TreeTensor
from functools import partial, wraps
from typing import Tuple
import torch
from treevalue import func_treelize, TreeValue
from .treetensor import TreeTensor
_treelize = partial(func_treelize, return_type=TreeTensor)
def _size_based_treelize(*args_, prefix: bool = False, tuple_: bool = False, **kwargs_):
def _decorator(func):
@_treelize(*args_, **kwargs_)
def _sub_func(size: Tuple[int, ...], *args, **kwargs):
_size_args = (size,) if tuple_ else size
_args = (*args, *_size_args) if prefix else (*_size_args, *args)
return func(*_args, **kwargs)
@wraps(func)
def _new_func(size, *args, **kwargs):
if isinstance(size, (TreeValue, dict)):
size = TreeTensor(size)
return _sub_func(size, *args, **kwargs)
return _new_func
return _decorator
zeros = _size_based_treelize()(torch.zeros)
randn = _size_based_treelize()(torch.randn)
randint = _size_based_treelize(prefix=True, tuple_=True)(torch.randint)
ones = _size_based_treelize()(torch.ones)
full = _size_based_treelize()(torch.full)
empty = _size_based_treelize()(torch.empty)
zeros_like = _treelize()(torch.zeros_like)
randn_like = _treelize()(torch.randn_like)
randint_like = _treelize()(torch.randint_like)
ones_like = _treelize()(torch.ones_like)
full_like = _treelize()(torch.full_like)
empty_like = _treelize()(torch.empty_like)
from treevalue import general_tree_value
from torch import Tensor
from treevalue import general_tree_value, method_treelize
from ..numpy import TreeNumpy
# noinspection PyTypeChecker,PyShadowingBuiltins
class TreeTensor(general_tree_value()):
pass
def numel(self) -> int:
return self \
.map(lambda t: t.numel()) \
.reduce(lambda **kws: sum(kws.values()))
numpy = method_treelize(return_type=TreeNumpy)(Tensor.numpy)
cpu = method_treelize()(Tensor.cpu)
cuda = method_treelize()(Tensor.cuda)
to = method_treelize()(Tensor.to)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册