Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
daed473d
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
daed473d
编写于
11月 03, 2018
作者:
K
Kaipeng Deng
提交者:
GitHub
11月 03, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14089 from heavengate/pool_exclude
add inclusive/exclusive mode in avg pool
上级
64f3e3ed
da8ee1fb
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
176 addition
and
76 deletion
+176
-76
paddle/fluid/API.spec
paddle/fluid/API.spec
+2
-2
paddle/fluid/operators/math/pooling.cc
paddle/fluid/operators/math/pooling.cc
+14
-8
paddle/fluid/operators/math/pooling.cu
paddle/fluid/operators/math/pooling.cu
+30
-25
paddle/fluid/operators/math/pooling.h
paddle/fluid/operators/math/pooling.h
+4
-4
paddle/fluid/operators/pool_cudnn_op.cu.cc
paddle/fluid/operators/pool_cudnn_op.cu.cc
+6
-2
paddle/fluid/operators/pool_op.cc
paddle/fluid/operators/pool_op.cc
+29
-0
paddle/fluid/operators/pool_op.h
paddle/fluid/operators/pool_op.h
+8
-6
paddle/fluid/operators/spp_op.h
paddle/fluid/operators/spp_op.h
+5
-3
paddle/fluid/platform/cudnn_helper.h
paddle/fluid/platform/cudnn_helper.h
+8
-3
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+15
-7
python/paddle/fluid/tests/unittests/test_pool2d_op.py
python/paddle/fluid/tests/unittests/test_pool2d_op.py
+27
-8
python/paddle/fluid/tests/unittests/test_pool3d_op.py
python/paddle/fluid/tests/unittests/test_pool3d_op.py
+28
-8
未找到文件。
paddle/fluid/API.spec
浏览文件 @
daed473d
...
@@ -67,8 +67,8 @@ paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size',
...
@@ -67,8 +67,8 @@ paddle.fluid.layers.conv3d ArgSpec(args=['input', 'num_filters', 'filter_size',
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type', 'is_test'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.sequence_pool ArgSpec(args=['input', 'pool_type', 'is_test'], varargs=None, keywords=None, defaults=(False,))
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.sequence_softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(False, None))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.softmax ArgSpec(args=['input', 'use_cudnn', 'name'], varargs=None, keywords=None, defaults=(True, None))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'
], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, Non
e))
paddle.fluid.layers.pool2d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'
, 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, Tru
e))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'
], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, Non
e))
paddle.fluid.layers.pool3d ArgSpec(args=['input', 'pool_size', 'pool_type', 'pool_stride', 'pool_padding', 'global_pooling', 'use_cudnn', 'ceil_mode', 'name'
, 'exclusive'], varargs=None, keywords=None, defaults=(-1, 'max', 1, 0, False, True, False, None, Tru
e))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False))
paddle.fluid.layers.batch_norm ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.beam_search_decode ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
paddle.fluid.layers.conv2d_transpose ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None))
...
...
paddle/fluid/operators/math/pooling.cc
浏览文件 @
daed473d
...
@@ -31,7 +31,7 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
...
@@ -31,7 +31,7 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
bool
exclusive
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
input_width
=
input
.
dims
()[
3
];
...
@@ -68,7 +68,8 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
...
@@ -68,7 +68,8 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
pool_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
&
ele
);
pool_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
&
ele
);
}
}
}
}
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
ph
*
output_width
+
pw
]
=
ele
;
output_data
[
ph
*
output_width
+
pw
]
=
ele
;
}
}
...
@@ -93,7 +94,7 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
...
@@ -93,7 +94,7 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
bool
exclusive
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
input_width
=
input
.
dims
()[
3
];
...
@@ -124,7 +125,8 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
...
@@ -124,7 +125,8 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wend
=
std
::
min
(
wstart
+
ksize_width
,
input_width
);
int
wend
=
std
::
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
std
::
max
(
wstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
float
scale
=
1.0
/
pool_size
;
float
scale
=
1.0
/
pool_size
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
...
@@ -249,7 +251,7 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
...
@@ -249,7 +251,7 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
bool
exclusive
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
const
int
input_height
=
input
.
dims
()[
3
];
...
@@ -300,7 +302,9 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
...
@@ -300,7 +302,9 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
}
}
}
}
int
pool_size
=
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
exclusive
?
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_depth
*
ksize_height
*
ksize_width
;
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
output_idx
]
=
ele
;
output_data
[
output_idx
]
=
ele
;
}
}
...
@@ -326,7 +330,7 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
...
@@ -326,7 +330,7 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
output_grad
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_grad_process
,
framework
::
Tensor
*
input_grad
)
{
bool
exclusive
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
const
int
input_height
=
input
.
dims
()[
3
];
...
@@ -369,7 +373,9 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
...
@@ -369,7 +373,9 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
wstart
=
std
::
max
(
wstart
,
0
);
wstart
=
std
::
max
(
wstart
,
0
);
int
pool_size
=
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
exclusive
?
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_depth
*
ksize_height
*
ksize_width
;
float
scale
=
1.0
/
pool_size
;
float
scale
=
1.0
/
pool_size
;
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
...
...
paddle/fluid/operators/math/pooling.cu
浏览文件 @
daed473d
...
@@ -29,7 +29,7 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
...
@@ -29,7 +29,7 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
const
int
ksize_width
,
const
int
stride_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
output_data
)
{
bool
exclusive
,
T
*
output_data
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
int
pw
=
index
%
output_width
;
...
@@ -52,7 +52,8 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
...
@@ -52,7 +52,8 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
pool_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
&
ele
);
pool_process
.
compute
(
input_data
[
h
*
input_width
+
w
],
&
ele
);
}
}
}
}
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
index
]
=
ele
;
output_data
[
index
]
=
ele
;
}
}
...
@@ -65,7 +66,7 @@ __global__ void KernelPool2DGrad(
...
@@ -65,7 +66,7 @@ __global__ void KernelPool2DGrad(
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
input_width
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
input_grad
)
{
PoolProcess
pool_process
,
bool
exclusive
,
T
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
offsetW
=
index
%
input_width
+
padding_width
;
int
offsetW
=
index
%
input_width
+
padding_width
;
...
@@ -95,7 +96,8 @@ __global__ void KernelPool2DGrad(
...
@@ -95,7 +96,8 @@ __global__ void KernelPool2DGrad(
int
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
int
wend
=
min
(
wstart
+
ksize_width
,
input_width
);
hstart
=
max
(
hstart
,
0
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
wstart
=
max
(
wstart
,
0
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_height
*
ksize_width
;
int
output_sub_idx
=
ph
*
output_width
+
pw
;
int
output_sub_idx
=
ph
*
output_width
+
pw
;
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
...
@@ -163,7 +165,7 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
...
@@ -163,7 +165,7 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
bool
exclusive
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
2
];
...
@@ -189,7 +191,8 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
...
@@ -189,7 +191,8 @@ class Pool2dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
KernelPool2D
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
KernelPool2D
<
PoolProcess
,
T
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
nthreads
,
input_data
,
input_channels
,
input_height
,
input_width
,
nthreads
,
input_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_process
,
output_data
);
stride_width
,
padding_height
,
padding_width
,
pool_process
,
exclusive
,
output_data
);
}
}
};
};
...
@@ -208,7 +211,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
...
@@ -208,7 +211,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
bool
exclusive
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
2
];
...
@@ -236,7 +239,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
...
@@ -236,7 +239,7 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_channels
,
nthreads
,
input_data
,
output_data
,
output_grad_data
,
input_channels
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
input_height
,
input_width
,
output_height
,
output_width
,
ksize_height
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
ksize_width
,
stride_height
,
stride_width
,
padding_height
,
padding_width
,
pool_process
,
input_grad_data
);
pool_process
,
exclusive
,
input_grad_data
);
}
}
};
};
...
@@ -313,16 +316,14 @@ template class Pool2dGradFunctor<platform::CUDADeviceContext,
...
@@ -313,16 +316,14 @@ template class Pool2dGradFunctor<platform::CUDADeviceContext,
double
>
;
double
>
;
template
<
typename
PoolProcess
,
typename
T
>
template
<
typename
PoolProcess
,
typename
T
>
__global__
void
KernelPool3D
(
const
int
nthreads
,
const
T
*
input_data
,
__global__
void
KernelPool3D
(
const
int
channels
,
const
int
input_depth
,
const
int
nthreads
,
const
T
*
input_data
,
const
int
channels
,
const
int
input_height
,
const
int
input_width
,
const
int
input_depth
,
const
int
input_height
,
const
int
input_width
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_depth
,
const
int
output_height
,
const
int
output_width
,
const
int
output_width
,
const
int
ksize_depth
,
const
int
ksize_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
const
int
stride_width
,
const
int
padding_depth
,
PoolProcess
pool_process
,
bool
exclusive
,
T
*
output_data
)
{
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
output_data
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
int
pw
=
index
%
output_width
;
...
@@ -351,7 +352,9 @@ __global__ void KernelPool3D(const int nthreads, const T* input_data,
...
@@ -351,7 +352,9 @@ __global__ void KernelPool3D(const int nthreads, const T* input_data,
}
}
}
}
}
}
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_depth
*
ksize_height
*
ksize_width
;
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
pool_process
.
finalize
(
static_cast
<
T
>
(
pool_size
),
&
ele
);
output_data
[
index
]
=
ele
;
output_data
[
index
]
=
ele
;
}
}
...
@@ -366,7 +369,7 @@ __global__ void KernelPool3DGrad(
...
@@ -366,7 +369,7 @@ __global__ void KernelPool3DGrad(
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
const
int
padding_height
,
const
int
padding_width
,
PoolProcess
pool_process
,
T
*
input_grad
)
{
bool
exclusive
,
T
*
input_grad
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
offsetW
=
index
%
input_width
+
padding_width
;
int
offsetW
=
index
%
input_width
+
padding_width
;
...
@@ -409,7 +412,9 @@ __global__ void KernelPool3DGrad(
...
@@ -409,7 +412,9 @@ __global__ void KernelPool3DGrad(
dstart
=
max
(
dstart
,
0
);
dstart
=
max
(
dstart
,
0
);
hstart
=
max
(
hstart
,
0
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
wstart
=
max
(
wstart
,
0
);
int
pool_size
=
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
pool_size
=
exclusive
?
(
dend
-
dstart
)
*
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
ksize_depth
*
ksize_height
*
ksize_width
;
int
output_sub_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
int
output_sub_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
pool_process
.
compute
(
input
,
output_data
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
output_grad
[
output_sub_idx
],
...
@@ -484,7 +489,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
...
@@ -484,7 +489,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
output
)
{
bool
exclusive
,
framework
::
Tensor
*
output
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
@@ -517,7 +522,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
...
@@ -517,7 +522,7 @@ class Pool3dFunctor<platform::CUDADeviceContext, PoolProcess, T> {
nthreads
,
input_data
,
input_channels
,
input_depth
,
input_height
,
nthreads
,
input_data
,
input_channels
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
pool_process
,
padding_depth
,
padding_height
,
padding_width
,
pool_process
,
exclusive
,
output_data
);
output_data
);
}
}
};
};
...
@@ -537,7 +542,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
...
@@ -537,7 +542,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_process
,
framework
::
Tensor
*
input_grad
)
{
bool
exclusive
,
framework
::
Tensor
*
input_grad
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_channels
=
input
.
dims
()[
1
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_depth
=
input
.
dims
()[
2
];
...
@@ -573,7 +578,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
...
@@ -573,7 +578,7 @@ class Pool3dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
input_depth
,
input_height
,
input_width
,
output_depth
,
output_height
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
output_width
,
ksize_depth
,
ksize_height
,
ksize_width
,
stride_depth
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
stride_height
,
stride_width
,
padding_depth
,
padding_height
,
padding_width
,
pool_process
,
input_grad_data
);
padding_width
,
pool_process
,
exclusive
,
input_grad_data
);
}
}
};
};
...
...
paddle/fluid/operators/math/pooling.h
浏览文件 @
daed473d
...
@@ -89,7 +89,7 @@ class Pool2dFunctor {
...
@@ -89,7 +89,7 @@ class Pool2dFunctor {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
output
);
bool
exclusive
,
framework
::
Tensor
*
output
);
};
};
template
<
typename
DeviceContext
,
typename
PoolProcess
,
typename
T
>
template
<
typename
DeviceContext
,
typename
PoolProcess
,
typename
T
>
...
@@ -101,7 +101,7 @@ class Pool2dGradFunctor {
...
@@ -101,7 +101,7 @@ class Pool2dGradFunctor {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
bool
exclusive
,
framework
::
Tensor
*
input_grad
);
};
};
template
<
typename
DeviceContext
,
class
T
>
template
<
typename
DeviceContext
,
class
T
>
...
@@ -123,7 +123,7 @@ class Pool3dFunctor {
...
@@ -123,7 +123,7 @@ class Pool3dFunctor {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
output
);
bool
exclusive
,
framework
::
Tensor
*
output
);
};
};
template
<
typename
DeviceContext
,
typename
PoolProcess
,
typename
T
>
template
<
typename
DeviceContext
,
typename
PoolProcess
,
typename
T
>
...
@@ -135,7 +135,7 @@ class Pool3dGradFunctor {
...
@@ -135,7 +135,7 @@ class Pool3dGradFunctor {
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
ksize
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
const
std
::
vector
<
int
>&
paddings
,
PoolProcess
pool_compute
,
framework
::
Tensor
*
input_grad
);
bool
exclusive
,
framework
::
Tensor
*
input_grad
);
};
};
template
<
typename
DeviceContext
,
class
T
>
template
<
typename
DeviceContext
,
class
T
>
...
...
paddle/fluid/operators/pool_cudnn_op.cu.cc
浏览文件 @
daed473d
...
@@ -41,6 +41,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
...
@@ -41,6 +41,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
bool
exclusive
=
ctx
.
Attr
<
bool
>
(
"exclusive"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
...
@@ -72,7 +73,8 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
...
@@ -72,7 +73,8 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
if
(
pooling_type
==
"max"
)
{
if
(
pooling_type
==
"max"
)
{
pooling_mode
=
PoolingMode
::
kMaximum
;
pooling_mode
=
PoolingMode
::
kMaximum
;
}
else
{
}
else
{
pooling_mode
=
PoolingMode
::
kAverage
;
pooling_mode
=
exclusive
?
PoolingMode
::
kAverageExclusive
:
PoolingMode
::
kAverageInclusive
;
}
}
cudnnPoolingDescriptor_t
cudnn_pool_desc
=
cudnnPoolingDescriptor_t
cudnn_pool_desc
=
...
@@ -101,6 +103,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
...
@@ -101,6 +103,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
Tensor
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
Tensor
*
input_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
std
::
string
pooling_type
=
ctx
.
Attr
<
std
::
string
>
(
"pooling_type"
);
bool
exclusive
=
ctx
.
Attr
<
bool
>
(
"exclusive"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
ksize
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
...
@@ -141,7 +144,8 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
...
@@ -141,7 +144,8 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
pooling_mode
=
PoolingMode
::
kMaximum
;
pooling_mode
=
PoolingMode
::
kMaximum
;
}
}
}
else
{
}
else
{
pooling_mode
=
PoolingMode
::
kAverage
;
pooling_mode
=
exclusive
?
PoolingMode
::
kAverageExclusive
:
PoolingMode
::
kAverageInclusive
;
}
}
cudnnPoolingDescriptor_t
cudnn_pool_desc
=
cudnnPoolingDescriptor_t
cudnn_pool_desc
=
...
...
paddle/fluid/operators/pool_op.cc
浏览文件 @
daed473d
...
@@ -180,6 +180,12 @@ void Pool2dOpMaker::Make() {
...
@@ -180,6 +180,12 @@ void Pool2dOpMaker::Make() {
"operator."
"operator."
"If global_pooling = true, paddings and ksize will be ignored."
)
"If global_pooling = true, paddings and ksize will be ignored."
)
.
SetDefault
({
0
,
0
});
.
SetDefault
({
0
,
0
});
AddAttr
<
bool
>
(
"exclusive"
,
"(bool, default True) When true, will exclude the zero-padding in the "
"averaging calculating, otherwise, include the zero-padding. Note, it "
"is only used when pooling_type is avg. The defalut is True."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
AddAttr
<
bool
>
(
"use_cudnn"
,
"use_cudnn"
,
"(bool, default false) Only used in cudnn kernel, need install cudnn"
)
"(bool, default false) Only used in cudnn kernel, need install cudnn"
)
...
@@ -236,6 +242,23 @@ Example:
...
@@ -236,6 +242,23 @@ Example:
W_{out} = \\frac{(W_{in} - ksize[1] + 2 * paddings[1] + strides[1] - 1)}{strides[1]} + 1
W_{out} = \\frac{(W_{in} - ksize[1] + 2 * paddings[1] + strides[1] - 1)}{strides[1]} + 1
$$
$$
For exclusive = true:
$$
hstart = i * strides[0] - paddings[0]
hend = hstart + ksize[0]
wstart = j * strides[1] - paddings[1]
wend = wstart + ksize[1]
Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{ksize[0] * ksize[1]}
$$
For exclusive = false:
$$
hstart = max(0, i * strides[0] - paddings[0])
hend = min(H, hstart + ksize[0])
wstart = max(0, j * strides[1] - paddings[1])
wend = min(W, wstart + ksize[1])
Output(i ,j) = \\frac{sum(Input[hstart:hend, wstart:wend])}{(hend - hstart) * (wend - wstart)}
$$
)DOC"
);
)DOC"
);
}
}
...
@@ -283,6 +306,12 @@ void Pool3dOpMaker::Make() {
...
@@ -283,6 +306,12 @@ void Pool3dOpMaker::Make() {
"If global_pooling = true, ksize and paddings will be ignored."
)
"If global_pooling = true, ksize and paddings will be ignored."
)
.
SetDefault
({
0
,
0
,
0
});
// TODO(Chengduo): Add checker. (Currently,
.
SetDefault
({
0
,
0
,
0
});
// TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
// TypedAttrChecker don't support vector type.)
AddAttr
<
bool
>
(
"exclusive"
,
"(bool, default True) When true, will exclude the zero-padding in the "
"averaging calculating, otherwise, include the zero-padding. Note, it "
"is only used when pooling_type is avg. The defalut is True."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
AddAttr
<
bool
>
(
"use_cudnn"
,
"use_cudnn"
,
...
...
paddle/fluid/operators/pool_op.h
浏览文件 @
daed473d
...
@@ -69,6 +69,7 @@ class PoolKernel : public framework::OpKernel<T> {
...
@@ -69,6 +69,7 @@ class PoolKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
bool
exclusive
=
context
.
Attr
<
bool
>
(
"exclusive"
);
if
(
context
.
Attr
<
bool
>
(
"global_pooling"
))
{
if
(
context
.
Attr
<
bool
>
(
"global_pooling"
))
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
paddings
[
i
]
=
0
;
paddings
[
i
]
=
0
;
...
@@ -84,7 +85,7 @@ class PoolKernel : public framework::OpKernel<T> {
...
@@ -84,7 +85,7 @@ class PoolKernel : public framework::OpKernel<T> {
pool2d_forward
;
pool2d_forward
;
paddle
::
operators
::
math
::
MaxPool
<
T
>
pool_process
;
paddle
::
operators
::
math
::
MaxPool
<
T
>
pool_process
;
pool2d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
pool2d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
true
,
out
);
}
else
if
(
pooling_type
==
"avg"
)
{
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool2dFunctor
<
paddle
::
operators
::
math
::
Pool2dFunctor
<
...
@@ -92,7 +93,7 @@ class PoolKernel : public framework::OpKernel<T> {
...
@@ -92,7 +93,7 @@ class PoolKernel : public framework::OpKernel<T> {
pool2d_forward
;
pool2d_forward
;
paddle
::
operators
::
math
::
AvgPool
<
T
>
pool_process
;
paddle
::
operators
::
math
::
AvgPool
<
T
>
pool_process
;
pool2d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
pool2d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
exclusive
,
out
);
}
}
}
break
;
}
break
;
case
3
:
{
case
3
:
{
...
@@ -102,14 +103,14 @@ class PoolKernel : public framework::OpKernel<T> {
...
@@ -102,14 +103,14 @@ class PoolKernel : public framework::OpKernel<T> {
pool3d_forward
;
pool3d_forward
;
paddle
::
operators
::
math
::
MaxPool
<
T
>
pool_process
;
paddle
::
operators
::
math
::
MaxPool
<
T
>
pool_process
;
pool3d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
pool3d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
true
,
out
);
}
else
if
(
pooling_type
==
"avg"
)
{
}
else
if
(
pooling_type
==
"avg"
)
{
paddle
::
operators
::
math
::
Pool3dFunctor
<
paddle
::
operators
::
math
::
Pool3dFunctor
<
DeviceContext
,
paddle
::
operators
::
math
::
AvgPool
<
T
>
,
T
>
DeviceContext
,
paddle
::
operators
::
math
::
AvgPool
<
T
>
,
T
>
pool3d_forward
;
pool3d_forward
;
paddle
::
operators
::
math
::
AvgPool
<
T
>
pool_process
;
paddle
::
operators
::
math
::
AvgPool
<
T
>
pool_process
;
pool3d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
pool3d_forward
(
dev_ctx
,
*
in_x
,
ksize
,
strides
,
paddings
,
pool_process
,
out
);
exclusive
,
out
);
}
}
}
break
;
}
break
;
default:
{
PADDLE_THROW
(
"Pool op only supports 2D and 3D input."
);
}
default:
{
PADDLE_THROW
(
"Pool op only supports 2D and 3D input."
);
}
...
@@ -131,6 +132,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
...
@@ -131,6 +132,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
bool
exclusive
=
context
.
Attr
<
bool
>
(
"exclusive"
);
if
(
context
.
Attr
<
bool
>
(
"global_pooling"
))
{
if
(
context
.
Attr
<
bool
>
(
"global_pooling"
))
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
...
@@ -157,7 +159,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
...
@@ -157,7 +159,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
pool2d_backward
;
pool2d_backward
;
paddle
::
operators
::
math
::
AvgPoolGrad
<
T
>
pool_process
;
paddle
::
operators
::
math
::
AvgPoolGrad
<
T
>
pool_process
;
pool2d_backward
(
dev_ctx
,
*
in_x
,
*
out
,
*
out_grad
,
ksize
,
strides
,
pool2d_backward
(
dev_ctx
,
*
in_x
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
pool_process
,
in_x_grad
);
paddings
,
pool_process
,
exclusive
,
in_x_grad
);
}
}
}
break
;
}
break
;
case
3
:
{
case
3
:
{
...
@@ -172,7 +174,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
...
@@ -172,7 +174,7 @@ class PoolGradKernel : public framework::OpKernel<T> {
pool3d_backward
;
pool3d_backward
;
paddle
::
operators
::
math
::
AvgPoolGrad
<
T
>
pool_process
;
paddle
::
operators
::
math
::
AvgPoolGrad
<
T
>
pool_process
;
pool3d_backward
(
dev_ctx
,
*
in_x
,
*
out
,
*
out_grad
,
ksize
,
strides
,
pool3d_backward
(
dev_ctx
,
*
in_x
,
*
out
,
*
out_grad
,
ksize
,
strides
,
paddings
,
pool_process
,
in_x_grad
);
paddings
,
pool_process
,
exclusive
,
in_x_grad
);
}
}
}
break
;
}
break
;
default:
{
PADDLE_THROW
(
"Pool op only supports 2D and 3D input."
);
}
default:
{
PADDLE_THROW
(
"Pool op only supports 2D and 3D input."
);
}
...
...
paddle/fluid/operators/spp_op.h
浏览文件 @
daed473d
...
@@ -56,12 +56,14 @@ class SppKernel : public framework::OpKernel<T> {
...
@@ -56,12 +56,14 @@ class SppKernel : public framework::OpKernel<T> {
math
::
Pool2dFunctor
<
DeviceContext
,
math
::
MaxPool
<
T
>
,
T
>
pool_forward
;
math
::
Pool2dFunctor
<
DeviceContext
,
math
::
MaxPool
<
T
>
,
T
>
pool_forward
;
math
::
MaxPool
<
T
>
max_process
;
math
::
MaxPool
<
T
>
max_process
;
pool_forward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
pool_forward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
kernel_size
,
strides
,
paddings
,
max_process
,
&
out_level
);
kernel_size
,
strides
,
paddings
,
max_process
,
true
,
&
out_level
);
}
else
if
(
pooling_type
==
"avg"
)
{
}
else
if
(
pooling_type
==
"avg"
)
{
math
::
Pool2dFunctor
<
DeviceContext
,
math
::
AvgPool
<
T
>
,
T
>
pool_forward
;
math
::
Pool2dFunctor
<
DeviceContext
,
math
::
AvgPool
<
T
>
,
T
>
pool_forward
;
math
::
AvgPool
<
T
>
avg_process
;
math
::
AvgPool
<
T
>
avg_process
;
pool_forward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
pool_forward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
kernel_size
,
strides
,
paddings
,
avg_process
,
&
out_level
);
kernel_size
,
strides
,
paddings
,
avg_process
,
true
,
&
out_level
);
}
}
// flatten pooling output shape
// flatten pooling output shape
int
output_flatten_w
=
in_x
->
dims
()[
1
]
*
bins
*
bins
;
int
output_flatten_w
=
in_x
->
dims
()[
1
]
*
bins
*
bins
;
...
@@ -154,7 +156,7 @@ class SppGradKernel : public framework::OpKernel<T> {
...
@@ -154,7 +156,7 @@ class SppGradKernel : public framework::OpKernel<T> {
math
::
AvgPoolGrad
<
T
>
avg_process
;
math
::
AvgPoolGrad
<
T
>
avg_process
;
pool_backward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
pool_backward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
*&
out_level
,
*&
outgrad_level
,
kernel_size
,
strides
,
*&
out_level
,
*&
outgrad_level
,
kernel_size
,
strides
,
paddings
,
avg_process
,
in_x_grad
);
paddings
,
avg_process
,
true
,
in_x_grad
);
}
}
}
}
}
}
...
...
paddle/fluid/platform/cudnn_helper.h
浏览文件 @
daed473d
...
@@ -76,8 +76,9 @@ enum class DataLayout { // Not use
...
@@ -76,8 +76,9 @@ enum class DataLayout { // Not use
enum
class
PoolingMode
{
enum
class
PoolingMode
{
kMaximum
,
kMaximum
,
kAverage
,
kMaximumDeterministic
,
kMaximumDeterministic
,
kAverageExclusive
,
kAverageInclusive
,
};
};
#if CUDNN_VERSION < 6000
#if CUDNN_VERSION < 6000
...
@@ -91,8 +92,10 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
...
@@ -91,8 +92,10 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
switch
(
mode
)
{
switch
(
mode
)
{
case
PoolingMode
::
kMaximumDeterministic
:
case
PoolingMode
::
kMaximumDeterministic
:
return
CUDNN_POOLING_MAX
;
return
CUDNN_POOLING_MAX
;
case
PoolingMode
::
kAverage
:
case
PoolingMode
::
kAverage
Exclusive
:
return
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
;
return
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
;
case
PoolingMode
::
kAverageInclusive
:
return
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
;
case
PoolingMode
::
kMaximum
:
case
PoolingMode
::
kMaximum
:
return
CUDNN_POOLING_MAX
;
return
CUDNN_POOLING_MAX
;
default:
default:
...
@@ -105,8 +108,10 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
...
@@ -105,8 +108,10 @@ inline cudnnPoolingMode_t GetPoolingMode(const PoolingMode& mode) {
switch
(
mode
)
{
switch
(
mode
)
{
case
PoolingMode
::
kMaximumDeterministic
:
case
PoolingMode
::
kMaximumDeterministic
:
return
CUDNN_POOLING_MAX_DETERMINISTIC
;
return
CUDNN_POOLING_MAX_DETERMINISTIC
;
case
PoolingMode
::
kAverage
:
case
PoolingMode
::
kAverage
Exclusive
:
return
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
;
return
CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING
;
case
PoolingMode
::
kAverageInclusive
:
return
CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING
;
case
PoolingMode
::
kMaximum
:
case
PoolingMode
::
kMaximum
:
return
CUDNN_POOLING_MAX
;
return
CUDNN_POOLING_MAX
;
default:
default:
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
daed473d
...
@@ -2102,7 +2102,8 @@ def pool2d(input,
...
@@ -2102,7 +2102,8 @@ def pool2d(input,
global_pooling
=
False
,
global_pooling
=
False
,
use_cudnn
=
True
,
use_cudnn
=
True
,
ceil_mode
=
False
,
ceil_mode
=
False
,
name
=
None
):
name
=
None
,
exclusive
=
True
):
"""
"""
${comment}
${comment}
...
@@ -2116,11 +2117,13 @@ def pool2d(input,
...
@@ -2116,11 +2117,13 @@ def pool2d(input,
pool_type: ${pooling_type_comment}
pool_type: ${pooling_type_comment}
pool_stride (int): stride of the pooling layer.
pool_stride (int): stride of the pooling layer.
pool_padding (int): padding size.
pool_padding (int): padding size.
global_pooling: ${global_pooling_comment}
global_pooling
(bool)
: ${global_pooling_comment}
use_cudnn: ${use_cudnn_comment}
use_cudnn
(bool)
: ${use_cudnn_comment}
ceil_mode: ${ceil_mode_comment}
ceil_mode
(bool)
: ${ceil_mode_comment}
name (str|None): A name for this layer(optional). If set None, the
name (str|None): A name for this layer(optional). If set None, the
layer will be named automatically.
layer will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
Returns:
Returns:
Variable: The pooling result.
Variable: The pooling result.
...
@@ -2178,7 +2181,8 @@ def pool2d(input,
...
@@ -2178,7 +2181,8 @@ def pool2d(input,
"paddings"
:
pool_padding
,
"paddings"
:
pool_padding
,
"use_cudnn"
:
use_cudnn
,
"use_cudnn"
:
use_cudnn
,
"ceil_mode"
:
ceil_mode
,
"ceil_mode"
:
ceil_mode
,
"use_mkldnn"
:
False
"use_mkldnn"
:
False
,
"exclusive"
:
exclusive
,
})
})
return
pool_out
return
pool_out
...
@@ -2192,7 +2196,8 @@ def pool3d(input,
...
@@ -2192,7 +2196,8 @@ def pool3d(input,
global_pooling
=
False
,
global_pooling
=
False
,
use_cudnn
=
True
,
use_cudnn
=
True
,
ceil_mode
=
False
,
ceil_mode
=
False
,
name
=
None
):
name
=
None
,
exclusive
=
True
):
"""
"""
This function adds the operator for pooling in 3-dimensions, using the
This function adds the operator for pooling in 3-dimensions, using the
pooling configurations mentioned in input parameters.
pooling configurations mentioned in input parameters.
...
@@ -2208,6 +2213,8 @@ def pool3d(input,
...
@@ -2208,6 +2213,8 @@ def pool3d(input,
ceil_mode (bool): ${ceil_mode_comment}
ceil_mode (bool): ${ceil_mode_comment}
name (str): A name for this layer(optional). If set None, the layer
name (str): A name for this layer(optional). If set None, the layer
will be named automatically.
will be named automatically.
exclusive (bool): Whether to exclude padding points in average pooling
mode, default is true
Returns:
Returns:
Variable: output of pool3d layer.
Variable: output of pool3d layer.
...
@@ -2246,7 +2253,8 @@ def pool3d(input,
...
@@ -2246,7 +2253,8 @@ def pool3d(input,
"paddings"
:
pool_padding
,
"paddings"
:
pool_padding
,
"use_cudnn"
:
use_cudnn
,
"use_cudnn"
:
use_cudnn
,
"ceil_mode"
:
ceil_mode
,
"ceil_mode"
:
ceil_mode
,
"use_mkldnn"
:
False
"use_mkldnn"
:
False
,
"exclusive"
:
exclusive
,
})
})
return
pool_out
return
pool_out
...
...
python/paddle/fluid/tests/unittests/test_pool2d_op.py
浏览文件 @
daed473d
...
@@ -26,7 +26,8 @@ def max_pool2D_forward_naive(x,
...
@@ -26,7 +26,8 @@ def max_pool2D_forward_naive(x,
strides
,
strides
,
paddings
,
paddings
,
global_pool
=
0
,
global_pool
=
0
,
ceil_mode
=
False
):
ceil_mode
=
False
,
exclusive
=
True
):
N
,
C
,
H
,
W
=
x
.
shape
N
,
C
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
if
global_pool
==
1
:
ksize
=
[
H
,
W
]
ksize
=
[
H
,
W
]
...
@@ -54,7 +55,8 @@ def avg_pool2D_forward_naive(x,
...
@@ -54,7 +55,8 @@ def avg_pool2D_forward_naive(x,
strides
,
strides
,
paddings
,
paddings
,
global_pool
=
0
,
global_pool
=
0
,
ceil_mode
=
False
):
ceil_mode
=
False
,
exclusive
=
True
):
N
,
C
,
H
,
W
=
x
.
shape
N
,
C
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
if
global_pool
==
1
:
ksize
=
[
H
,
W
]
ksize
=
[
H
,
W
]
...
@@ -73,8 +75,9 @@ def avg_pool2D_forward_naive(x,
...
@@ -73,8 +75,9 @@ def avg_pool2D_forward_naive(x,
c_end
=
np
.
min
((
j
*
strides
[
1
]
+
ksize
[
1
]
-
paddings
[
1
],
W
))
c_end
=
np
.
min
((
j
*
strides
[
1
]
+
ksize
[
1
]
-
paddings
[
1
],
W
))
x_masked
=
x
[:,
:,
r_start
:
r_end
,
c_start
:
c_end
]
x_masked
=
x
[:,
:,
r_start
:
r_end
,
c_start
:
c_end
]
out
[:,
:,
i
,
j
]
=
np
.
sum
(
x_masked
,
axis
=
(
2
,
3
))
/
(
field_size
=
((
r_end
-
r_start
)
*
(
c_end
-
c_start
))
if
exclusive
\
(
r_end
-
r_start
)
*
(
c_end
-
c_start
))
else
(
ksize
[
0
]
*
ksize
[
1
])
out
[:,
:,
i
,
j
]
=
np
.
sum
(
x_masked
,
axis
=
(
2
,
3
))
/
field_size
return
out
return
out
...
@@ -89,12 +92,13 @@ class TestPool2d_Op(OpTest):
...
@@ -89,12 +92,13 @@ class TestPool2d_Op(OpTest):
self
.
init_kernel_type
()
self
.
init_kernel_type
()
self
.
init_pool_type
()
self
.
init_pool_type
()
self
.
init_ceil_mode
()
self
.
init_ceil_mode
()
self
.
init_exclusive
()
if
self
.
global_pool
:
if
self
.
global_pool
:
self
.
paddings
=
[
0
for
_
in
range
(
len
(
self
.
paddings
))]
self
.
paddings
=
[
0
for
_
in
range
(
len
(
self
.
paddings
))]
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
output
=
self
.
pool2D_forward_naive
(
input
,
self
.
ksize
,
self
.
strides
,
output
=
self
.
pool2D_forward_naive
(
self
.
paddings
,
self
.
global_pool
,
input
,
self
.
ksize
,
self
.
strides
,
self
.
paddings
,
self
.
global_pool
,
self
.
ceil_mod
e
).
astype
(
self
.
dtype
)
self
.
ceil_mode
,
self
.
exclusiv
e
).
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
input
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
input
)}
self
.
attrs
=
{
self
.
attrs
=
{
...
@@ -106,7 +110,9 @@ class TestPool2d_Op(OpTest):
...
@@ -106,7 +110,9 @@ class TestPool2d_Op(OpTest):
'use_cudnn'
:
self
.
use_cudnn
,
'use_cudnn'
:
self
.
use_cudnn
,
'use_mkldnn'
:
self
.
use_mkldnn
,
'use_mkldnn'
:
self
.
use_mkldnn
,
'ceil_mode'
:
self
.
ceil_mode
,
'ceil_mode'
:
self
.
ceil_mode
,
'data_format'
:
'AnyLayout'
# TODO(dzhwinter) : should be fix latter
'data_format'
:
'AnyLayout'
,
# TODO(dzhwinter) : should be fix latter
'exclusive'
:
self
.
exclusive
}
}
self
.
outputs
=
{
'Out'
:
output
}
self
.
outputs
=
{
'Out'
:
output
}
...
@@ -150,6 +156,9 @@ class TestPool2d_Op(OpTest):
...
@@ -150,6 +156,9 @@ class TestPool2d_Op(OpTest):
def
init_ceil_mode
(
self
):
def
init_ceil_mode
(
self
):
self
.
ceil_mode
=
False
self
.
ceil_mode
=
False
def
init_exclusive
(
self
):
self
.
exclusive
=
True
class
TestCase1
(
TestPool2d_Op
):
class
TestCase1
(
TestPool2d_Op
):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
...
@@ -322,5 +331,15 @@ class TestCeilModeCase4(TestCase2):
...
@@ -322,5 +331,15 @@ class TestCeilModeCase4(TestCase2):
self
.
ceil_mode
=
True
self
.
ceil_mode
=
True
class
TestAvgInclude
(
TestCase2
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
class
TestCUDNNAvgInclude
(
TestCUDNNCase3
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_pool3d_op.py
浏览文件 @
daed473d
...
@@ -26,7 +26,8 @@ def max_pool3D_forward_naive(x,
...
@@ -26,7 +26,8 @@ def max_pool3D_forward_naive(x,
strides
,
strides
,
paddings
,
paddings
,
global_pool
=
0
,
global_pool
=
0
,
ceil_mode
=
False
):
ceil_mode
=
False
,
exclusive
=
True
):
N
,
C
,
D
,
H
,
W
=
x
.
shape
N
,
C
,
D
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
if
global_pool
==
1
:
ksize
=
[
D
,
H
,
W
]
ksize
=
[
D
,
H
,
W
]
...
@@ -60,7 +61,8 @@ def avg_pool3D_forward_naive(x,
...
@@ -60,7 +61,8 @@ def avg_pool3D_forward_naive(x,
strides
,
strides
,
paddings
,
paddings
,
global_pool
=
0
,
global_pool
=
0
,
ceil_mode
=
False
):
ceil_mode
=
False
,
exclusive
=
True
):
N
,
C
,
D
,
H
,
W
=
x
.
shape
N
,
C
,
D
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
if
global_pool
==
1
:
ksize
=
[
D
,
H
,
W
]
ksize
=
[
D
,
H
,
W
]
...
@@ -85,8 +87,10 @@ def avg_pool3D_forward_naive(x,
...
@@ -85,8 +87,10 @@ def avg_pool3D_forward_naive(x,
w_end
=
np
.
min
((
j
*
strides
[
1
]
+
ksize
[
1
]
-
paddings
[
1
],
W
))
w_end
=
np
.
min
((
j
*
strides
[
1
]
+
ksize
[
1
]
-
paddings
[
1
],
W
))
x_masked
=
x
[:,
:,
d_start
:
d_end
,
h_start
:
h_end
,
w_start
:
w_end
]
x_masked
=
x
[:,
:,
d_start
:
d_end
,
h_start
:
h_end
,
w_start
:
w_end
]
out
[:,
:,
k
,
i
,
j
]
=
np
.
sum
(
x_masked
,
axis
=
(
2
,
3
,
4
))
/
(
field_size
=
(
d_end
-
d_start
)
*
(
h_end
-
h_start
)
*
(
w_end
-
w_start
)
\
(
d_end
-
d_start
)
*
(
h_end
-
h_start
)
*
(
w_end
-
w_start
))
if
exclusive
else
ksize
[
0
]
*
ksize
[
1
]
*
ksize
[
2
]
out
[:,
:,
k
,
i
,
j
]
=
np
.
sum
(
x_masked
,
axis
=
(
2
,
3
,
4
))
/
field_size
return
out
return
out
...
@@ -100,13 +104,14 @@ class TestPool3d_Op(OpTest):
...
@@ -100,13 +104,14 @@ class TestPool3d_Op(OpTest):
self
.
init_kernel_type
()
self
.
init_kernel_type
()
self
.
init_pool_type
()
self
.
init_pool_type
()
self
.
init_ceil_mode
()
self
.
init_ceil_mode
()
self
.
init_exclusive
()
if
self
.
global_pool
:
if
self
.
global_pool
:
self
.
paddings
=
[
0
for
_
in
range
(
len
(
self
.
paddings
))]
self
.
paddings
=
[
0
for
_
in
range
(
len
(
self
.
paddings
))]
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
output
=
self
.
pool3D_forward_naive
(
input
,
self
.
ksize
,
self
.
strides
,
output
=
self
.
pool3D_forward_naive
(
self
.
paddings
,
self
.
global_pool
,
input
,
self
.
ksize
,
self
.
strides
,
self
.
paddings
,
self
.
global_pool
,
self
.
ceil_mod
e
).
astype
(
self
.
dtype
)
self
.
ceil_mode
,
self
.
exclusiv
e
).
astype
(
self
.
dtype
)
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
input
)}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
input
)}
self
.
attrs
=
{
self
.
attrs
=
{
...
@@ -117,7 +122,9 @@ class TestPool3d_Op(OpTest):
...
@@ -117,7 +122,9 @@ class TestPool3d_Op(OpTest):
'global_pooling'
:
self
.
global_pool
,
'global_pooling'
:
self
.
global_pool
,
'use_cudnn'
:
self
.
use_cudnn
,
'use_cudnn'
:
self
.
use_cudnn
,
'ceil_mode'
:
self
.
ceil_mode
,
'ceil_mode'
:
self
.
ceil_mode
,
'data_format'
:
'AnyLayout'
# TODO(dzhwinter) : should be fix latter
'data_format'
:
'AnyLayout'
,
# TODO(dzhwinter) : should be fix latter
'exclusive'
:
self
.
exclusive
}
}
self
.
outputs
=
{
'Out'
:
output
}
self
.
outputs
=
{
'Out'
:
output
}
...
@@ -161,6 +168,9 @@ class TestPool3d_Op(OpTest):
...
@@ -161,6 +168,9 @@ class TestPool3d_Op(OpTest):
def
init_ceil_mode
(
self
):
def
init_ceil_mode
(
self
):
self
.
ceil_mode
=
False
self
.
ceil_mode
=
False
def
init_exclusive
(
self
):
self
.
exclusive
=
True
class
TestCase1
(
TestPool3d_Op
):
class
TestCase1
(
TestPool3d_Op
):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
...
@@ -333,5 +343,15 @@ class TestCeilModeCase4(TestCase2):
...
@@ -333,5 +343,15 @@ class TestCeilModeCase4(TestCase2):
self
.
ceil_mode
=
True
self
.
ceil_mode
=
True
class
TestAvgInclude
(
TestCase2
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
class
TestCUDNNAvgInclude
(
TestCUDNNCase3
):
def
init_exclusive
(
self
):
self
.
exclusive
=
False
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录