Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
dbf50440
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
dbf50440
编写于
9月 12, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test(hansbug): refactor test in test_numpy.py and test_treetensor.py
上级
e3e3e952
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
26 addition
and
28 deletion
+26
-28
test/numpy/test_numpy.py
test/numpy/test_numpy.py
+18
-18
test/tensor/test_treetensor.py
test/tensor/test_treetensor.py
+8
-10
未找到文件。
test/numpy/test_numpy.py
浏览文件 @
dbf50440
import
numpy
as
np
import
pytest
import
treetensor.numpy
as
tnp
from
treetensor.common
import
TreeObject
from
treetensor.numpy
import
TreeNumpy
# noinspection DuplicatedCode
@
pytest
.
mark
.
unittest
class
TestNumpyNumpy
:
_DEMO_1
=
TreeNumpy
({
_DEMO_1
=
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
np
.
array
([
1
,
3
,
5
,
7
]),
'x'
:
{
...
...
@@ -17,7 +17,7 @@ class TestNumpyNumpy:
}
})
_DEMO_2
=
TreeNumpy
({
_DEMO_2
=
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
1
,
22
,
3
],
[
4
,
5
,
6
]]),
'b'
:
np
.
array
([
1
,
3
,
5
,
7
]),
'x'
:
{
...
...
@@ -26,7 +26,7 @@ class TestNumpyNumpy:
}
})
_DEMO_3
=
TreeNumpy
({
_DEMO_3
=
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
0
,
0
,
0
],
[
0
,
0
,
0
]]),
'b'
:
np
.
array
([
0
,
0
,
0
,
0
]),
'x'
:
{
...
...
@@ -54,7 +54,7 @@ class TestNumpyNumpy:
assert
self
.
_DEMO_1
.
all
()
assert
not
self
.
_DEMO_2
.
all
()
assert
not
self
.
_DEMO_3
.
all
()
assert
TreeNumpy
({
assert
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
True
,
True
,
True
],
[
True
,
True
,
True
]]),
'b'
:
np
.
array
([
True
,
True
,
True
,
True
]),
'x'
:
{
...
...
@@ -62,7 +62,7 @@ class TestNumpyNumpy:
'd'
:
np
.
array
([
True
,
True
,
True
])
}
}).
all
()
assert
not
TreeNumpy
({
assert
not
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
True
,
True
,
True
],
[
True
,
True
,
True
]]),
'b'
:
np
.
array
([
True
,
True
,
True
,
True
]),
'x'
:
{
...
...
@@ -70,7 +70,7 @@ class TestNumpyNumpy:
'd'
:
np
.
array
([
True
,
True
,
False
])
}
}).
all
()
assert
not
TreeNumpy
({
assert
not
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
False
,
False
,
False
],
[
False
,
False
,
False
]]),
'b'
:
np
.
array
([
False
,
False
,
False
,
False
]),
'x'
:
{
...
...
@@ -83,7 +83,7 @@ class TestNumpyNumpy:
assert
self
.
_DEMO_1
.
any
()
assert
self
.
_DEMO_2
.
any
()
assert
not
self
.
_DEMO_3
.
any
()
assert
TreeNumpy
({
assert
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
True
,
True
,
True
],
[
True
,
True
,
True
]]),
'b'
:
np
.
array
([
True
,
True
,
True
,
True
]),
'x'
:
{
...
...
@@ -91,7 +91,7 @@ class TestNumpyNumpy:
'd'
:
np
.
array
([
True
,
True
,
True
])
}
}).
any
()
assert
TreeNumpy
({
assert
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
True
,
True
,
True
],
[
True
,
True
,
True
]]),
'b'
:
np
.
array
([
True
,
True
,
True
,
True
]),
'x'
:
{
...
...
@@ -99,7 +99,7 @@ class TestNumpyNumpy:
'd'
:
np
.
array
([
True
,
True
,
False
])
}
}).
any
()
assert
not
TreeNumpy
({
assert
not
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
False
,
False
,
False
],
[
False
,
False
,
False
]]),
'b'
:
np
.
array
([
False
,
False
,
False
,
False
]),
'x'
:
{
...
...
@@ -121,7 +121,7 @@ class TestNumpyNumpy:
def
test_gt
(
self
):
assert
not
(
self
.
_DEMO_1
>
self
.
_DEMO_1
).
any
()
assert
not
(
self
.
_DEMO_2
>
self
.
_DEMO_2
).
any
()
assert
((
self
.
_DEMO_1
>
self
.
_DEMO_2
)
==
TreeNumpy
({
assert
((
self
.
_DEMO_1
>
self
.
_DEMO_2
)
==
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
False
,
False
,
False
],
[
False
,
False
,
False
]]),
'b'
:
np
.
array
([
False
,
False
,
False
,
False
]),
'x'
:
{
...
...
@@ -129,7 +129,7 @@ class TestNumpyNumpy:
'd'
:
np
.
array
([
False
,
False
,
False
])
}
})).
all
()
assert
((
self
.
_DEMO_2
>
self
.
_DEMO_1
)
==
TreeNumpy
({
assert
((
self
.
_DEMO_2
>
self
.
_DEMO_1
)
==
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
False
,
True
,
False
],
[
False
,
False
,
False
]]),
'b'
:
np
.
array
([
False
,
False
,
False
,
False
]),
'x'
:
{
...
...
@@ -141,7 +141,7 @@ class TestNumpyNumpy:
def
test_ge
(
self
):
assert
(
self
.
_DEMO_1
>=
self
.
_DEMO_1
).
all
()
assert
(
self
.
_DEMO_2
>=
self
.
_DEMO_2
).
all
()
assert
((
self
.
_DEMO_1
>=
self
.
_DEMO_2
)
==
TreeNumpy
({
assert
((
self
.
_DEMO_1
>=
self
.
_DEMO_2
)
==
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
True
,
False
,
True
],
[
True
,
True
,
True
]]),
'b'
:
np
.
array
([
True
,
True
,
True
,
True
]),
'x'
:
{
...
...
@@ -149,7 +149,7 @@ class TestNumpyNumpy:
'd'
:
np
.
array
([
True
,
True
,
True
])
}
})).
all
()
assert
((
self
.
_DEMO_2
>=
self
.
_DEMO_1
)
==
TreeNumpy
({
assert
((
self
.
_DEMO_2
>=
self
.
_DEMO_1
)
==
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
True
,
True
,
True
],
[
True
,
True
,
True
]]),
'b'
:
np
.
array
([
True
,
True
,
True
,
True
]),
'x'
:
{
...
...
@@ -161,7 +161,7 @@ class TestNumpyNumpy:
def
test_lt
(
self
):
assert
not
(
self
.
_DEMO_1
<
self
.
_DEMO_1
).
any
()
assert
not
(
self
.
_DEMO_2
<
self
.
_DEMO_2
).
any
()
assert
((
self
.
_DEMO_1
<
self
.
_DEMO_2
)
==
TreeNumpy
({
assert
((
self
.
_DEMO_1
<
self
.
_DEMO_2
)
==
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
False
,
True
,
False
],
[
False
,
False
,
False
]]),
'b'
:
np
.
array
([
False
,
False
,
False
,
False
]),
'x'
:
{
...
...
@@ -169,7 +169,7 @@ class TestNumpyNumpy:
'd'
:
np
.
array
([
False
,
False
,
False
])
}
})).
all
()
assert
((
self
.
_DEMO_2
<
self
.
_DEMO_1
)
==
TreeNumpy
({
assert
((
self
.
_DEMO_2
<
self
.
_DEMO_1
)
==
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
False
,
False
,
False
],
[
False
,
False
,
False
]]),
'b'
:
np
.
array
([
False
,
False
,
False
,
False
]),
'x'
:
{
...
...
@@ -181,7 +181,7 @@ class TestNumpyNumpy:
def
test_le
(
self
):
assert
(
self
.
_DEMO_1
<=
self
.
_DEMO_1
).
all
()
assert
(
self
.
_DEMO_2
<=
self
.
_DEMO_2
).
all
()
assert
((
self
.
_DEMO_1
<=
self
.
_DEMO_2
)
==
TreeNumpy
({
assert
((
self
.
_DEMO_1
<=
self
.
_DEMO_2
)
==
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
True
,
True
,
True
],
[
True
,
True
,
True
]]),
'b'
:
np
.
array
([
True
,
True
,
True
,
True
]),
'x'
:
{
...
...
@@ -189,7 +189,7 @@ class TestNumpyNumpy:
'd'
:
np
.
array
([
True
,
True
,
True
])
}
})).
all
()
assert
((
self
.
_DEMO_2
<=
self
.
_DEMO_1
)
==
TreeNumpy
({
assert
((
self
.
_DEMO_2
<=
self
.
_DEMO_1
)
==
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
True
,
False
,
True
],
[
True
,
True
,
True
]]),
'b'
:
np
.
array
([
True
,
True
,
True
,
True
]),
'x'
:
{
...
...
test/tensor/test_treetensor.py
浏览文件 @
dbf50440
...
...
@@ -3,17 +3,15 @@ import pytest
import
torch
from
treevalue
import
func_treelize
from
treetensor.numpy
import
TreeNumpy
from
treetensor.numpy
import
all
as
_numpy_all
from
treetensor.tensor
import
TreeTensor
from
treetensor.tensor
import
all
as
_tensor_all
import
treetensor.numpy
as
tnp
import
treetensor.tensor
as
ttorch
_all_is
=
func_treelize
(
return_type
=
TreeTensor
)(
lambda
x
,
y
:
x
is
y
)
_all_is
=
func_treelize
(
return_type
=
ttorch
.
TreeTensor
)(
lambda
x
,
y
:
x
is
y
)
@
pytest
.
mark
.
unittest
class
TestTensorTreetensor
:
_DEMO_1
=
TreeTensor
({
_DEMO_1
=
ttorch
.
TreeTensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([[
1
,
2
],
[
5
,
6
]]),
'x'
:
{
...
...
@@ -22,7 +20,7 @@ class TestTensorTreetensor:
}
})
_DEMO_2
=
TreeTensor
({
_DEMO_2
=
ttorch
.
TreeTensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([[
1
,
2
],
[
5
,
60
]]),
'x'
:
{
...
...
@@ -35,7 +33,7 @@ class TestTensorTreetensor:
assert
self
.
_DEMO_1
.
numel
()
==
18
def
test_numpy
(
self
):
assert
_numpy_all
(
self
.
_DEMO_1
.
numpy
()
==
TreeNumpy
({
assert
tnp
.
all
(
self
.
_DEMO_1
.
numpy
()
==
tnp
.
TreeNumpy
({
'a'
:
np
.
array
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
np
.
array
([[
1
,
2
],
[
5
,
6
]]),
'x'
:
{
...
...
@@ -45,11 +43,11 @@ class TestTensorTreetensor:
}))
def
test_cpu
(
self
):
assert
_tensor_
all
(
self
.
_DEMO_1
.
cpu
()
==
self
.
_DEMO_1
)
assert
ttorch
.
all
(
self
.
_DEMO_1
.
cpu
()
==
self
.
_DEMO_1
)
assert
_all_is
(
self
.
_DEMO_1
.
cpu
(),
self
.
_DEMO_1
).
reduce
(
lambda
**
kws
:
all
(
kws
.
values
()))
def
test_to
(
self
):
assert
_tensor_all
(
self
.
_DEMO_1
.
to
(
torch
.
float32
)
==
TreeTensor
({
assert
ttorch
.
all
(
self
.
_DEMO_1
.
to
(
torch
.
float32
)
==
ttorch
.
TreeTensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
torch
.
float32
),
'b'
:
torch
.
tensor
([[
1
,
2
],
[
5
,
6
]],
dtype
=
torch
.
float32
),
'x'
:
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录