Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
ddc4ecec
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ddc4ecec
编写于
9月 07, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev(hansbug): add new unittests
上级
8f298c55
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
243 addition
and
4 deletion
+243
-4
test/tensor/test_funcs.py
test/tensor/test_funcs.py
+165
-1
treetensor/common/__init__.py
treetensor/common/__init__.py
+1
-0
treetensor/common/treelist.py
treetensor/common/treelist.py
+5
-0
treetensor/numpy/numpy.py
treetensor/numpy/numpy.py
+5
-1
treetensor/tensor/funcs.py
treetensor/tensor/funcs.py
+1
-1
treetensor/tensor/treetensor.py
treetensor/tensor/treetensor.py
+66
-1
未找到文件。
test/tensor/test_funcs.py
浏览文件 @
ddc4ecec
import
pytest
import
torch
from
treetensor.tensor
import
TreeTensor
,
zeros
,
all_equal
,
zeros_like
,
ones
,
ones_like
from
treetensor.tensor
import
TreeTensor
,
zeros
,
all_equal
,
zeros_like
,
ones
,
ones_like
,
randint
,
randint_like
,
randn
,
\
randn_like
,
full
,
full_like
from
treetensor.tensor
import
all
as
_tensor_all
# noinspection DuplicatedCode
...
...
@@ -48,6 +50,10 @@ class TestTensorFuncs:
}))
def
test_zeros_like
(
self
):
assert
all_equal
(
zeros_like
(
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])),
torch
.
tensor
([[
0
,
0
,
0
],
[
0
,
0
,
0
]]),
)
assert
all_equal
(
zeros_like
(
TreeTensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
...
...
@@ -84,6 +90,10 @@ class TestTensorFuncs:
}))
def
test_ones_like
(
self
):
assert
all_equal
(
ones_like
(
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]])),
torch
.
tensor
([[
1
,
1
,
1
],
[
1
,
1
,
1
]])
)
assert
all_equal
(
ones_like
(
TreeTensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
...
...
@@ -102,3 +112,157 @@ class TestTensorFuncs:
}
})
)
def
test_randn
(
self
):
_target
=
randn
((
200
,
300
))
assert
-
0.02
<=
_target
.
view
(
60000
).
mean
().
tolist
()
<=
0.02
assert
0.98
<=
_target
.
view
(
60000
).
std
().
tolist
()
<=
1.02
assert
_target
.
shape
==
torch
.
Size
([
200
,
300
])
_target
=
randn
({
'a'
:
(
2
,
3
),
'b'
:
(
5
,
6
),
'x'
:
{
'c'
:
(
2
,
3
,
4
),
}
})
assert
_target
.
raw_shape
==
TreeTensor
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
5
,
6
]),
'x'
:
{
'c'
:
torch
.
Size
([
2
,
3
,
4
]),
}
})
def
test_randn_like
(
self
):
_target
=
randn_like
(
torch
.
ones
(
200
,
300
))
assert
-
0.02
<=
_target
.
view
(
60000
).
mean
().
tolist
()
<=
0.02
assert
0.98
<=
_target
.
view
(
60000
).
std
().
tolist
()
<=
1.02
assert
_target
.
shape
==
torch
.
Size
([
200
,
300
])
_target
=
randn_like
(
TreeTensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]],
dtype
=
torch
.
float32
),
'b'
:
torch
.
tensor
([
1
,
2
,
3
,
4
],
dtype
=
torch
.
float32
),
'x'
:
{
'c'
:
torch
.
tensor
([
5
,
6
,
7
],
dtype
=
torch
.
float32
),
'd'
:
torch
.
tensor
([[[
8
,
9
]]],
dtype
=
torch
.
float32
),
}
}))
assert
_target
.
raw_shape
==
TreeTensor
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
4
]),
'x'
:
{
'c'
:
torch
.
Size
([
3
]),
'd'
:
torch
.
Size
([
1
,
1
,
2
]),
}
})
def
test_randint
(
self
):
_target
=
randint
({
'a'
:
(
2
,
3
),
'b'
:
(
5
,
6
),
'x'
:
{
'c'
:
(
2
,
3
,
4
),
}
},
-
10
,
10
)
assert
_tensor_all
(
_target
<
10
).
all
()
assert
_tensor_all
(
-
10
<=
_target
).
all
()
assert
_target
.
raw_shape
==
TreeTensor
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
5
,
6
]),
'x'
:
{
'c'
:
torch
.
Size
([
2
,
3
,
4
]),
}
})
_target
=
randint
({
'a'
:
(
2
,
3
),
'b'
:
(
5
,
6
),
'x'
:
{
'c'
:
(
2
,
3
,
4
),
}
},
10
)
assert
_tensor_all
(
_target
<
10
).
all
()
assert
_tensor_all
(
0
<=
_target
).
all
()
assert
_target
.
raw_shape
==
TreeTensor
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
5
,
6
]),
'x'
:
{
'c'
:
torch
.
Size
([
2
,
3
,
4
]),
}
})
def
test_randint_like
(
self
):
_target
=
randint_like
(
TreeTensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([
1
,
2
,
3
,
4
]),
'x'
:
{
'c'
:
torch
.
tensor
([
5
,
6
,
7
]),
'd'
:
torch
.
tensor
([[[
8
,
9
]]]),
}
}),
-
10
,
10
)
assert
_tensor_all
(
_target
<
10
).
all
()
assert
_tensor_all
(
-
10
<=
_target
).
all
()
assert
_target
.
raw_shape
==
TreeTensor
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
4
]),
'x'
:
{
'c'
:
torch
.
Size
([
3
]),
'd'
:
torch
.
Size
([
1
,
1
,
2
]),
}
})
_target
=
randint_like
(
TreeTensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([
1
,
2
,
3
,
4
]),
'x'
:
{
'c'
:
torch
.
tensor
([
5
,
6
,
7
]),
'd'
:
torch
.
tensor
([[[
8
,
9
]]]),
}
}),
10
)
assert
_tensor_all
(
_target
<
10
).
all
()
assert
_tensor_all
(
0
<=
_target
).
all
()
assert
_target
.
raw_shape
==
TreeTensor
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
4
]),
'x'
:
{
'c'
:
torch
.
Size
([
3
]),
'd'
:
torch
.
Size
([
1
,
1
,
2
]),
}
})
def
test_full
(
self
):
_target
=
full
({
'a'
:
(
2
,
3
),
'b'
:
(
5
,
6
),
'x'
:
{
'c'
:
(
2
,
3
,
4
),
}
},
233
)
assert
_tensor_all
(
_target
.
tensor_eq
(
233
)).
all
()
assert
_target
.
raw_shape
==
TreeTensor
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
5
,
6
]),
'x'
:
{
'c'
:
torch
.
Size
([
2
,
3
,
4
]),
}
})
def
test_full_like
(
self
):
_target
=
full_like
(
TreeTensor
({
'a'
:
torch
.
tensor
([[
1
,
2
,
3
],
[
4
,
5
,
6
]]),
'b'
:
torch
.
tensor
([
1
,
2
,
3
,
4
]),
'x'
:
{
'c'
:
torch
.
tensor
([
5
,
6
,
7
]),
'd'
:
torch
.
tensor
([[[
8
,
9
]]]),
}
}),
233
)
assert
_tensor_all
(
_target
.
tensor_eq
(
233
)).
all
()
assert
_target
.
raw_shape
==
TreeTensor
({
'a'
:
torch
.
Size
([
2
,
3
]),
'b'
:
torch
.
Size
([
4
]),
'x'
:
{
'c'
:
torch
.
Size
([
3
]),
'd'
:
torch
.
Size
([
1
,
1
,
2
]),
}
})
treetensor/common/__init__.py
0 → 100644
浏览文件 @
ddc4ecec
from
.treelist
import
TreeList
treetensor/common/treelist.py
0 → 100644
浏览文件 @
ddc4ecec
from
treevalue
import
general_tree_value
class
TreeList
(
general_tree_value
()):
pass
treetensor/numpy/numpy.py
浏览文件 @
ddc4ecec
from
treevalue
import
general_tree_value
from
treevalue
import
general_tree_value
,
method_treelize
from
..common
import
TreeList
class
TreeNumpy
(
general_tree_value
()):
...
...
@@ -7,6 +9,8 @@ class TreeNumpy(general_tree_value()):
Real numpy tree.
"""
tolist
=
method_treelize
(
return_type
=
TreeList
)(
lambda
d
:
d
.
tolist
())
@
property
def
size
(
self
)
->
int
:
return
self
\
...
...
treetensor/tensor/funcs.py
浏览文件 @
ddc4ecec
...
...
@@ -34,7 +34,7 @@ zeros = _size_based_treelize()(torch.zeros)
randn
=
_size_based_treelize
()(
torch
.
randn
)
randint
=
_size_based_treelize
(
prefix
=
True
,
tuple_
=
True
)(
torch
.
randint
)
ones
=
_size_based_treelize
()(
torch
.
ones
)
full
=
_size_based_treelize
()(
torch
.
full
)
full
=
_size_based_treelize
(
tuple_
=
True
)(
torch
.
full
)
empty
=
_size_based_treelize
()(
torch
.
empty
)
# Tensor generation based on another tensor
...
...
treetensor/tensor/treetensor.py
浏览文件 @
ddc4ecec
from
functools
import
partial
from
operator
import
__eq__
from
torch
import
Tensor
from
treevalue
import
general_tree_value
,
method_treelize
from
treevalue
import
general_tree_value
,
method_treelize
,
TreeValue
from
..common
import
TreeList
from
..numpy
import
TreeNumpy
def
_same_merge
(
eq
,
hash_
,
**
kwargs
):
kws
=
{
key
:
value
for
key
,
value
in
kwargs
.
items
()
if
not
(
isinstance
(
value
,
TreeValue
)
and
not
value
)
}
class
_Wrapper
:
def
__init__
(
self
,
v
):
self
.
v
=
v
def
__hash__
(
self
):
return
hash_
(
self
.
v
)
def
__eq__
(
self
,
other
):
return
eq
(
self
.
v
,
other
.
v
)
if
len
(
set
(
_Wrapper
(
v
)
for
v
in
kws
.
values
()))
==
1
:
return
list
(
kws
.
values
())[
0
]
else
:
return
TreeTensor
(
kws
)
# noinspection PyTypeChecker,PyShadowingBuiltins
class
TreeTensor
(
general_tree_value
()):
def
numel
(
self
)
->
int
:
...
...
@@ -11,7 +37,46 @@ class TreeTensor(general_tree_value()):
.
map
(
lambda
t
:
t
.
numel
())
\
.
reduce
(
lambda
**
kws
:
sum
(
kws
.
values
()))
@
property
def
raw_shape
(
self
):
return
self
.
map
(
lambda
t
:
t
.
shape
)
@
property
def
shape
(
self
):
return
self
.
raw_shape
.
reduce
(
partial
(
_same_merge
,
__eq__
,
hash
))
numpy
=
method_treelize
(
return_type
=
TreeNumpy
)(
Tensor
.
numpy
)
tolist
=
method_treelize
(
return_type
=
TreeList
)(
Tensor
.
tolist
)
cpu
=
method_treelize
()(
Tensor
.
cpu
)
cuda
=
method_treelize
()(
Tensor
.
cuda
)
to
=
method_treelize
()(
Tensor
.
to
)
@
method_treelize
()
def
__lt__
(
self
,
other
):
return
self
<
other
@
method_treelize
()
def
__le__
(
self
,
other
):
return
self
<=
other
@
method_treelize
()
def
__gt__
(
self
,
other
):
return
self
>
other
@
method_treelize
()
def
__ge__
(
self
,
other
):
return
self
>=
other
@
method_treelize
()
def
tensor_eq
(
self
,
other
):
return
self
==
other
@
method_treelize
()
def
tensor_ne
(
self
,
other
):
return
self
!=
other
def
all
(
self
):
return
self
.
reduce
(
lambda
**
kws
:
all
(
kws
.
values
()))
def
any
(
self
):
return
self
.
reduce
(
lambda
**
kws
:
any
(
kws
.
values
()))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录