提交 41907116 编写于 作者: HansBug's avatar HansBug 😆

fix(hansbug): fix test for torch 1.1.0 && exclude some versions

上级 cea90ec7
......@@ -29,6 +29,24 @@ jobs:
- '1.9.0'
- '1.10.0'
exclude:
- os: 'ubuntu-18.04'
python-version: '3.8'
torch-version: '1.1.0'
- os: 'ubuntu-18.04'
python-version: '3.8'
torch-version: '1.2.0'
- os: 'ubuntu-18.04'
python-version: '3.8'
torch-version: '1.3.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
torch-version: '1.1.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
torch-version: '1.2.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
torch-version: '1.3.0'
- os: 'ubuntu-18.04'
python-version: '3.9'
torch-version: '1.4.0'
......
......@@ -237,14 +237,14 @@ class TestNumpyArray:
})
def test_tensor(self):
assert (self._DEMO_1.tensor() == ttorch.Tensor({
assert ttorch.isclose(self._DEMO_1.tensor().double(), ttorch.Tensor({
'a': ttorch.Tensor([[1, 2, 3], [4, 5, 6]]),
'b': ttorch.Tensor([1, 3, 5, 7]),
'x': {
'c': ttorch.Tensor([[11], [23]]),
'd': ttorch.Tensor([3, 9, 11.0])
}
})).all()
}).double()).all()
assert (self._DEMO_1.tensor(dtype=torch.float64) == ttorch.Tensor({
'a': ttorch.Tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float64),
......
......@@ -4,6 +4,8 @@ import torch
import treetensor.torch as ttorch
from .base import choose_mark
bool_init_dtype = torch.tensor([True, False]).dtype
# noinspection DuplicatedCode,PyUnresolvedReferences
class TestTorchTensorReduction:
......@@ -14,7 +16,7 @@ class TestTorchTensorReduction:
'b': {'x': [[True, True, ], [True, True, ]]}
}).all()
assert isinstance(t1, torch.Tensor)
assert t1.dtype == torch.bool
assert t1.dtype == bool_init_dtype
assert t1
t2 = ttorch.Tensor({
......@@ -22,7 +24,7 @@ class TestTorchTensorReduction:
'b': {'x': [[True, True, ], [True, True, ]]}
}).all()
assert isinstance(t2, torch.Tensor)
assert t2.dtype == torch.bool
assert t2.dtype == bool_init_dtype
assert not t2
t3 = ttorch.tensor({
......@@ -48,7 +50,7 @@ class TestTorchTensorReduction:
'b': {'x': [[False, False, ], [False, False, ]]}
}).any()
assert isinstance(t1, torch.Tensor)
assert t1.dtype == torch.bool
assert t1.dtype == bool_init_dtype
assert t1
t2 = ttorch.Tensor({
......@@ -56,7 +58,7 @@ class TestTorchTensorReduction:
'b': {'x': [[False, False, ], [False, False, ]]}
}).any()
assert isinstance(t2, torch.Tensor)
assert t2.dtype == torch.bool
assert t2.dtype == bool_init_dtype
assert not t2
t3 = ttorch.Tensor({
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册