Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
feaf1e2d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
feaf1e2d
编写于
11月 14, 2017
作者:
C
chengduo
提交者:
GitHub
11月 14, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5608 from chengduoZH/fix_pooling_function_parameter_order
fix pooling functor parameter order
上级
d7319c22
21604977
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
278 addition
and
272 deletion
+278
-272
paddle/operators/math/pooling.cc
paddle/operators/math/pooling.cc
+61
-57
paddle/operators/math/pooling.cu
paddle/operators/math/pooling.cu
+167
-167
paddle/operators/math/pooling.h
paddle/operators/math/pooling.h
+26
-24
paddle/operators/pool_op.h
paddle/operators/pool_op.h
+16
-16
paddle/operators/pool_with_index_op.h
paddle/operators/pool_with_index_op.h
+8
-8
未找到文件。
paddle/operators/math/pooling.cc
浏览文件 @
feaf1e2d
...
...
@@ -27,15 +27,15 @@ template <typename PoolProcess, typename T>
class
Pool2dFunctor
<
platform
::
CPUPlace
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
stride
s
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
)
{
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
padding
s
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_width
=
output
.
dims
()[
3
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_height
=
output
->
dims
()[
2
];
const
int
output_width
=
output
->
dims
()[
3
];
const
int
ksize_height
=
ksize
[
0
];
const
int
ksize_width
=
ksize
[
1
];
const
int
stride_height
=
strides
[
0
];
...
...
@@ -47,7 +47,7 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
const
int
output_stride
=
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
...
...
@@ -87,11 +87,12 @@ template <typename PoolProcess, class T>
class
Pool2dGradFunctor
<
platform
::
CPUPlace
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
)
{
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -110,7 +111,7 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
...
...
@@ -154,10 +155,11 @@ template <class T>
class
MaxPool2dGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -176,7 +178,7 @@ class MaxPool2dGradFunctor<platform::CPUPlace, T> {
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
...
...
@@ -240,17 +242,17 @@ template <typename PoolProcess, class T>
class
Pool3dFunctor
<
platform
::
CPUPlace
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
stride
s
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
)
{
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
padding
s
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
const
int
input_width
=
input
.
dims
()[
4
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_depth
=
output
.
dims
()[
2
];
const
int
output_height
=
output
.
dims
()[
3
];
const
int
output_width
=
output
.
dims
()[
4
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_depth
=
output
->
dims
()[
2
];
const
int
output_height
=
output
->
dims
()[
3
];
const
int
output_width
=
output
->
dims
()[
4
];
const
int
ksize_depth
=
ksize
[
0
];
const
int
ksize_height
=
ksize
[
1
];
const
int
ksize_width
=
ksize
[
2
];
...
...
@@ -265,7 +267,7 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
const
int
output_stride
=
output_depth
*
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
...
...
@@ -315,11 +317,12 @@ template <typename PoolProcess, class T>
class
Pool3dGradFunctor
<
platform
::
CPUPlace
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
)
{
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -343,7 +346,7 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
...
...
@@ -398,10 +401,11 @@ template <class T>
class
MaxPool3dGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -425,7 +429,7 @@ class MaxPool3dGradFunctor<platform::CPUPlace, T> {
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
...
...
@@ -498,15 +502,15 @@ template <typename T>
class
MaxPool2dWithIndexFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_width
=
output
.
dims
()[
3
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_height
=
output
->
dims
()[
2
];
const
int
output_width
=
output
->
dims
()[
3
];
const
int
ksize_height
=
ksize
[
0
];
const
int
ksize_width
=
ksize
[
1
];
const
int
stride_height
=
strides
[
0
];
...
...
@@ -517,8 +521,8 @@ class MaxPool2dWithIndexFunctor<platform::CPUPlace, T> {
const
int
output_stride
=
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
mask_data
=
mask
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
...
...
@@ -563,13 +567,13 @@ template <typename T>
class
MaxPool2dWithIndexGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input_grad
.
dims
()[
0
];
const
int
input_height
=
input_grad
.
dims
()[
2
];
const
int
input_width
=
input_grad
.
dims
()[
3
];
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input_grad
->
dims
()[
0
];
const
int
input_height
=
input_grad
->
dims
()[
2
];
const
int
input_width
=
input_grad
->
dims
()[
3
];
const
int
output_channels
=
output_grad
.
dims
()[
1
];
const
int
output_height
=
output_grad
.
dims
()[
2
];
const
int
output_width
=
output_grad
.
dims
()[
3
];
...
...
@@ -578,7 +582,7 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUPlace, T> {
const
T
*
mask_data
=
mask
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
n
=
0
;
n
<
batch_size
;
++
n
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
...
...
@@ -612,17 +616,17 @@ template <typename T>
class
MaxPool3dWithIndexFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
const
int
input_width
=
input
.
dims
()[
4
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_depth
=
output
.
dims
()[
2
];
const
int
output_height
=
output
.
dims
()[
3
];
const
int
output_width
=
output
.
dims
()[
4
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_depth
=
output
->
dims
()[
2
];
const
int
output_height
=
output
->
dims
()[
3
];
const
int
output_width
=
output
->
dims
()[
4
];
const
int
ksize_depth
=
ksize
[
0
];
const
int
ksize_height
=
ksize
[
1
];
const
int
ksize_width
=
ksize
[
2
];
...
...
@@ -636,8 +640,8 @@ class MaxPool3dWithIndexFunctor<platform::CPUPlace, T> {
const
int
output_stride
=
output_depth
*
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
mask_data
=
mask
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
...
...
@@ -691,14 +695,14 @@ template <typename T>
class
MaxPool3dWithIndexGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input_grad
.
dims
()[
0
];
const
int
input_depth
=
input_grad
.
dims
()[
2
];
const
int
input_height
=
input_grad
.
dims
()[
3
];
const
int
input_width
=
input_grad
.
dims
()[
4
];
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input_grad
->
dims
()[
0
];
const
int
input_depth
=
input_grad
->
dims
()[
2
];
const
int
input_height
=
input_grad
->
dims
()[
3
];
const
int
input_width
=
input_grad
->
dims
()[
4
];
const
int
output_channels
=
output_grad
.
dims
()[
1
];
const
int
output_depth
=
output_grad
.
dims
()[
2
];
const
int
output_height
=
output_grad
.
dims
()[
3
];
...
...
@@ -708,7 +712,7 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUPlace, T> {
const
T
*
mask_data
=
mask
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
n
=
0
;
n
<
batch_size
;
++
n
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
...
...
paddle/operators/math/pooling.cu
浏览文件 @
feaf1e2d
...
...
@@ -21,13 +21,13 @@ namespace math {
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool2D
(
const
int
nthreads
,
const
T
*
input_data
,
T
*
output_data
,
const
int
channels
,
const
int
input_
height
,
const
int
input_width
,
const
int
output_
height
,
const
int
output_width
,
const
int
ksize_
height
,
const
int
ksize_width
,
const
int
stride_
height
,
const
int
stride_width
,
const
int
padding_
height
,
const
int
padding_width
,
PoolProcess
pool_process
)
{
const
int
channels
,
const
int
input_height
,
const
int
input_
width
,
const
int
output_height
,
const
int
output_
width
,
const
int
ksize_height
,
const
int
ksize_
width
,
const
int
stride_height
,
const
int
stride_
width
,
const
int
padding_height
,
const
int
padding_
width
,
PoolProcess
pool_process
,
T
*
output_data
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
...
...
@@ -59,11 +59,11 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool2DGrad
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
T
*
input_grad
,
const
int
channels
,
const
int
input_
height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_
height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
)
{
const
T
*
output_grad
,
const
int
channels
,
const
int
input_height
,
const
int
input_
width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_
width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
offsetW
=
index
%
input_width
+
padding_width
;
...
...
@@ -107,11 +107,11 @@ __global__ void KernelPool2DGrad(
template
<
typename
T
>
__global__
void
KernelMaxPool2DGrad
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
T
*
input_grad
,
const
int
channels
,
const
int
input_
height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_
height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
)
{
const
T
*
output_grad
,
const
int
channels
,
const
int
input_height
,
const
int
input_
width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_
width
,
const
int
padding_height
,
const
int
padding_width
,
T
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
...
...
@@ -158,16 +158,16 @@ template <typename PoolProcess, typename T>
class
Pool2dFunctor
<
platform
::
GPUPlace
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
stride
s
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
)
{
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
padding
s
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_width
=
output
.
dims
()[
3
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_height
=
output
->
dims
()[
2
];
const
int
output_width
=
output
->
dims
()[
3
];
const
int
ksize_height
=
ksize
[
0
];
const
int
ksize_width
=
ksize
[
1
];
const
int
stride_height
=
strides
[
0
];
...
...
@@ -176,7 +176,7 @@ class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> {
const
int
padding_width
=
paddings
[
1
];
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
output_channels
*
output_height
*
output_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
...
...
@@ -187,11 +187,10 @@ class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> {
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_process
);
.
stream
()
>>>
(
nthreads
,
input_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_process
,
output_data
);
}
};
...
...
@@ -204,11 +203,11 @@ template <typename PoolProcess, typename T>
class
Pool2dGradFunctor
<
platform
::
GPUPlace
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
)
{
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -225,7 +224,7 @@ class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> {
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
input_channels
*
input_height
*
input_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
...
...
@@ -237,10 +236,10 @@ class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> {
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_
grad_data
,
input_
channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_
height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
p
adding_width
,
pool_process
);
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_
channels
,
input_
height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_
width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
p
ool_process
,
input_grad_data
);
}
};
...
...
@@ -253,10 +252,11 @@ template <typename T>
class
MaxPool2dGradFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -274,7 +274,7 @@ class MaxPool2dGradFunctor<platform::GPUPlace, T> {
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
output_channels
*
output_height
*
output_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
...
...
@@ -285,10 +285,10 @@ class MaxPool2dGradFunctor<platform::GPUPlace, T> {
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_
grad_data
,
input_
channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_
height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
);
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_
channels
,
input_
height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_
width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
input_grad_data
);
}
};
...
...
@@ -313,14 +313,16 @@ template class Pool2dGradFunctor<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
AvgPoolGrad
<
double
>,
double
>
;
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool3D
(
const
int
nthreads
,
const
T
*
input_data
,
T
*
output_data
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
)
{
__global__
void
KernelPool3D
(
const
int
nthreads
,
const
T
*
input_data
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
output_data
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
...
...
@@ -358,13 +360,13 @@ __global__ void KernelPool3D(
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool3DGrad
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
T
*
input_grad
,
const
int
channels
,
const
int
input_
depth
,
const
int
input_height
,
const
int
input_wid
th
,
const
int
output_
depth
,
const
int
output_height
,
const
int
output_wid
th
,
const
int
ksize_
depth
,
const
int
ksize_height
,
const
int
ksize_wid
th
,
const
int
stride_
depth
,
const
int
stride_height
,
const
int
stride_wid
th
,
const
int
padding_
depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
)
{
const
T
*
output_grad
,
const
int
channels
,
const
int
input_depth
,
const
int
input_
height
,
const
int
input_width
,
const
int
output_dep
th
,
const
int
output_
height
,
const
int
output_width
,
const
int
ksize_dep
th
,
const
int
ksize_
height
,
const
int
ksize_width
,
const
int
stride_dep
th
,
const
int
stride_
height
,
const
int
stride_width
,
const
int
padding_dep
th
,
const
int
padding_
height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
offsetW
=
index
%
input_width
+
padding_width
;
...
...
@@ -422,13 +424,12 @@ __global__ void KernelPool3DGrad(
template
<
typename
T
>
__global__
void
KernelMaxPool3DGrad
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
T
*
input_grad
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
)
{
const
T
*
output_grad
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
T
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
...
...
@@ -480,18 +481,18 @@ template <typename PoolProcess, class T>
class
Pool3dFunctor
<
platform
::
GPUPlace
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
stride
s
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
)
{
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
padding
s
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
const
int
input_width
=
input
.
dims
()[
4
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_depth
=
output
.
dims
()[
2
];
const
int
output_height
=
output
.
dims
()[
3
];
const
int
output_width
=
output
.
dims
()[
4
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_depth
=
output
->
dims
()[
2
];
const
int
output_height
=
output
->
dims
()[
3
];
const
int
output_width
=
output
->
dims
()[
4
];
const
int
ksize_depth
=
ksize
[
0
];
const
int
ksize_height
=
ksize
[
1
];
const
int
ksize_width
=
ksize
[
2
];
...
...
@@ -503,7 +504,7 @@ class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> {
const
int
padding_width
=
paddings
[
2
];
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
output_channels
*
output_depth
*
output_height
*
output_width
;
...
...
@@ -516,11 +517,11 @@ class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> {
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
input_channels
,
input_depth
,
input_
height
,
input_width
,
output_depth
,
output_height
,
output_wid
th
,
ksize_
depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
pool_process
);
nthreads
,
input_data
,
input_channels
,
input_depth
,
input_height
,
input_
width
,
output_depth
,
output_height
,
output_width
,
ksize_dep
th
,
ksize_
height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
pool_process
,
output_data
);
}
};
...
...
@@ -533,11 +534,11 @@ template <typename PoolProcess, class T>
class
Pool3dGradFunctor
<
platform
::
GPUPlace
,
PoolProcess
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
)
{
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -560,7 +561,7 @@ class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> {
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
input_channels
*
input_depth
*
input_height
*
input_width
;
...
...
@@ -573,11 +574,11 @@ class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> {
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_
grad_data
,
input_
channels
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_
height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_wid
th
,
stride_
depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_
height
,
padding_width
,
pool_process
);
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_
channels
,
input_
depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_
width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_dep
th
,
stride_
height
,
stride_width
,
padding_depth
,
padding_height
,
padding_
width
,
pool_process
,
input_grad_data
);
}
};
...
...
@@ -590,10 +591,11 @@ template <class T>
class
MaxPool3dGradFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -616,7 +618,7 @@ class MaxPool3dGradFunctor<platform::GPUPlace, T> {
const
T
*
input_data
=
input
.
data
<
T
>
();
const
T
*
output_data
=
output
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
output_channels
*
output_depth
*
output_height
*
output_width
;
...
...
@@ -628,11 +630,11 @@ class MaxPool3dGradFunctor<platform::GPUPlace, T> {
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_
grad_data
,
input_
channels
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_
height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_wid
th
,
stride_
depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_
height
,
padding_width
);
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_
channels
,
input_
depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_
width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_dep
th
,
stride_
height
,
stride_width
,
padding_depth
,
padding_height
,
padding_
width
,
input_grad_data
);
}
};
...
...
@@ -658,11 +660,11 @@ template class Pool3dGradFunctor<
template
<
typename
T
>
__global__
void
KernelMaxPool2dWithIdx
(
const
int
nthreads
,
const
T
*
input_data
,
T
*
output_data
,
T
*
mask_data
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
const
int
output_
height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_
height
,
const
int
padding_width
)
{
const
int
nthreads
,
const
T
*
input_data
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_
width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_
width
,
T
*
output_data
,
T
*
mask_data
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
...
...
@@ -697,11 +699,11 @@ __global__ void KernelMaxPool2dWithIdx(
template
<
typename
T
>
__global__
void
KernelMaxPool2DWithIdxGrad
(
const
int
nthreads
,
T
*
input_grad
,
const
T
*
output_grad
,
const
T
*
mask_data
,
const
int
nthreads
,
const
T
*
output_grad
,
const
T
*
mask_data
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
)
{
const
int
padding_height
,
const
int
padding_width
,
T
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
w_offset
=
index
%
input_width
;
...
...
@@ -748,16 +750,16 @@ template <typename T>
class
MaxPool2dWithIndexFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_width
=
output
.
dims
()[
3
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_height
=
output
->
dims
()[
2
];
const
int
output_width
=
output
->
dims
()[
3
];
const
int
ksize_height
=
ksize
[
0
];
const
int
ksize_width
=
ksize
[
1
];
const
int
stride_height
=
strides
[
0
];
...
...
@@ -766,8 +768,8 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> {
const
int
padding_width
=
paddings
[
1
];
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
mask_data
=
mask
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
output_channels
*
output_height
*
output_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
...
...
@@ -777,11 +779,10 @@ class MaxPool2dWithIndexFunctor<platform::GPUPlace, T> {
KernelMaxPool2dWithIdx
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
mask_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
);
.
stream
()
>>>
(
nthreads
,
input_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
output_data
,
mask_data
);
}
};
...
...
@@ -794,14 +795,14 @@ template <typename T>
class
MaxPool2dWithIndexGradFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input_grad
.
dims
()[
0
];
const
int
input_channels
=
input_grad
.
dims
()[
1
];
const
int
input_height
=
input_grad
.
dims
()[
2
];
const
int
input_width
=
input_grad
.
dims
()[
3
];
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input_grad
->
dims
()[
0
];
const
int
input_channels
=
input_grad
->
dims
()[
1
];
const
int
input_height
=
input_grad
->
dims
()[
2
];
const
int
input_width
=
input_grad
->
dims
()[
3
];
const
int
output_height
=
output_grad
.
dims
()[
2
];
const
int
output_width
=
output_grad
.
dims
()[
3
];
const
int
ksize_height
=
ksize
[
0
];
...
...
@@ -813,7 +814,7 @@ class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> {
const
T
*
mask_data
=
mask
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
input_channels
*
input_height
*
input_width
;
int
blocks
=
(
nthreads
+
1024
-
1
)
/
1024
;
...
...
@@ -823,11 +824,11 @@ class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> {
KernelMaxPool2DWithIdxGrad
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
nthreads
,
input_grad_data
,
output_grad
_data
,
mask_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_
height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
);
.
stream
()
>>>
(
nthreads
,
output_grad_data
,
mask
_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_
width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
input_grad_data
);
}
};
...
...
@@ -838,13 +839,13 @@ template class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, double>;
template
<
typename
T
>
__global__
void
KernelMaxPool3DWithIdx
(
const
int
nthreads
,
const
T
*
input_data
,
T
*
output_data
,
T
*
mask_data
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
)
{
const
int
nthreads
,
const
T
*
input_data
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
T
*
output_data
,
T
*
mask_data
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
...
...
@@ -886,13 +887,13 @@ __global__ void KernelMaxPool3DWithIdx(
template
<
typename
T
>
__global__
void
KernelMaxPool3DWithIdxGrad
(
const
int
nthreads
,
T
*
input_grad
,
const
T
*
output_grad
,
const
T
*
mask
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
)
{
const
int
nthreads
,
const
T
*
output_grad
,
const
T
*
mask
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
T
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
w_offset
=
index
%
input_width
;
...
...
@@ -952,18 +953,18 @@ template <typename T>
class
MaxPool3dWithIndexFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
const
int
input_width
=
input
.
dims
()[
4
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_depth
=
output
.
dims
()[
2
];
const
int
output_height
=
output
.
dims
()[
3
];
const
int
output_width
=
output
.
dims
()[
4
];
const
int
output_channels
=
output
->
dims
()[
1
];
const
int
output_depth
=
output
->
dims
()[
2
];
const
int
output_height
=
output
->
dims
()[
3
];
const
int
output_width
=
output
->
dims
()[
4
];
const
int
ksize_depth
=
ksize
[
0
];
const
int
ksize_height
=
ksize
[
1
];
const
int
ksize_width
=
ksize
[
2
];
...
...
@@ -975,8 +976,8 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
const
int
padding_width
=
paddings
[
2
];
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
mask_data
=
mask
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
mask_data
=
mask
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
output_channels
*
output_depth
*
output_height
*
output_width
;
...
...
@@ -988,11 +989,10 @@ class MaxPool3dWithIndexFunctor<platform::GPUPlace, T> {
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
nthreads
,
input_data
,
output_data
,
mask_data
,
input_channels
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
);
nthreads
,
input_data
,
input_channels
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
output_data
,
mask_data
);
}
};
...
...
@@ -1005,15 +1005,15 @@ template <typename T>
class
MaxPool3dWithIndexGradFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input_grad
.
dims
()[
0
];
const
int
input_channels
=
input_grad
.
dims
()[
1
];
const
int
input_depth
=
input_grad
.
dims
()[
2
];
const
int
input_height
=
input_grad
.
dims
()[
3
];
const
int
input_width
=
input_grad
.
dims
()[
4
];
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input_grad
->
dims
()[
0
];
const
int
input_channels
=
input_grad
->
dims
()[
1
];
const
int
input_depth
=
input_grad
->
dims
()[
2
];
const
int
input_height
=
input_grad
->
dims
()[
3
];
const
int
input_width
=
input_grad
->
dims
()[
4
];
const
int
output_depth
=
output_grad
.
dims
()[
2
];
const
int
output_height
=
output_grad
.
dims
()[
3
];
const
int
output_width
=
output_grad
.
dims
()[
4
];
...
...
@@ -1029,7 +1029,7 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
const
T
*
mask_data
=
mask
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
input_grad_data
=
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
nthreads
=
batch_size
*
input_channels
*
input_depth
*
input_height
*
input_width
;
...
...
@@ -1041,11 +1041,11 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
nthreads
,
input_grad_data
,
output_grad_data
,
mask_data
,
input_channels
,
input_
depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_
height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
);
nthreads
,
output_grad_data
,
mask_data
,
input_channels
,
input_depth
,
input_
height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_
width
,
padding_depth
,
padding_height
,
padding_width
,
input_grad_data
);
}
};
...
...
paddle/operators/math/pooling.h
浏览文件 @
feaf1e2d
...
...
@@ -88,60 +88,62 @@ template <typename Place, typename PoolProcess, typename T>
class
Pool2dFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
stride
s
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
);
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
padding
s
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
output
);
};
template
<
typename
Place
,
typename
PoolProcess
,
typename
T
>
class
Pool2dGradFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
);
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
Place
,
class
T
>
class
MaxPool2dGradFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
Place
,
typename
PoolProcess
,
typename
T
>
class
Pool3dFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
stride
s
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
);
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
padding
s
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
output
);
};
template
<
typename
Place
,
typename
PoolProcess
,
typename
T
>
class
Pool3dGradFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
);
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
Place
,
class
T
>
class
MaxPool3dGradFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
/*
...
...
@@ -155,38 +157,38 @@ template <typename Place, typename T>
class
MaxPool2dWithIndexFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
);
};
template
<
typename
Place
,
typename
T
>
class
MaxPool2dWithIndexGradFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
template
<
typename
Place
,
typename
T
>
class
MaxPool3dWithIndexFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
const
framework
::
Tensor
&
input
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
mask
);
};
template
<
typename
Place
,
typename
T
>
class
MaxPool3dWithIndexGradFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
framework
::
Tensor
*
input_grad
);
};
}
// namespace math
...
...
paddle/operators/pool_op.h
浏览文件 @
feaf1e2d
...
...
@@ -75,16 +75,16 @@ class PoolKernel : public framework::OpKernel<T> {
Place
,
paddle
::
operators
::
math
::
MaxPool
<
T
>
,
T
>
pool2d_forward
;
paddle
::
operators
::
math
::
MaxPool
<
T
>
pool_process
;
pool2d_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
ksize
,
strides
,
paddings
,
pool_process
);
pool2d_forward
(
context
.
device_context
(),
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool2dFunctor
<
Place
,
paddle
::
operators
::
math
::
AvgPool
<
T
>
,
T
>
pool2d_forward
;
paddle
::
operators
::
math
::
AvgPool
<
T
>
pool_process
;
pool2d_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
ksize
,
strides
,
paddings
,
pool_process
);
pool2d_forward
(
context
.
device_context
(),
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
}
}
break
;
case
3
:
{
...
...
@@ -93,15 +93,15 @@ class PoolKernel : public framework::OpKernel<T> {
Place
,
paddle
::
operators
::
math
::
MaxPool
<
T
>
,
T
>
pool3d_forward
;
paddle
::
operators
::
math
::
MaxPool
<
T
>
pool_process
;
pool3d_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
ksize
,
strides
,
paddings
,
pool_process
);
pool3d_forward
(
context
.
device_context
(),
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool3dFunctor
<
Place
,
paddle
::
operators
::
math
::
AvgPool
<
T
>
,
T
>
pool3d_forward
;
paddle
::
operators
::
math
::
AvgPool
<
T
>
pool_process
;
pool3d_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
ksize
,
strides
,
paddings
,
pool_process
);
pool3d_forward
(
context
.
device_context
(),
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
}
}
break
;
default:
{
PADDLE_THROW
(
"Pool op only supports 2D and 3D input."
);
}
...
...
@@ -142,30 +142,30 @@ class PoolGradKernel : public framework::OpKernel<T> {
if
(
pooling_type
==
"max"
)
{
paddle
::
operators
::
math
::
MaxPool2dGradFunctor
<
Place
,
T
>
pool2d_backward
;
pool2d_backward
(
context
.
device_context
(),
*
in_x
,
*
in_x_grad
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
);
pool2d_backward
(
context
.
device_context
(),
*
in_x
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
in_x_grad
);
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool2dGradFunctor
<
Place
,
paddle
::
operators
::
math
::
AvgPoolGrad
<
T
>
,
T
>
pool2d_backward
;
paddle
::
operators
::
math
::
AvgPoolGrad
<
T
>
pool_process
;
pool2d_backward
(
context
.
device_context
(),
*
in_x
,
*
in_x_grad
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
pool_process
);
pool2d_backward
(
context
.
device_context
(),
*
in_x
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
pool_process
,
in_x_grad
);
}
}
break
;
case
3
:
{
if
(
pooling_type
==
"max"
)
{
paddle
::
operators
::
math
::
MaxPool3dGradFunctor
<
Place
,
T
>
pool3d_backward
;
pool3d_backward
(
context
.
device_context
(),
*
in_x
,
*
in_x_grad
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
);
pool3d_backward
(
context
.
device_context
(),
*
in_x
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
in_x_grad
);
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool3dGradFunctor
<
Place
,
paddle
::
operators
::
math
::
AvgPoolGrad
<
T
>
,
T
>
pool3d_backward
;
paddle
::
operators
::
math
::
AvgPoolGrad
<
T
>
pool_process
;
pool3d_backward
(
context
.
device_context
(),
*
in_x
,
*
in_x_grad
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
pool_process
);
pool3d_backward
(
context
.
device_context
(),
*
in_x
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
pool_process
,
in_x_grad
);
}
}
break
;
default:
{
PADDLE_THROW
(
"Pool op only supports 2D and 3D input."
);
}
...
...
paddle/operators/pool_with_index_op.h
浏览文件 @
feaf1e2d
...
...
@@ -46,14 +46,14 @@ class MaxPoolWithIndexKernel : public framework::OpKernel<T> {
case
2
:
{
paddle
::
operators
::
math
::
MaxPool2dWithIndexFunctor
<
Place
,
T
>
pool2d_forward
;
pool2d_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
*
mask
,
ksize
,
strides
,
paddings
);
pool2d_forward
(
context
.
device_context
(),
*
in_x
,
ksize
,
strides
,
paddings
,
out
,
mask
);
}
break
;
case
3
:
{
paddle
::
operators
::
math
::
MaxPool3dWithIndexFunctor
<
Place
,
T
>
pool3d_forward
;
pool3d_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
*
mask
,
ksize
,
strides
,
paddings
);
pool3d_forward
(
context
.
device_context
(),
*
in_x
,
ksize
,
strides
,
paddings
,
out
,
mask
);
}
break
;
default:
{
PADDLE_THROW
(
"Pool op only supports 2D and 3D input."
);
}
}
...
...
@@ -89,14 +89,14 @@ class MaxPoolWithIndexGradKernel : public framework::OpKernel<T> {
case
2
:
{
paddle
::
operators
::
math
::
MaxPool2dWithIndexGradFunctor
<
Place
,
T
>
pool2d_backward
;
pool2d_backward
(
context
.
device_context
(),
*
in_x_grad
,
*
out_grad
,
*
mask
,
ksize
,
strides
,
paddings
);
pool2d_backward
(
context
.
device_context
(),
*
out_grad
,
*
mask
,
ksize
,
strides
,
paddings
,
in_x_grad
);
}
break
;
case
3
:
{
paddle
::
operators
::
math
::
MaxPool3dWithIndexGradFunctor
<
Place
,
T
>
pool3d_backward
;
pool3d_backward
(
context
.
device_context
(),
*
in_x_grad
,
*
out_grad
,
*
mask
,
ksize
,
strides
,
paddings
);
pool3d_backward
(
context
.
device_context
(),
*
out_grad
,
*
mask
,
ksize
,
strides
,
paddings
,
in_x_grad
);
}
break
;
default:
{
PADDLE_THROW
(
"Pool op only supports 2D and 3D input."
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录