Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
18860735
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
18860735
编写于
9月 22, 2022
作者:
A
Aurelius84
提交者:
GitHub
9月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[BugFix]Fix pooling output_size bug if encounter list[Tensor] (#46352)
* [Check]Enhance pooling output_size type check * add unittest
上级
f30ead13
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
52 addition
and
22 deletion
+52
-22
python/paddle/fluid/tests/unittests/test_unpool_op.py
python/paddle/fluid/tests/unittests/test_unpool_op.py
+29
-0
python/paddle/nn/functional/pooling.py
python/paddle/nn/functional/pooling.py
+23
-22
未找到文件。
python/paddle/fluid/tests/unittests/test_unpool_op.py
浏览文件 @
18860735
...
...
@@ -436,6 +436,35 @@ class TestZOutputSizeTensor2(unittest.TestCase):
np
.
testing
.
assert_array_equal
(
unpool_out
.
shape
,
[
1
,
3
,
7
,
7
])
class
TestZOutputSizeTensor3
(
unittest
.
TestCase
):
def
setUp
(
self
):
paddle
.
disable_static
()
def
tearDown
(
self
):
paddle
.
enable_static
()
def
test_dygraph
(
self
):
x
=
paddle
.
randn
([
1
,
3
,
6
,
6
])
pool_out
,
indices
=
F
.
max_pool2d
(
x
,
kernel_size
=
2
,
stride
=
2
,
padding
=
0
,
return_mask
=
True
)
output_size
=
[
paddle
.
assign
([
1
]),
paddle
.
assign
([
1
]),
paddle
.
assign
([
7
]),
paddle
.
assign
([
7
])
]
unpool_out
=
F
.
max_unpool2d
(
pool_out
,
indices
,
kernel_size
=
2
,
padding
=
0
,
output_size
=
output_size
)
np
.
testing
.
assert_array_equal
(
unpool_out
.
shape
,
[
1
,
3
,
7
,
7
])
if
__name__
==
'__main__'
:
paddle
.
enable_static
()
unittest
.
main
()
python/paddle/nn/functional/pooling.py
浏览文件 @
18860735
...
...
@@ -646,6 +646,9 @@ def max_pool1d(x,
def
_unpool_output_size
(
x
,
kernel_size
,
stride
,
padding
,
output_size
):
assert
output_size
is
None
or
isinstance
(
output_size
,
(
list
,
tuple
)
),
"Required output_size is None|list|tuple, but received %s"
%
output_size
input_size
=
x
.
shape
default_size
=
[]
for
d
in
range
(
len
(
kernel_size
)):
...
...
@@ -654,7 +657,7 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size):
has_static_var
=
False
if
output_size
is
None
:
ret
=
default_size
ret
urn
default_size
elif
utils
.
_contain_var
(
output_size
):
if
not
_non_static_mode
():
has_static_var
=
True
...
...
@@ -663,27 +666,25 @@ def _unpool_output_size(x, kernel_size, stride, padding, output_size):
for
i
,
var
in
enumerate
(
output_size
):
if
isinstance
(
var
,
Variable
):
output_size
[
i
]
=
var
.
numpy
()[
0
]
ret
=
output_size
else
:
if
len
(
output_size
)
==
len
(
kernel_size
)
+
2
:
output_size
=
output_size
[
2
:]
if
len
(
output_size
)
!=
len
(
kernel_size
):
raise
ValueError
(
"output_size should be a sequence containing "
"{} or {} elements, but it has a length of '{}'"
.
format
(
len
(
kernel_size
),
len
(
kernel_size
)
+
2
,
len
(
output_size
)))
if
not
has_static_var
:
for
d
in
range
(
len
(
kernel_size
)):
min_size
=
default_size
[
d
]
-
stride
[
d
]
max_size
=
default_size
[
d
]
+
stride
[
d
]
if
not
(
min_size
<
output_size
[
d
]
<
max_size
):
raise
ValueError
(
'invalid output_size "{}" (dim {} must be between {} and {})'
.
format
(
output_size
,
d
,
min_size
,
max_size
))
ret
=
output_size
return
ret
if
len
(
output_size
)
==
len
(
kernel_size
)
+
2
:
output_size
=
output_size
[
2
:]
if
len
(
output_size
)
!=
len
(
kernel_size
):
raise
ValueError
(
"output_size should be a sequence containing "
"{} or {} elements, but it has a length of '{}'"
.
format
(
len
(
kernel_size
),
len
(
kernel_size
)
+
2
,
len
(
output_size
)))
if
not
has_static_var
:
for
d
in
range
(
len
(
kernel_size
)):
min_size
=
default_size
[
d
]
-
stride
[
d
]
max_size
=
default_size
[
d
]
+
stride
[
d
]
if
not
(
min_size
<
output_size
[
d
]
<
max_size
):
raise
ValueError
(
'invalid output_size "{}" (dim {} must be between {} and {})'
.
format
(
output_size
,
d
,
min_size
,
max_size
))
return
output_size
def
max_unpool1d
(
x
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录