Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
e1e3859e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e1e3859e
编写于
9月 29, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove custom attr checker and fix code format
上级
3c0f0793
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
153 addition
and
210 deletion
+153
-210
paddle/operators/math/pooling.cc
paddle/operators/math/pooling.cc
+34
-32
paddle/operators/math/pooling.cu
paddle/operators/math/pooling.cu
+69
-63
paddle/operators/math/pooling.h
paddle/operators/math/pooling.h
+5
-6
paddle/operators/pool_op.cc
paddle/operators/pool_op.cc
+31
-93
paddle/operators/pool_op.h
paddle/operators/pool_op.h
+14
-16
未找到文件。
paddle/operators/math/pooling.cc
浏览文件 @
e1e3859e
...
...
@@ -24,7 +24,7 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_
compute
)
{
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_
process
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -54,14 +54,15 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wend
=
std
::
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
std
::
max
(
wstart
,
0
);
T
ele
=
pool_compute
.
initial
();
T
ele
=
pool_process
.
initial
();
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_
compute
.
compute
(
ele
,
input_data
[
h
*
input_width
+
w
]);
pool_
process
.
compute
(
ele
,
input_data
[
h
*
input_width
+
w
]);
}
}
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
pool_
compute
.
finalize
(
ele
,
(
static_cast
<
T
>
(
pool_size
)));
pool_
process
.
finalize
(
ele
,
(
static_cast
<
T
>
(
pool_size
)));
output_data
[
ph
*
output_width
+
pw
]
=
ele
;
}
}
...
...
@@ -80,7 +81,7 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_
compute
)
{
PoolProcess
pool_
grad_process
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
...
...
@@ -115,7 +116,8 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
float
scale
=
1.0
/
pool_size
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_compute
.
compute
(
input_data
[
h
*
input_width
+
w
],
pool_grad_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
output_data
[
ph
*
output_width
+
pw
],
output_grad_data
[
ph
*
output_width
+
pw
],
input_grad_data
[
h
*
input_width
+
w
],
...
...
@@ -198,21 +200,21 @@ template class MaxPool2dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
template
class
Pool2dFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
m
axPool
<
float
>,
float
>
;
paddle
::
operators
::
math
::
M
axPool
<
float
>,
float
>
;
template
class
Pool2dFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
a
vgPool
<
float
>,
float
>
;
paddle
::
operators
::
math
::
A
vgPool
<
float
>,
float
>
;
template
class
Pool2dGradFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
m
axPoolGrad
<
float
>,
float
>
;
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
M
axPoolGrad
<
float
>,
float
>
;
template
class
Pool2dGradFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
a
vgPoolGrad
<
float
>,
float
>
;
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
A
vgPoolGrad
<
float
>,
float
>
;
template
class
Pool2dFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
m
axPool
<
double
>,
double
>
;
paddle
::
operators
::
math
::
M
axPool
<
double
>,
double
>
;
template
class
Pool2dFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
a
vgPool
<
double
>,
double
>
;
paddle
::
operators
::
math
::
A
vgPool
<
double
>,
double
>
;
template
class
Pool2dGradFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
m
axPoolGrad
<
double
>,
double
>
;
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
M
axPoolGrad
<
double
>,
double
>
;
template
class
Pool2dGradFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
a
vgPoolGrad
<
double
>,
double
>
;
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
A
vgPoolGrad
<
double
>,
double
>
;
template
<
typename
PoolProcess
,
class
T
>
class
Pool3dFunctor
<
platform
::
CPUPlace
,
PoolProcess
,
T
>
{
...
...
@@ -220,7 +222,7 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_
compute
)
{
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_
process
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -260,11 +262,11 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
int
wend
=
std
::
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
std
::
max
(
wstart
,
0
);
int
output_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
T
ele
=
pool_
compute
.
initial
();
T
ele
=
pool_
process
.
initial
();
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_
compute
.
compute
(
pool_
process
.
compute
(
ele
,
input_data
[(
d
*
input_height
+
h
)
*
input_width
+
w
]);
}
...
...
@@ -272,7 +274,7 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
}
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
pool_
compute
.
finalize
(
ele
,
static_cast
<
T
>
(
pool_size
));
pool_
process
.
finalize
(
ele
,
static_cast
<
T
>
(
pool_size
));
output_data
[
output_idx
]
=
ele
;
}
}
...
...
@@ -292,7 +294,7 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_
compute
)
{
PoolProcess
pool_
grad_process
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
...
...
@@ -343,7 +345,7 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
int
input_idx
=
(
d
*
input_height
+
h
)
*
input_width
+
w
;
int
output_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
pool_
compute
.
compute
(
pool_
grad_process
.
compute
(
input_data
[
input_idx
],
output_data
[
output_idx
],
output_grad_data
[
output_idx
],
input_grad_data
[
input_idx
],
static_cast
<
T
>
(
scale
));
...
...
@@ -441,21 +443,21 @@ template class MaxPool3dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
template
class
Pool3dFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
m
axPool
<
float
>,
float
>
;
paddle
::
operators
::
math
::
M
axPool
<
float
>,
float
>
;
template
class
Pool3dFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
a
vgPool
<
float
>,
float
>
;
paddle
::
operators
::
math
::
A
vgPool
<
float
>,
float
>
;
template
class
Pool3dGradFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
m
axPoolGrad
<
float
>,
float
>
;
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
M
axPoolGrad
<
float
>,
float
>
;
template
class
Pool3dGradFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
a
vgPoolGrad
<
float
>,
float
>
;
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
A
vgPoolGrad
<
float
>,
float
>
;
template
class
Pool3dFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
m
axPool
<
double
>,
double
>
;
paddle
::
operators
::
math
::
M
axPool
<
double
>,
double
>
;
template
class
Pool3dFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
a
vgPool
<
double
>,
double
>
;
paddle
::
operators
::
math
::
A
vgPool
<
double
>,
double
>
;
template
class
Pool3dGradFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
m
axPoolGrad
<
double
>,
double
>
;
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
M
axPoolGrad
<
double
>,
double
>
;
template
class
Pool3dGradFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
a
vgPoolGrad
<
double
>,
double
>
;
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
A
vgPoolGrad
<
double
>,
double
>
;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/pooling.cu
浏览文件 @
e1e3859e
...
...
@@ -20,14 +20,16 @@ namespace operators {
namespace
math
{
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool2dForward
(
const
int
nthreads
,
const
T
*
input_data
,
T
*
output_data
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_compute
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
__global__
void
KernelPool2D
(
const
int
nthreads
,
const
T
*
input_data
,
T
*
output_data
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
int
ph
=
(
index
/
output_width
)
%
output_height
;
int
c
=
(
index
/
output_width
/
output_height
)
%
channels
;
...
...
@@ -42,28 +44,28 @@ __global__ void KernelPool2dForward(
wstart
=
max
(
wstart
,
0
);
input_data
+=
(
batch_idx
*
channels
+
c
)
*
input_height
*
input_width
;
T
ele
=
pool_
compute
.
initial
();
T
ele
=
pool_
process
.
initial
();
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_
compute
.
compute
(
ele
,
input_data
[
h
*
input_width
+
w
]);
pool_
process
.
compute
(
ele
,
input_data
[
h
*
input_width
+
w
]);
}
}
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
pool_
compute
.
finalize
(
ele
,
(
static_cast
<
T
>
(
pool_size
)));
pool_
process
.
finalize
(
ele
,
(
static_cast
<
T
>
(
pool_size
)));
output_data
[
index
]
=
ele
;
}
}
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool2
dBackwar
d
(
__global__
void
KernelPool2
DGra
d
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
T
*
input_grad
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_
compute
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
const
int
padding_width
,
PoolProcess
pool_
process
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
offsetW
=
index
%
input_width
+
padding_width
;
int
offsetH
=
(
index
/
input_width
)
%
input_height
+
padding_height
;
int
offsetC
=
(
index
/
input_width
/
input_height
)
%
channels
;
...
...
@@ -93,7 +95,7 @@ __global__ void KernelPool2dBackward(
wstart
=
max
(
wstart
,
0
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
output_sub_idx
=
ph
*
output_width
+
pw
;
pool_
compute
.
compute
(
input
,
output_data
[
output_sub_idx
],
pool_
process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
gradient
,
static_cast
<
T
>
(
1.0
/
pool_size
));
}
...
...
@@ -103,15 +105,15 @@ __global__ void KernelPool2dBackward(
}
template
<
typename
T
>
__global__
void
KernelMaxPool2
dBackwar
d
(
__global__
void
KernelMaxPool2
DGra
d
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
T
*
input_grad
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
int
ph
=
(
index
/
output_width
)
%
output_height
;
int
c
=
(
index
/
output_width
/
output_height
)
%
channels
;
...
...
@@ -153,7 +155,7 @@ class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> {
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_
compute
)
{
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_
process
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -176,7 +178,7 @@ class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> {
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelPool2
dForward
<
KernelPool2
D
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
...
...
@@ -184,7 +186,7 @@ class Pool2dFunctor<platform::GPUPlace, PoolProcess, T> {
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_
compute
);
padding_width
,
pool_
process
);
}
};
...
...
@@ -196,7 +198,7 @@ class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> {
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_
compute
)
{
PoolProcess
pool_
process
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
...
...
@@ -220,7 +222,7 @@ class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> {
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelPool2
dBackwar
d
<
KernelPool2
DGra
d
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
...
...
@@ -228,7 +230,7 @@ class Pool2dGradFunctor<platform::GPUPlace, PoolProcess, T> {
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_grad_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_
compute
);
padding_width
,
pool_
process
);
}
};
...
...
@@ -264,7 +266,7 @@ class MaxPool2dGradFunctor<platform::GPUPlace, T> {
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelMaxPool2
dBackwar
d
<
KernelMaxPool2
DGra
d
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
...
...
@@ -276,35 +278,37 @@ class MaxPool2dGradFunctor<platform::GPUPlace, T> {
};
template
class
MaxPool2dGradFunctor
<
platform
::
GPUPlace
,
float
>;
// template class MaxPool2dGradFunctor<platform::GPUPlace, double>;
// template class MaxPool2dGradFunctor<platform::GPUPlace, double>; // The
// 64-bit floating-point version of atomicAdd() is only supported by devices of
// compute capability 6.x and higher.
template
class
Pool2dFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
m
axPool
<
float
>,
float
>
;
paddle
::
operators
::
math
::
M
axPool
<
float
>,
float
>
;
template
class
Pool2dFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
a
vgPool
<
float
>,
float
>
;
paddle
::
operators
::
math
::
A
vgPool
<
float
>,
float
>
;
template
class
Pool2dGradFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
m
axPoolGrad
<
float
>,
float
>
;
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
M
axPoolGrad
<
float
>,
float
>
;
template
class
Pool2dGradFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
a
vgPoolGrad
<
float
>,
float
>
;
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
A
vgPoolGrad
<
float
>,
float
>
;
template
class
Pool2dFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
m
axPool
<
double
>,
double
>
;
paddle
::
operators
::
math
::
M
axPool
<
double
>,
double
>
;
template
class
Pool2dFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
a
vgPool
<
double
>,
double
>
;
paddle
::
operators
::
math
::
A
vgPool
<
double
>,
double
>
;
template
class
Pool2dGradFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
m
axPoolGrad
<
double
>,
double
>
;
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
M
axPoolGrad
<
double
>,
double
>
;
template
class
Pool2dGradFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
a
vgPoolGrad
<
double
>,
double
>
;
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
A
vgPoolGrad
<
double
>,
double
>
;
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool3D
Forward
(
__global__
void
KernelPool3D
(
const
int
nthreads
,
const
T
*
input_data
,
T
*
output_data
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_
compute
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
)
;
PoolProcess
pool_
process
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
int
ph
=
(
index
/
output_width
)
%
output_height
;
...
...
@@ -321,25 +325,25 @@ __global__ void KernelPool3DForward(
dstart
=
max
(
dstart
,
0
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
T
ele
=
pool_
compute
.
initial
();
T
ele
=
pool_
process
.
initial
();
input_data
+=
(
batch_idx
*
channels
+
c
)
*
input_depth
*
input_height
*
input_width
;
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
pool_
compute
.
compute
(
pool_
process
.
compute
(
ele
,
input_data
[(
d
*
input_height
+
h
)
*
input_width
+
w
]);
}
}
}
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
pool_
compute
.
finalize
(
ele
,
static_cast
<
T
>
(
pool_size
));
pool_
process
.
finalize
(
ele
,
static_cast
<
T
>
(
pool_size
));
output_data
[
index
]
=
ele
;
}
}
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool3D
Backwar
d
(
__global__
void
KernelPool3D
Gra
d
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
T
*
input_grad
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
...
...
@@ -347,8 +351,8 @@ __global__ void KernelPool3DBackward(
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_
compute
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
)
;
PoolProcess
pool_
process
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
offsetW
=
index
%
input_width
+
padding_width
;
int
offsetH
=
(
index
/
input_width
)
%
input_height
+
padding_height
;
...
...
@@ -392,7 +396,7 @@ __global__ void KernelPool3DBackward(
wstart
=
max
(
wstart
,
0
);
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
output_sub_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
pool_
compute
.
compute
(
input
,
output_data
[
output_sub_idx
],
pool_
process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
gradient
,
static_cast
<
T
>
(
1.0
/
pool_size
));
}
...
...
@@ -403,7 +407,7 @@ __global__ void KernelPool3DBackward(
}
template
<
typename
T
>
__global__
void
KernelMaxPool3D
Backwar
d
(
__global__
void
KernelMaxPool3D
Gra
d
(
const
int
nthreads
,
const
T
*
input_data
,
const
T
*
output_data
,
const
T
*
output_grad
,
T
*
input_grad
,
const
int
channels
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
...
...
@@ -412,7 +416,7 @@ __global__ void KernelMaxPool3DBackward(
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
)
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
int
ph
=
(
index
/
output_width
)
%
output_height
;
...
...
@@ -460,7 +464,7 @@ class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> {
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_
compute
)
{
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_
process
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -489,7 +493,7 @@ class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> {
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelPool3D
Forward
<
KernelPool3D
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
...
...
@@ -498,7 +502,7 @@ class Pool3dFunctor<platform::GPUPlace, PoolProcess, T> {
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
pool_
compute
);
pool_
process
);
}
};
...
...
@@ -510,7 +514,7 @@ class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> {
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_
compute
)
{
PoolProcess
pool_
process
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
...
@@ -541,7 +545,7 @@ class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> {
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelPool3D
Backwar
d
<
KernelPool3D
Gra
d
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
...
...
@@ -550,7 +554,7 @@ class Pool3dGradFunctor<platform::GPUPlace, PoolProcess, T> {
input_channels
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
pool_
compute
);
padding_height
,
padding_width
,
pool_
process
);
}
};
...
...
@@ -592,7 +596,7 @@ class MaxPool3dGradFunctor<platform::GPUPlace, T> {
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocks
,
1
);
KernelMaxPool3D
Backwar
d
<
KernelMaxPool3D
Gra
d
<
T
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
...
...
@@ -605,24 +609,26 @@ class MaxPool3dGradFunctor<platform::GPUPlace, T> {
};
template
class
MaxPool3dGradFunctor
<
platform
::
GPUPlace
,
float
>;
// template class MaxPool3dGradFunctor<platform::GPUPlace, double>;
// template class MaxPool3dGradFunctor<platform::GPUPlace, double>; // The
// 64-bit floating-point version of atomicAdd() is only supported by devices of
// compute capability 6.x and higher.
template
class
Pool3dFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
m
axPool
<
float
>,
float
>
;
paddle
::
operators
::
math
::
M
axPool
<
float
>,
float
>
;
template
class
Pool3dFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
a
vgPool
<
float
>,
float
>
;
paddle
::
operators
::
math
::
A
vgPool
<
float
>,
float
>
;
template
class
Pool3dGradFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
m
axPoolGrad
<
float
>,
float
>
;
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
M
axPoolGrad
<
float
>,
float
>
;
template
class
Pool3dGradFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
a
vgPoolGrad
<
float
>,
float
>
;
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
A
vgPoolGrad
<
float
>,
float
>
;
template
class
Pool3dFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
m
axPool
<
double
>,
double
>
;
paddle
::
operators
::
math
::
M
axPool
<
double
>,
double
>
;
template
class
Pool3dFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
a
vgPool
<
double
>,
double
>
;
paddle
::
operators
::
math
::
A
vgPool
<
double
>,
double
>
;
template
class
Pool3dGradFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
m
axPoolGrad
<
double
>,
double
>
;
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
M
axPoolGrad
<
double
>,
double
>
;
template
class
Pool3dGradFunctor
<
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
a
vgPoolGrad
<
double
>,
double
>
;
platform
::
GPUPlace
,
paddle
::
operators
::
math
::
A
vgPoolGrad
<
double
>,
double
>
;
}
// namespace math
}
// namespace operators
...
...
paddle/operators/math/pooling.h
浏览文件 @
e1e3859e
...
...
@@ -22,11 +22,10 @@ namespace paddle {
namespace
operators
{
namespace
math
{
//////////////////////
#define FLT_MAX __FLT_MAX__
/////////////////////
#define FLT_MAX __FLT_MAX__ //
template
<
class
T
>
class
m
axPool
{
class
M
axPool
{
public:
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
-
FLT_MAX
);
}
DEVICE
inline
void
compute
(
T
&
y
,
const
T
&
x
)
{
y
=
y
>
x
?
y
:
x
;
}
...
...
@@ -34,14 +33,14 @@ class maxPool {
};
template
<
class
T
>
class
a
vgPool
{
class
A
vgPool
{
public:
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
0
);
}
DEVICE
inline
void
compute
(
T
&
y
,
const
T
&
x
)
{
y
+=
x
;
}
DEVICE
inline
void
finalize
(
T
&
y
,
const
T
&
poo_size
)
{
y
/=
poo_size
;
}
};
template
<
class
T
>
class
m
axPoolGrad
{
class
M
axPoolGrad
{
public:
DEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
&
dx
,
T
scale
)
{
...
...
@@ -50,7 +49,7 @@ class maxPoolGrad {
};
template
<
class
T
>
class
a
vgPoolGrad
{
class
A
vgPoolGrad
{
public:
DEVICE
inline
void
compute
(
const
T
&
x
,
const
T
&
y
,
const
T
&
dy
,
T
&
dx
,
T
scale
)
{
...
...
paddle/operators/pool_op.cc
浏览文件 @
e1e3859e
...
...
@@ -51,7 +51,7 @@ class PoolOp : public framework::OperatorWithKernel {
ksize
[
i
]
=
static_cast
<
int
>
(
in_x_dims
[
i
+
2
]);
}
PADDLE_ENFORCE
(
in_x_dims
.
size
()
-
ksize
.
size
()
==
2
,
PADDLE_ENFORCE
(
in_x_dims
.
size
()
-
ksize
.
size
()
==
2
U
,
"Input size and Pooling size should be consistent."
);
PADDLE_ENFORCE
(
ksize
.
size
()
==
2
||
ksize
.
size
()
==
3
,
"Pooling size should be 2 elements. or 3 elements."
);
...
...
@@ -79,7 +79,6 @@ class PoolOpGrad : public framework::OperatorWithKernel {
"X(Input) of Pooling should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Input@Grad of Pooling should not be null."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
};
...
...
@@ -98,66 +97,36 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"The format of output tensor is also NCHW."
);
AddAttr
<
std
::
string
>
(
"poolingType"
,
"poolingType of pooling operator."
"str constant equal to 'max' or 'avg'"
);
"PoolingType of pooling operator."
"Str constant equal to 'max' or 'avg'."
)
.
InEnum
({
"max"
,
"avg"
});
AddAttr
<
std
::
vector
<
int
>>
(
"ksize"
,
"Pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be specified."
);
"If globalPooling = true, ksize is ignored and need not be "
"specified."
);
// TODO(Add checker)
AddAttr
<
bool
>
(
"globalPooling"
,
"
w
hether to use the globalPooling."
"
int constant equal to false or true
"
"
default false
"
"
W
hether to use the globalPooling."
"
Bool constant equal to false or true.
"
"
Default false.
"
"If globalPooling = true, ksize is ignored and need not be specified."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"strides(height, width) of pooling operator."
"default {1,1}"
)
.
SetDefault
({
1
,
1
})
.
AddCustomChecker
(
GreaterThanChecker_pool
({
0
,
0
}));
"Strides(height, width) of pooling operator."
"Default {1,1}"
)
.
SetDefault
({
1
,
1
});
// TODO(Add checker)
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"paddings(height, width) of pooling operator."
"default {0,0}"
)
.
SetDefault
({
0
,
0
})
.
AddCustomChecker
(
EqualGreaterThanChecker_pool
({
0
,
0
}));
"Paddings(height, width) of pooling operator."
"Default {0,0}."
)
.
SetDefault
({
0
,
0
});
// TODO(Add checker)
AddComment
(
R"DOC(
The pooling2d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters.
)DOC"
);
}
private:
struct
GreaterThanChecker_pool
{
public:
explicit
GreaterThanChecker_pool
(
std
::
vector
<
int
>
lower_bound
)
:
lower_bound_
(
lower_bound
)
{}
void
operator
()(
std
::
vector
<
int
>
&
value
)
const
{
PADDLE_ENFORCE
(
value
.
size
()
==
lower_bound_
.
size
(),
"equal check fails."
);
for
(
size_t
i
=
0
;
i
<
value
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
value
[
i
]
>
lower_bound_
[
i
],
"larger_than check fails."
);
}
}
private:
std
::
vector
<
int
>
lower_bound_
;
};
struct
EqualGreaterThanChecker_pool
{
public:
explicit
EqualGreaterThanChecker_pool
(
std
::
vector
<
int
>
lower_bound
)
:
lower_bound_
(
lower_bound
)
{}
void
operator
()(
std
::
vector
<
int
>
&
value
)
const
{
PADDLE_ENFORCE
(
value
.
size
()
==
lower_bound_
.
size
(),
"equal check fails."
);
for
(
size_t
i
=
0
;
i
<
value
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
value
[
i
]
>=
lower_bound_
[
i
],
"larger_than check fails."
);
}
}
private:
std
::
vector
<
int
>
lower_bound_
;
};
};
class
Pool3dOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
Pool3dOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
...
...
@@ -173,67 +142,36 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
"The format of output tensor is also NCDHW."
);
AddAttr
<
std
::
string
>
(
"poolingType"
,
"poolingType of pooling operator."
"str constant equal to 'max' or 'avg'"
);
"PoolingType of pooling operator."
"str constant equal to 'max' or 'avg'."
)
.
InEnum
({
"max"
,
"avg"
});
AddAttr
<
std
::
vector
<
int
>>
(
"ksize"
,
"pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be specified."
);
"Pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."
);
// TODO(Add checker)
AddAttr
<
bool
>
(
"globalPooling"
,
"
w
hether to use the globalPooling."
"
int constant equal to false or true
"
"
default false
"
"
W
hether to use the globalPooling."
"
Bool constant equal to false or true.
"
"
Default false.
"
"If globalPooling = true, ksize is ignored and need not be specified."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"strides(depth, height, width) of pooling operator."
"default {1,1,1}"
)
.
SetDefault
({
1
,
1
,
1
})
.
AddCustomChecker
(
GreaterThanChecker_pool
({
0
,
0
,
0
}));
"Strides(depth, height, width) of pooling operator."
"Default {1,1,1}."
)
.
SetDefault
({
1
,
1
,
1
});
// TODO(Add checker)
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"paddings(depth, height, width) of pooling operator."
"default {0,0,0}"
)
.
SetDefault
({
0
,
0
,
0
})
.
AddCustomChecker
(
EqualGreaterThanChecker_pool
({
0
,
0
,
0
}));
"Paddings(depth, height, width) of pooling operator."
"Default {0,0,0}."
)
.
SetDefault
({
0
,
0
,
0
});
// TODO(Add checker)
AddComment
(
R"DOC(
The pooling3d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters.
)DOC"
);
}
private:
struct
GreaterThanChecker_pool
{
public:
explicit
GreaterThanChecker_pool
(
std
::
vector
<
int
>
lower_bound
)
:
lower_bound_
(
lower_bound
)
{}
void
operator
()(
std
::
vector
<
int
>
&
value
)
const
{
PADDLE_ENFORCE
(
value
.
size
()
==
lower_bound_
.
size
(),
"equal check fails."
);
for
(
size_t
i
=
0
;
i
<
value
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
value
[
i
]
>
lower_bound_
[
i
],
"larger_than check fails."
);
}
}
private:
std
::
vector
<
int
>
lower_bound_
;
};
struct
EqualGreaterThanChecker_pool
{
public:
explicit
EqualGreaterThanChecker_pool
(
std
::
vector
<
int
>
lower_bound
)
:
lower_bound_
(
lower_bound
)
{}
void
operator
()(
std
::
vector
<
int
>
&
value
)
const
{
PADDLE_ENFORCE
(
value
.
size
()
==
lower_bound_
.
size
(),
"equal check fails."
);
for
(
size_t
i
=
0
;
i
<
value
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
value
[
i
]
>=
lower_bound_
[
i
],
"larger_than check fails."
);
}
}
private:
std
::
vector
<
int
>
lower_bound_
;
};
};
}
// namespace operators
}
// namespace paddle
...
...
paddle/operators/pool_op.h
浏览文件 @
e1e3859e
...
...
@@ -31,12 +31,11 @@ class PoolKernel : public framework::OpKernel {
const
Tensor
*
in_x
=
context
.
Input
<
Tensor
>
(
"X"
);
Tensor
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
bool
global_pooling
=
context
.
Attr
<
bool
>
(
"globalPooling"
);
std
::
string
pooling_type
=
context
.
Attr
<
std
::
string
>
(
"poolingType"
);
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
if
(
global_pooling
)
{
if
(
context
.
Attr
<
bool
>
(
"globalPooling"
)
)
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
ksize
[
i
]
=
static_cast
<
int
>
(
in_x
->
dims
()[
i
+
2
]);
}
...
...
@@ -46,17 +45,17 @@ class PoolKernel : public framework::OpKernel {
case
2
:
{
if
(
pooling_type
==
"max"
)
{
paddle
::
operators
::
math
::
Pool2dFunctor
<
Place
,
paddle
::
operators
::
math
::
m
axPool
<
T
>
,
T
>
Place
,
paddle
::
operators
::
math
::
M
axPool
<
T
>
,
T
>
pool2d_forward
;
paddle
::
operators
::
math
::
m
axPool
<
T
>
pool_process
;
paddle
::
operators
::
math
::
M
axPool
<
T
>
pool_process
;
pool2d_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
ksize
,
strides
,
paddings
,
pool_process
);
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool2dFunctor
<
Place
,
paddle
::
operators
::
math
::
a
vgPool
<
T
>
,
T
>
Place
,
paddle
::
operators
::
math
::
A
vgPool
<
T
>
,
T
>
pool2d_forward
;
paddle
::
operators
::
math
::
a
vgPool
<
T
>
pool_process
;
paddle
::
operators
::
math
::
A
vgPool
<
T
>
pool_process
;
pool2d_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
ksize
,
strides
,
paddings
,
pool_process
);
}
...
...
@@ -64,16 +63,16 @@ class PoolKernel : public framework::OpKernel {
case
3
:
{
if
(
pooling_type
==
"max"
)
{
paddle
::
operators
::
math
::
Pool3dFunctor
<
Place
,
paddle
::
operators
::
math
::
m
axPool
<
T
>
,
T
>
Place
,
paddle
::
operators
::
math
::
M
axPool
<
T
>
,
T
>
pool3d_forward
;
paddle
::
operators
::
math
::
m
axPool
<
T
>
pool_process
;
paddle
::
operators
::
math
::
M
axPool
<
T
>
pool_process
;
pool3d_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
ksize
,
strides
,
paddings
,
pool_process
);
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool3dFunctor
<
Place
,
paddle
::
operators
::
math
::
a
vgPool
<
T
>
,
T
>
Place
,
paddle
::
operators
::
math
::
A
vgPool
<
T
>
,
T
>
pool3d_forward
;
paddle
::
operators
::
math
::
a
vgPool
<
T
>
pool_process
;
paddle
::
operators
::
math
::
A
vgPool
<
T
>
pool_process
;
pool3d_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
ksize
,
strides
,
paddings
,
pool_process
);
}
...
...
@@ -92,13 +91,12 @@ class PoolGradKernel : public framework::OpKernel {
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
Tensor
*
in_x_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
bool
global_pooling
=
context
.
Attr
<
bool
>
(
"globalPooling"
);
std
::
string
pooling_type
=
context
.
Attr
<
std
::
string
>
(
"poolingType"
);
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
if
(
global_pooling
)
{
if
(
context
.
Attr
<
bool
>
(
"globalPooling"
)
)
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
ksize
[
i
]
=
static_cast
<
int
>
(
in_x
->
dims
()[
i
+
2
]);
}
...
...
@@ -118,9 +116,9 @@ class PoolGradKernel : public framework::OpKernel {
*
out_grad
,
ksize
,
strides
,
paddings
);
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool2dGradFunctor
<
Place
,
paddle
::
operators
::
math
::
a
vgPoolGrad
<
T
>
,
T
>
Place
,
paddle
::
operators
::
math
::
A
vgPoolGrad
<
T
>
,
T
>
pool2d_backward
;
paddle
::
operators
::
math
::
a
vgPoolGrad
<
T
>
pool_process
;
paddle
::
operators
::
math
::
A
vgPoolGrad
<
T
>
pool_process
;
pool2d_backward
(
context
.
device_context
(),
*
in_x
,
*
in_x_grad
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
pool_process
);
}
...
...
@@ -133,9 +131,9 @@ class PoolGradKernel : public framework::OpKernel {
*
out_grad
,
ksize
,
strides
,
paddings
);
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool3dGradFunctor
<
Place
,
paddle
::
operators
::
math
::
a
vgPoolGrad
<
T
>
,
T
>
Place
,
paddle
::
operators
::
math
::
A
vgPoolGrad
<
T
>
,
T
>
pool3d_backward
;
paddle
::
operators
::
math
::
a
vgPoolGrad
<
T
>
pool_process
;
paddle
::
operators
::
math
::
A
vgPoolGrad
<
T
>
pool_process
;
pool3d_backward
(
context
.
device_context
(),
*
in_x
,
*
in_x_grad
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
pool_process
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录