提交 7cb29d8e 编写于 作者: HansBug's avatar HansBug 😆

dev,test(hansbug): add test for tensor/funcs.py

上级 dbf50440
......@@ -240,6 +240,40 @@ class TestTensorFuncs:
}
})
def test_empty(self):
_target = ttorch.empty(TreeValue({
'a': (2, 3),
'b': (5, 6),
'x': {
'c': (2, 3, 4),
}
}))
assert _target.shape == ttorch.TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([5, 6]),
'x': {
'c': torch.Size([2, 3, 4]),
}
})
def test_empty_like(self):
_target = ttorch.empty_like(ttorch.TreeTensor({
'a': torch.tensor([[1, 2, 3], [4, 5, 6]]),
'b': torch.tensor([1, 2, 3, 4]),
'x': {
'c': torch.tensor([5, 6, 7]),
'd': torch.tensor([[[8, 9]]]),
}
}))
assert _target.shape == ttorch.TreeSize({
'a': torch.Size([2, 3]),
'b': torch.Size([4]),
'x': {
'c': torch.Size([3]),
'd': torch.Size([1, 1, 2]),
}
})
def test_all(self):
r1 = ttorch.all(torch.tensor([True, True, True]))
assert torch.is_tensor(r1)
......@@ -340,3 +374,32 @@ class TestTensorFuncs:
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]),
})).all()
def test_equal(self):
p1 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 3]))
assert isinstance(p1, bool)
assert p1
p2 = ttorch.equal(torch.tensor([1, 2, 3]), torch.tensor([1, 2, 4]))
assert isinstance(p2, bool)
assert not p2
p3 = ttorch.equal(ttorch.TreeTensor({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}), ttorch.TreeTensor({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}))
assert isinstance(p3, bool)
assert p3
p4 = ttorch.equal(ttorch.TreeTensor({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 6]),
}), ttorch.TreeTensor({
'a': torch.tensor([1, 2, 3]),
'b': torch.tensor([4, 5, 5]),
}))
assert isinstance(p4, bool)
assert not p4
import builtins
import numpy as np
from treevalue import func_treelize as original_func_treelize
......@@ -13,13 +15,13 @@ __all__ = [
func_treelize = replaceable_partial(original_func_treelize, return_type=TreeNumpy)
@ireduce(all)
@ireduce(builtins.all)
@func_treelize(return_type=TreeObject)
def all(a, *args, **kwargs):
return np.all(a, *args, **kwargs)
@ireduce(any)
@ireduce(builtins.any)
@func_treelize()
def any(a, *args, **kwargs):
return np.any(a, *args, **kwargs)
......
import builtins
import torch
from treevalue import func_treelize as original_func_treelize
from .tensor import TreeTensor, tireduce
from ..common import TreeObject
from ..common import TreeObject, ireduce
from ..utils import replaceable_partial
func_treelize = replaceable_partial(original_func_treelize, return_type=TreeTensor)
......@@ -96,6 +98,7 @@ def eq(input_, other, *args, **kwargs):
return torch.eq(input_, other, *args, **kwargs)
@ireduce(builtins.all)
@func_treelize()
def equal(input_, other, *args, **kwargs):
return torch.equal(input_, other, *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册