Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
flybirding10011
DI-treetensor
提交
8217663f
D
DI-treetensor
项目概览
flybirding10011
/
DI-treetensor
与 Fork 源项目一致
Fork自
OpenDILab开源决策智能平台 / DI-treetensor
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DI-treetensor
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
8217663f
编写于
9月 25, 2021
作者:
HansBug
😆
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dev, doc, test(hansbug): add squeeze, unsqueeze, where, reshape
上级
b3286b03
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
449 addition
and
2 deletion
+449
-2
test/torch/test_funcs.py
test/torch/test_funcs.py
+91
-0
test/torch/test_tensor.py
test/torch/test_tensor.py
+136
-0
treetensor/torch/funcs.py
treetensor/torch/funcs.py
+170
-2
treetensor/torch/tensor.py
treetensor/torch/tensor.py
+52
-0
未找到文件。
test/torch/test_funcs.py
浏览文件 @
8217663f
...
...
@@ -1655,3 +1655,94 @@ class TestTorchFuncs:
[[
18
,
21
,
17
,
12
],
[
36
,
30
,
33
,
31
]]]},
})).
all
()
@
choose_mark
()
def
test_reshape
(
self
):
t1
=
ttorch
.
reshape
(
torch
.
tensor
([[
1
,
2
],
[
3
,
4
]]),
(
-
1
,))
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
(
t1
==
ttorch
.
tensor
([
1
,
2
,
3
,
4
])).
all
()
t2
=
ttorch
.
reshape
(
ttorch
.
tensor
({
'a'
:
[[
1
,
2
],
[
3
,
4
]],
'b'
:
{
'x'
:
[[
2
],
[
3
],
[
5
],
[
7
],
[
11
],
[
13
]]},
}),
(
-
1
,))
assert
(
t2
==
ttorch
.
tensor
({
'a'
:
[
1
,
2
,
3
,
4
],
'b'
:
{
'x'
:
[
2
,
3
,
5
,
7
,
11
,
13
]},
})).
all
()
@
choose_mark
()
def
test_squeeze
(
self
):
t1
=
torch
.
randint
(
100
,
(
2
,
1
,
2
,
1
,
2
))
assert
t1
.
shape
==
torch
.
Size
([
2
,
1
,
2
,
1
,
2
])
assert
ttorch
.
squeeze
(
t1
).
shape
==
torch
.
Size
([
2
,
2
,
2
])
t2
=
ttorch
.
randint
(
100
,
{
'a'
:
(
2
,
1
,
2
,
1
,
2
),
'b'
:
{
'x'
:
(
2
,
1
,
1
,
3
)},
})
assert
t2
.
shape
==
ttorch
.
Size
({
'a'
:
(
2
,
1
,
2
,
1
,
2
),
'b'
:
{
'x'
:
(
2
,
1
,
1
,
3
)},
})
assert
ttorch
.
squeeze
(
t2
).
shape
==
ttorch
.
Size
({
'a'
:
(
2
,
2
,
2
),
'b'
:
{
'x'
:
(
2
,
3
)},
})
@
choose_mark
()
def
test_unsqueeze
(
self
):
t1
=
torch
.
randint
(
100
,
(
100
,))
assert
t1
.
shape
==
torch
.
Size
([
100
])
assert
ttorch
.
unsqueeze
(
t1
,
0
).
shape
==
torch
.
Size
([
1
,
100
])
tt1
=
ttorch
.
randint
(
100
,
{
'a'
:
(
2
,
2
,
2
),
'b'
:
{
'x'
:
(
2
,
3
)},
})
assert
tt1
.
shape
==
ttorch
.
Size
({
'a'
:
(
2
,
2
,
2
),
'b'
:
{
'x'
:
(
2
,
3
)},
})
assert
ttorch
.
unsqueeze
(
tt1
,
1
).
shape
==
ttorch
.
Size
({
'a'
:
(
2
,
1
,
2
,
2
),
'b'
:
{
'x'
:
(
2
,
1
,
3
)},
})
@
choose_mark
()
def
test_where
(
self
):
t1
=
ttorch
.
where
(
torch
.
tensor
([[
True
,
False
],
[
False
,
True
]]),
torch
.
tensor
([[
2
,
8
],
[
16
,
4
]]),
torch
.
tensor
([[
3
,
11
],
[
5
,
7
]]),
)
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
(
t1
==
ttorch
.
tensor
([[
2
,
11
],
[
5
,
4
]])).
all
()
t2
=
ttorch
.
tensor
({
'a'
:
[[
27
,
90
,
80
],
[
12
,
59
,
5
]],
'b'
:
{
'x'
:
[[[
71
,
52
,
92
,
79
],
[
48
,
4
,
13
,
96
]],
[[
72
,
89
,
44
,
62
],
[
32
,
4
,
29
,
76
]],
[[
6
,
3
,
93
,
89
],
[
44
,
89
,
85
,
90
]]]},
})
assert
(
ttorch
.
where
(
t2
%
2
==
1
,
t2
,
ttorch
.
zeros
({
'a'
:
(
2
,
3
),
'b'
:
{
'x'
:
(
3
,
2
,
4
)}},
dtype
=
torch
.
long
))
==
ttorch
.
tensor
({
'a'
:
[[
27
,
0
,
0
],
[
0
,
59
,
5
]],
'b'
:
{
'x'
:
[[[
71
,
0
,
0
,
79
],
[
0
,
0
,
13
,
0
]],
[[
0
,
89
,
0
,
0
],
[
0
,
0
,
29
,
0
]],
[[
0
,
3
,
93
,
89
],
[
0
,
89
,
85
,
0
]]]},
})).
all
()
test/torch/test_tensor.py
浏览文件 @
8217663f
...
...
@@ -1324,3 +1324,139 @@ class TestTorchTensor:
[[
73
,
81
,
11
],
[
58
,
54
,
78
]]]
})).
all
()
@
choose_mark
()
def
test_reshape
(
self
):
t1
=
torch
.
tensor
([[
1
,
2
],
[
3
,
4
]]).
reshape
((
-
1
,))
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
(
t1
==
ttorch
.
tensor
([
1
,
2
,
3
,
4
])).
all
()
t2
=
ttorch
.
tensor
({
'a'
:
[[
1
,
2
],
[
3
,
4
]],
'b'
:
{
'x'
:
[[
2
],
[
3
],
[
5
],
[
7
],
[
11
],
[
13
]]},
}).
reshape
((
-
1
,))
assert
(
t2
==
ttorch
.
tensor
({
'a'
:
[
1
,
2
,
3
,
4
],
'b'
:
{
'x'
:
[
2
,
3
,
5
,
7
,
11
,
13
]},
})).
all
()
@
choose_mark
()
def
test_squeeze
(
self
):
t1
=
torch
.
randint
(
100
,
(
2
,
1
,
2
,
1
,
2
))
assert
t1
.
shape
==
torch
.
Size
([
2
,
1
,
2
,
1
,
2
])
assert
t1
.
squeeze
().
shape
==
torch
.
Size
([
2
,
2
,
2
])
t2
=
ttorch
.
randint
(
100
,
{
'a'
:
(
2
,
1
,
2
,
1
,
2
),
'b'
:
{
'x'
:
(
2
,
1
,
1
,
3
)},
})
assert
t2
.
shape
==
ttorch
.
Size
({
'a'
:
(
2
,
1
,
2
,
1
,
2
),
'b'
:
{
'x'
:
(
2
,
1
,
1
,
3
)},
})
assert
t2
.
squeeze
().
shape
==
ttorch
.
Size
({
'a'
:
(
2
,
2
,
2
),
'b'
:
{
'x'
:
(
2
,
3
)},
})
@
choose_mark
()
def
test_squeeze_
(
self
):
t1
=
torch
.
randint
(
100
,
(
2
,
1
,
2
,
1
,
2
))
assert
t1
.
shape
==
torch
.
Size
([
2
,
1
,
2
,
1
,
2
])
t1r
=
t1
.
squeeze_
()
assert
t1r
is
t1
assert
t1
.
shape
==
torch
.
Size
([
2
,
2
,
2
])
t2
=
ttorch
.
randint
(
100
,
{
'a'
:
(
2
,
1
,
2
,
1
,
2
),
'b'
:
{
'x'
:
(
2
,
1
,
1
,
3
)},
})
assert
t2
.
shape
==
ttorch
.
Size
({
'a'
:
(
2
,
1
,
2
,
1
,
2
),
'b'
:
{
'x'
:
(
2
,
1
,
1
,
3
)},
})
t2r
=
t2
.
squeeze_
()
assert
t2r
is
t2
assert
t2
.
shape
==
ttorch
.
Size
({
'a'
:
(
2
,
2
,
2
),
'b'
:
{
'x'
:
(
2
,
3
)},
})
@
choose_mark
()
def
test_unsqueeze
(
self
):
t1
=
torch
.
randint
(
100
,
(
100
,))
assert
t1
.
shape
==
torch
.
Size
([
100
])
assert
t1
.
unsqueeze
(
0
).
shape
==
torch
.
Size
([
1
,
100
])
tt1
=
ttorch
.
randint
(
100
,
{
'a'
:
(
2
,
2
,
2
),
'b'
:
{
'x'
:
(
2
,
3
)},
})
assert
tt1
.
shape
==
ttorch
.
Size
({
'a'
:
(
2
,
2
,
2
),
'b'
:
{
'x'
:
(
2
,
3
)},
})
assert
tt1
.
unsqueeze
(
1
).
shape
==
ttorch
.
Size
({
'a'
:
(
2
,
1
,
2
,
2
),
'b'
:
{
'x'
:
(
2
,
1
,
3
)},
})
@
choose_mark
()
def
test_unsqueeze_
(
self
):
t1
=
torch
.
randint
(
100
,
(
100
,))
assert
t1
.
shape
==
torch
.
Size
([
100
])
t1r
=
t1
.
unsqueeze_
(
0
)
assert
t1r
is
t1
assert
t1
.
shape
==
torch
.
Size
([
1
,
100
])
tt1
=
ttorch
.
randint
(
100
,
{
'a'
:
(
2
,
2
,
2
),
'b'
:
{
'x'
:
(
2
,
3
)},
})
assert
tt1
.
shape
==
ttorch
.
Size
({
'a'
:
(
2
,
2
,
2
),
'b'
:
{
'x'
:
(
2
,
3
)},
})
tt1r
=
tt1
.
unsqueeze_
(
1
)
assert
tt1r
is
tt1
assert
tt1
.
shape
==
ttorch
.
Size
({
'a'
:
(
2
,
1
,
2
,
2
),
'b'
:
{
'x'
:
(
2
,
1
,
3
)},
})
@
choose_mark
()
def
test_where
(
self
):
t1
=
torch
.
tensor
([[
2
,
8
],
[
16
,
4
]]).
where
(
torch
.
tensor
([[
True
,
False
],
[
False
,
True
]]),
torch
.
tensor
([[
3
,
11
],
[
5
,
7
]]),
)
assert
isinstance
(
t1
,
torch
.
Tensor
)
assert
(
t1
==
ttorch
.
tensor
([[
2
,
11
],
[
5
,
4
]])).
all
()
t2
=
ttorch
.
tensor
({
'a'
:
[[
27
,
90
,
80
],
[
12
,
59
,
5
]],
'b'
:
{
'x'
:
[[[
71
,
52
,
92
,
79
],
[
48
,
4
,
13
,
96
]],
[[
72
,
89
,
44
,
62
],
[
32
,
4
,
29
,
76
]],
[[
6
,
3
,
93
,
89
],
[
44
,
89
,
85
,
90
]]]},
})
assert
(
t2
.
where
(
t2
%
2
==
1
,
ttorch
.
zeros
({
'a'
:
(
2
,
3
),
'b'
:
{
'x'
:
(
3
,
2
,
4
)}},
dtype
=
torch
.
long
))
==
ttorch
.
tensor
({
'a'
:
[[
27
,
0
,
0
],
[
0
,
59
,
5
]],
'b'
:
{
'x'
:
[[[
71
,
0
,
0
,
79
],
[
0
,
0
,
13
,
0
]],
[[
0
,
89
,
0
,
0
],
[
0
,
0
,
29
,
0
]],
[[
0
,
3
,
93
,
89
],
[
0
,
89
,
85
,
0
]]]},
})).
all
()
treetensor/torch/funcs.py
浏览文件 @
8217663f
...
...
@@ -31,7 +31,7 @@ __all__ = [
'add'
,
'sub'
,
'mul'
,
'div'
,
'pow'
,
'neg'
,
'neg_'
,
'exp'
,
'exp_'
,
'exp2'
,
'exp2_'
,
'sqrt'
,
'sqrt_'
,
'log'
,
'log_'
,
'log2'
,
'log2_'
,
'log10'
,
'log10_'
,
'cat'
,
'split'
,
'stack'
,
'cat'
,
'split'
,
'stack'
,
'reshape'
,
'where'
,
'squeeze'
,
'unsqueeze'
,
]
func_treelize
=
post_process
(
post_process
(
args_mapping
(
...
...
@@ -2443,4 +2443,172 @@ def stack(tensors, *args, **kwargs):
return
torch
.
stack
(
tensors
,
*
args
,
**
kwargs
)
sys
.
modules
[
__name__
]
=
module_autoremove
(
sys
.
modules
[
__name__
])
# noinspection PyShadowingBuiltins
@
doc_from_base
()
@
func_treelize
()
def
reshape
(
input
,
shape
):
"""
Returns a tensor with the same data and number of elements as ``input``,
but with the specified shape. When possible, the returned tensor will be a view of ``input``.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.reshape(torch.tensor([[1, 2], [3, 4]]), (-1, ))
tensor([1, 2, 3, 4])
>>> ttorch.reshape(ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': {'x': [[2], [3], [5], [7], [11], [13]]},
... }), (-1, ))
<Tensor 0x7fc9efa3bda0>
├── a --> tensor([1, 2, 3, 4])
└── b --> <Tensor 0x7fc9efa3bcf8>
└── x --> tensor([ 2, 3, 5, 7, 11, 13])
.. note::
If the given ``shape`` is only one tuple, it should make sure that all the tensors
in this tree can be reshaped to the given ``shape``. Or you can give a tree of tuples
to reshape the tensors to different shapes.
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.reshape(ttorch.tensor({
... 'a': [[1, 2], [3, 4]],
... 'b': {'x': [[2], [3], [5], [7], [11], [13]]},
... }), {'a': (4, ), 'b': {'x': (3, 2)}})
<Tensor 0x7fc9efa3bd68>
├── a --> tensor([1, 2, 3, 4])
└── b --> <Tensor 0x7fc9efa3bf28>
└── x --> tensor([[ 2, 3],
[ 5, 7],
[11, 13]])
"""
return
torch
.
reshape
(
input
,
shape
)
# noinspection PyShadowingBuiltins
@
doc_from_base
()
@
func_treelize
()
def
squeeze
(
input
,
*
args
,
**
kwargs
):
"""
Returns a tensor with all the dimensions of ``input`` of size 1 removed.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randint(100, (2, 1, 2, 1, 2))
>>> t1.shape
torch.Size([2, 1, 2, 1, 2])
>>> ttorch.squeeze(t1).shape
torch.Size([2, 2, 2])
>>> tt1 = ttorch.randint(100, {
... 'a': (2, 1, 2, 1, 2),
... 'b': {'x': (2, 1, 1, 3)},
... })
>>> tt1.shape
<Size 0x7fa4c1b05410>
├── a --> torch.Size([2, 1, 2, 1, 2])
└── b --> <Size 0x7fa4c1b05510>
└── x --> torch.Size([2, 1, 1, 3])
>>> ttorch.squeeze(tt1).shape
<Size 0x7fa4c1b9f3d0>
├── a --> torch.Size([2, 2, 2])
└── b --> <Size 0x7fa4c1afe710>
└── x --> torch.Size([2, 3])
"""
return
torch
.
squeeze
(
input
,
*
args
,
*
kwargs
)
# noinspection PyShadowingBuiltins
@
doc_from_base
()
@
func_treelize
()
def
unsqueeze
(
input
,
dim
):
"""
Returns a new tensor with a dimension of size one inserted at the specified position.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> t1 = torch.randint(100, (100, ))
>>> t1.shape
torch.Size([100])
>>> ttorch.unsqueeze(t1, 0).shape
torch.Size([1, 100])
>>> tt1 = ttorch.randint(100, {
... 'a': (2, 2, 2),
... 'b': {'x': (2, 3)},
... })
>>> tt1.shape
<Size 0x7f5d1a5741d0>
├── a --> torch.Size([2, 2, 2])
└── b --> <Size 0x7f5d1a5740b8>
└── x --> torch.Size([2, 3])
>>> ttorch.unsqueeze(tt1, 1).shape
<Size 0x7f5d1a5c98d0>
├── a --> torch.Size([2, 1, 2, 2])
└── b --> <Size 0x7f5d1a5c99b0>
└── x --> torch.Size([2, 1, 3])
"""
return
torch
.
unsqueeze
(
input
,
dim
)
@
doc_from_base
()
@
func_treelize
()
def
where
(
condition
,
x
,
y
):
"""
Return a tree of tensors of elements selected from either ``x`` or ``y``, depending on ``condition``.
Examples::
>>> import torch
>>> import treetensor.torch as ttorch
>>> ttorch.where(
... torch.tensor([[True, False], [False, True]]),
... torch.tensor([[2, 8], [16, 4]]),
... torch.tensor([[3, 11], [5, 7]]),
... )
tensor([[ 2, 11],
[ 5, 4]])
>>> tt1 = ttorch.randint(1, 99, {'a': (2, 3), 'b': {'x': (3, 2, 4)}})
>>> tt1
<Tensor 0x7f6760ad9908>
├── a --> tensor([[27, 90, 80],
│ [12, 59, 5]])
└── b --> <Tensor 0x7f6760ad9860>
└── x --> tensor([[[71, 52, 92, 79],
[48, 4, 13, 96]],
[[72, 89, 44, 62],
[32, 4, 29, 76]],
[[ 6, 3, 93, 89],
[44, 89, 85, 90]]])
>>> ttorch.where(tt1 % 2 == 1, tt1, 0)
<Tensor 0x7f6760ad9d30>
├── a --> tensor([[27, 0, 0],
│ [ 0, 59, 5]])
└── b --> <Tensor 0x7f6760ad9f98>
└── x --> tensor([[[71, 0, 0, 79],
[ 0, 0, 13, 0]],
[[ 0, 89, 0, 0],
[ 0, 0, 29, 0]],
[[ 0, 3, 93, 89],
[ 0, 89, 85, 0]]])
"""
return
torch
.
where
(
condition
,
x
,
y
)
_current_module
=
sys
.
modules
[
__name__
]
_current_module
=
module_autoremove
(
_current_module
)
sys
.
modules
[
__name__
]
=
_current_module
treetensor/torch/tensor.py
浏览文件 @
8217663f
...
...
@@ -661,3 +661,55 @@ class Tensor(Torch, metaclass=clsmeta(_to_tensor, allow_dict=True)):
See :func:`treetensor.torch.split`.
"""
return
self
.
split
(
split_size
,
*
args
,
**
kwargs
)
@
doc_from_base
()
@
method_treelize
()
def
reshape
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.reshape`.
"""
return
self
.
reshape
(
*
args
,
**
kwargs
)
@
doc_from_base
()
@
method_treelize
()
def
squeeze
(
self
,
*
args
,
**
kwargs
):
"""
See :func:`treetensor.torch.squeeze`.
"""
return
self
.
squeeze
(
*
args
,
**
kwargs
)
@
doc_from_base
()
@
return_self
@
method_treelize
()
def
squeeze_
(
self
,
*
args
,
**
kwargs
):
"""
In-place version of :meth:`Tensor.squeeze'.
"""
return
self
.
squeeze_
(
*
args
,
**
kwargs
)
@
doc_from_base
()
@
method_treelize
()
def
unsqueeze
(
self
,
dim
):
"""
See :func:`treetensor.torch.unsqueeze`.
"""
return
self
.
unsqueeze
(
dim
)
@
doc_from_base
()
@
return_self
@
method_treelize
()
def
unsqueeze_
(
self
,
dim
):
"""
In-place version of :meth:`Tensor.unsqueeze'.
"""
return
self
.
unsqueeze_
(
dim
)
@
doc_from_base
()
@
method_treelize
()
def
where
(
self
,
condition
,
y
,
*
args
,
**
kwargs
):
"""
``self.where(condition, y)`` is equivalent to
``treetensor.torch.where(condition, self, y)``.
See :func:`treetensor.torch.where`.
"""
return
self
.
where
(
condition
,
y
,
*
args
,
**
kwargs
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录