Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8986a821
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8986a821
编写于
8月 26, 2020
作者:
B
Bai Yifan
提交者:
GitHub
8月 26, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix adaptive gpu grad bug, add doc refine (#26660)
上级
98e057bb
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
47 addition
and
25 deletion
+47
-25
paddle/fluid/operators/math/pooling.cu
paddle/fluid/operators/math/pooling.cu
+14
-15
python/paddle/fluid/tests/unittests/test_pool2d_op.py
python/paddle/fluid/tests/unittests/test_pool2d_op.py
+13
-0
python/paddle/fluid/tests/unittests/test_pool3d_op.py
python/paddle/fluid/tests/unittests/test_pool3d_op.py
+12
-0
python/paddle/nn/functional/pooling.py
python/paddle/nn/functional/pooling.py
+6
-8
python/paddle/nn/layer/pooling.py
python/paddle/nn/layer/pooling.py
+2
-2
未找到文件。
paddle/fluid/operators/math/pooling.cu
浏览文件 @
8986a821
...
@@ -111,12 +111,11 @@ __global__ void KernelPool2DGrad(
...
@@ -111,12 +111,11 @@ __global__ void KernelPool2DGrad(
int
phstart
,
phend
;
int
phstart
,
phend
;
int
pwstart
,
pwend
;
int
pwstart
,
pwend
;
if
(
adaptive
)
{
if
(
adaptive
)
{
phstart
=
h_offset
*
output_height
/
input_height
;
phstart
=
AdaptStartIndex
(
h_offset
,
output_height
,
input_height
);
phend
=
phend
=
AdaptEndIndex
(
h_offset
,
output_height
,
input_height
);
min
((
h_offset
+
1
)
*
output_height
/
input_height
+
1
,
output_height
);
pwstart
=
w_offset
*
output_width
/
input_width
;
pwstart
=
AdaptStartIndex
(
w_offset
,
output_width
,
input_width
);
pwend
=
pwend
=
AdaptEndIndex
(
w_offset
,
output_width
,
input_width
);
min
((
w_offset
+
1
)
*
output_width
/
input_width
+
1
,
output_width
);
}
else
{
}
else
{
phstart
=
(
h_offset
<
ksize_height
)
phstart
=
(
h_offset
<
ksize_height
)
?
0
?
0
...
@@ -159,6 +158,7 @@ __global__ void KernelPool2DGrad(
...
@@ -159,6 +158,7 @@ __global__ void KernelPool2DGrad(
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
:
ksize_height
*
ksize_width
;
}
}
int
output_sub_idx
=
channel_last
int
output_sub_idx
=
channel_last
?
(
ph
*
output_width
+
pw
)
*
channels
+
offsetC
?
(
ph
*
output_width
+
pw
)
*
channels
+
offsetC
:
ph
*
output_width
+
pw
;
:
ph
*
output_width
+
pw
;
...
@@ -689,15 +689,14 @@ __global__ void KernelPool3DGrad(
...
@@ -689,15 +689,14 @@ __global__ void KernelPool3DGrad(
int
phstart
,
phend
;
int
phstart
,
phend
;
int
pwstart
,
pwend
;
int
pwstart
,
pwend
;
if
(
adaptive
)
{
if
(
adaptive
)
{
pdstart
=
d_offset
*
output_depth
/
input_depth
;
pdstart
=
AdaptStartIndex
(
d_offset
,
output_depth
,
input_depth
);
pdend
=
pdend
=
AdaptEndIndex
(
d_offset
,
output_depth
,
input_depth
);
min
((
d_offset
+
1
)
*
output_depth
/
input_depth
+
1
,
output_depth
);
phstart
=
h_offset
*
output_height
/
input_height
;
phstart
=
AdaptStartIndex
(
h_offset
,
output_height
,
input_height
);
phend
=
phend
=
AdaptEndIndex
(
h_offset
,
output_height
,
input_height
);
min
((
h_offset
+
1
)
*
output_height
/
input_height
+
1
,
output_height
);
pwstart
=
w_offset
*
output_width
/
input_width
;
pwstart
=
AdaptStartIndex
(
w_offset
,
output_width
,
input_width
);
pwend
=
pwend
=
AdaptEndIndex
(
w_offset
,
output_width
,
input_width
);
min
((
w_offset
+
1
)
*
output_width
/
input_width
+
1
,
output_width
);
}
else
{
}
else
{
pdstart
=
(
d_offset
<
ksize_depth
)
pdstart
=
(
d_offset
<
ksize_depth
)
?
0
?
0
...
...
python/paddle/fluid/tests/unittests/test_pool2d_op.py
浏览文件 @
8986a821
...
@@ -517,6 +517,19 @@ class TestAvgPoolAdaptive(TestCase1):
...
@@ -517,6 +517,19 @@ class TestAvgPoolAdaptive(TestCase1):
self
.
adaptive
=
True
self
.
adaptive
=
True
class
TestAvgPoolAdaptiveAsyOutSize
(
TestCase1
):
def
init_adaptive
(
self
):
self
.
adaptive
=
True
def
init_shape
(
self
):
self
.
shape
=
[
8
,
3
,
6
,
6
]
def
init_test_case
(
self
):
self
.
ksize
=
[
2
,
3
]
self
.
strides
=
[
1
,
1
]
self
.
paddings
=
[
0
,
0
,
0
,
0
]
#-------test pool2d with asymmetric padding-----
#-------test pool2d with asymmetric padding-----
...
...
python/paddle/fluid/tests/unittests/test_pool3d_op.py
浏览文件 @
8986a821
...
@@ -453,6 +453,18 @@ class TestAvgPoolAdaptive(TestCase1):
...
@@ -453,6 +453,18 @@ class TestAvgPoolAdaptive(TestCase1):
self
.
adaptive
=
True
self
.
adaptive
=
True
class
TestAvgPoolAdaptiveAsyOutSize
(
TestCase1
):
def
init_adaptive
(
self
):
self
.
adaptive
=
True
def
init_shape
(
self
):
self
.
shape
=
[
8
,
3
,
2
,
4
,
4
]
def
init_test_case
(
self
):
self
.
ksize
=
[
2
,
2
,
3
]
self
.
strides
=
[
1
,
1
,
1
]
#-------test pool3d with asymmetric padding------
#-------test pool3d with asymmetric padding------
class
TestPool3d_Op_AsyPadding
(
TestPool3d_Op
):
class
TestPool3d_Op_AsyPadding
(
TestPool3d_Op
):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
...
...
python/paddle/nn/functional/pooling.py
100644 → 100755
浏览文件 @
8986a821
...
@@ -1238,7 +1238,7 @@ def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None):
...
@@ -1238,7 +1238,7 @@ def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None):
Args:
Args:
x (Tensor): The input tensor of adaptive avg pool2d operator, which is a 4-D tensor.
x (Tensor): The input tensor of adaptive avg pool2d operator, which is a 4-D tensor.
The data type can be float
16, float32, float64, int32 or in
t64.
The data type can be float
32 or floa
t64.
output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain two element, (H, W). H and W can be either a int, or None which means
it must contain two element, (H, W). H and W can be either a int, or None which means
the size will be the same as that of the input.
the size will be the same as that of the input.
...
@@ -1285,8 +1285,7 @@ def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None):
...
@@ -1285,8 +1285,7 @@ def adaptive_avg_pool2d(x, output_size, data_format='NCHW', name=None):
# pool_out.shape is [2, 3, 3, 3]
# pool_out.shape is [2, 3, 3, 3]
"""
"""
if
not
in_dygraph_mode
():
if
not
in_dygraph_mode
():
check_variable_and_dtype
(
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
],
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'adaptive_avg_pool2d'
)
'adaptive_avg_pool2d'
)
check_type
(
data_format
,
'data_format'
,
str
,
'adaptive_avg_pool2d'
)
check_type
(
data_format
,
'data_format'
,
str
,
'adaptive_avg_pool2d'
)
...
@@ -1363,7 +1362,7 @@ def adaptive_avg_pool3d(x, output_size, data_format='NCDHW', name=None):
...
@@ -1363,7 +1362,7 @@ def adaptive_avg_pool3d(x, output_size, data_format='NCDHW', name=None):
Args:
Args:
x (Tensor): The input tensor of adaptive avg pool3d operator, which is a 5-D tensor.
x (Tensor): The input tensor of adaptive avg pool3d operator, which is a 5-D tensor.
The data type can be float
16, float32, float64, int32 or in
t64.
The data type can be float
32, floa
t64.
output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
output_size (int|list|tuple): The pool kernel size. If pool kernel size is a tuple or list,
it must contain three elements, (D, H, W). D, H and W can be either a int, or None which means
it must contain three elements, (D, H, W). D, H and W can be either a int, or None which means
the size will be the same as that of the input.
the size will be the same as that of the input.
...
@@ -1413,8 +1412,7 @@ def adaptive_avg_pool3d(x, output_size, data_format='NCDHW', name=None):
...
@@ -1413,8 +1412,7 @@ def adaptive_avg_pool3d(x, output_size, data_format='NCDHW', name=None):
# pool_out.shape is [2, 3, 3, 3, 3]
# pool_out.shape is [2, 3, 3, 3, 3]
"""
"""
if
not
in_dygraph_mode
():
if
not
in_dygraph_mode
():
check_variable_and_dtype
(
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
],
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
'adaptive_avg_pool3d'
)
'adaptive_avg_pool3d'
)
check_type
(
data_format
,
'data_format'
,
str
,
'adaptive_avg_pool3d'
)
check_type
(
data_format
,
'data_format'
,
str
,
'adaptive_avg_pool3d'
)
...
...
python/paddle/nn/layer/pooling.py
浏览文件 @
8986a821
...
@@ -67,7 +67,7 @@ class AdaptiveAvgPool2d(layers.Layer):
...
@@ -67,7 +67,7 @@ class AdaptiveAvgPool2d(layers.Layer):
None by default.
None by default.
Shape:
Shape:
x (Tensor): The input tensor of adaptive avg pool2d operator, which is a 4-D tensor. The data type can be float
16, float32, float64, int32 or in
t64.
x (Tensor): The input tensor of adaptive avg pool2d operator, which is a 4-D tensor. The data type can be float
32 or floa
t64.
output (Tensor): The output tensor of adaptive avg pool2d operator, which is a 4-D tensor. The data type is same as input x.
output (Tensor): The output tensor of adaptive avg pool2d operator, which is a 4-D tensor. The data type is same as input x.
Returns:
Returns:
...
@@ -152,7 +152,7 @@ class AdaptiveAvgPool3d(layers.Layer):
...
@@ -152,7 +152,7 @@ class AdaptiveAvgPool3d(layers.Layer):
to :ref:`api_guide_Name`. Usually name is no need to set and
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
None by default.
Shape:
Shape:
x (Tensor): The input tensor of adaptive avg pool3d operator, which is a 5-D tensor. The data type can be float
16, float32, float64, int32 or in
t64.
x (Tensor): The input tensor of adaptive avg pool3d operator, which is a 5-D tensor. The data type can be float
32 or floa
t64.
output (Tensor): The output tensor of adaptive avg pool3d operator, which is a 5-D tensor. The data type is same as input x.
output (Tensor): The output tensor of adaptive avg pool3d operator, which is a 5-D tensor. The data type is same as input x.
Returns:
Returns:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录