Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
de80c569
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
de80c569
编写于
12月 06, 2017
作者:
T
Tao Luo
提交者:
GitHub
12月 06, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6100 from guoshengCS/enhance-include-pool
Enhance AvgPooling to support both include_mode and exclude_mode
上级
c4599d3e
e1358945
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
114 addition
and
39 deletion
+114
-39
paddle/cuda/include/hl_cnn.h
paddle/cuda/include/hl_cnn.h
+6
-2
paddle/cuda/include/stub/hl_cnn_stub.h
paddle/cuda/include/stub/hl_cnn_stub.h
+4
-2
paddle/cuda/src/hl_cuda_cnn.cu
paddle/cuda/src/hl_cuda_cnn.cu
+18
-10
paddle/gserver/layers/PoolLayer.cpp
paddle/gserver/layers/PoolLayer.cpp
+2
-0
paddle/gserver/layers/PoolLayer.h
paddle/gserver/layers/PoolLayer.h
+2
-0
paddle/gserver/layers/PoolProjection.cpp
paddle/gserver/layers/PoolProjection.cpp
+6
-2
paddle/gserver/layers/PoolProjection.h
paddle/gserver/layers/PoolProjection.h
+1
-0
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+15
-1
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+16
-8
paddle/math/Matrix.h
paddle/math/Matrix.h
+13
-6
proto/ModelConfig.proto
proto/ModelConfig.proto
+2
-0
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+6
-3
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+11
-4
python/paddle/trainer_config_helpers/poolings.py
python/paddle/trainer_config_helpers/poolings.py
+12
-1
未找到文件。
paddle/cuda/include/hl_cnn.h
浏览文件 @
de80c569
...
...
@@ -116,6 +116,7 @@ extern void hl_maxpool_backward(const int frameCnt,
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
* @param[in] excludeMode whether to consider paddings for size.
*
*/
extern
void
hl_avgpool_forward
(
const
int
frameCnt
,
...
...
@@ -132,7 +133,8 @@ extern void hl_avgpool_forward(const int frameCnt,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
);
const
int
tgtStride
,
bool
excludeMode
);
/**
* @brief Maximum pool backward.
...
...
@@ -154,6 +156,7 @@ extern void hl_avgpool_forward(const int frameCnt,
* @param[in] scaleB scale.
* @param[out] backGrad output grad.
* @param[in] outStride stride between output data samples.
* @param[in] excludeMode whether to consider paddings for size.
*
*/
extern
void
hl_avgpool_backward
(
const
int
frameCnt
,
...
...
@@ -172,7 +175,8 @@ extern void hl_avgpool_backward(const int frameCnt,
real
scaleA
,
real
scaleB
,
real
*
backGrad
,
const
int
outStride
);
const
int
outStride
,
bool
excludeMode
);
extern
void
hl_maxpool3D_forward
(
const
int
frameCnt
,
const
real
*
inputData
,
...
...
paddle/cuda/include/stub/hl_cnn_stub.h
浏览文件 @
de80c569
...
...
@@ -68,7 +68,8 @@ inline void hl_avgpool_forward(const int frameCnt,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
)
{}
const
int
tgtStride
,
const
bool
excludeMode
)
{}
inline
void
hl_avgpool_backward
(
const
int
frameCnt
,
const
real
*
outGrad
,
...
...
@@ -86,7 +87,8 @@ inline void hl_avgpool_backward(const int frameCnt,
real
scaleA
,
real
scaleB
,
real
*
backGrad
,
const
int
outStride
)
{}
const
int
outStride
,
const
bool
excludeMode
)
{}
inline
void
hl_maxpool3D_forward
(
const
int
frameCnt
,
const
real
*
inputData
,
...
...
paddle/cuda/src/hl_cuda_cnn.cu
浏览文件 @
de80c569
...
...
@@ -210,7 +210,8 @@ __global__ void KeAvgPoolForward(const int nthreads,
const
int
padH
,
const
int
padW
,
real
*
tgtData
,
const
int
tgtStride
)
{
const
int
tgtStride
,
const
bool
excludeMode
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
int
pw
=
index
%
pooledW
;
...
...
@@ -224,7 +225,8 @@ __global__ void KeAvgPoolForward(const int nthreads,
int
wend
=
min
(
wstart
+
sizeX
,
width
);
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
int
pool_size
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
poolSize
=
excludeMode
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
sizeY
*
sizeX
;
real
aveval
=
0
;
inputData
+=
(
frameNum
*
channels
+
c
)
*
height
*
width
;
...
...
@@ -235,7 +237,7 @@ __global__ void KeAvgPoolForward(const int nthreads,
}
int
tgtIndex
=
index
%
(
pooledW
*
pooledH
*
channels
)
+
frameNum
*
tgtStride
;
tgtData
[
tgtIndex
]
=
aveval
/
pool
_s
ize
;
tgtData
[
tgtIndex
]
=
aveval
/
pool
S
ize
;
}
}
...
...
@@ -253,7 +255,8 @@ void hl_avgpool_forward(const int frameCnt,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
)
{
const
int
tgtStride
,
const
bool
excludeMode
)
{
int
num_kernels
=
pooledH
*
pooledW
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
KeAvgPoolForward
<<<
blocks
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
num_kernels
,
...
...
@@ -270,7 +273,8 @@ void hl_avgpool_forward(const int frameCnt,
paddingH
,
paddingW
,
tgtData
,
tgtStride
);
tgtStride
,
excludeMode
);
CHECK_SYNC
(
"hl_avgpool_forward failed"
);
}
...
...
@@ -290,7 +294,8 @@ __global__ void KeAvgPoolBackward(const int nthreads,
real
scaleA
,
real
scaleB
,
real
*
tgtGrad
,
const
int
outStride
)
{
const
int
outStride
,
const
bool
excludeMode
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
int
offsetW
=
index
%
width
+
padW
;
...
...
@@ -314,8 +319,9 @@ __global__ void KeAvgPoolBackward(const int nthreads,
int
wstart
=
pw
*
strideW
-
padW
;
int
wend
=
min
(
wstart
+
sizeX
,
width
);
wstart
=
max
(
wstart
,
0
);
int
poolsize
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
gradient
+=
outGrad
[
ph
*
pooledW
+
pw
]
/
poolsize
;
int
poolSize
=
excludeMode
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
sizeY
*
sizeX
;
gradient
+=
outGrad
[
ph
*
pooledW
+
pw
]
/
poolSize
;
}
}
tgtGrad
[
index
]
=
scaleB
*
tgtGrad
[
index
]
+
scaleA
*
gradient
;
...
...
@@ -338,7 +344,8 @@ void hl_avgpool_backward(const int frameCnt,
real
scaleA
,
real
scaleB
,
real
*
backGrad
,
const
int
outStride
)
{
const
int
outStride
,
const
bool
excludeMode
)
{
int
num_kernels
=
height
*
width
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
...
...
@@ -358,7 +365,8 @@ void hl_avgpool_backward(const int frameCnt,
scaleA
,
scaleB
,
backGrad
,
outStride
);
outStride
,
excludeMode
);
CHECK_SYNC
(
"hl_avgpool_backward failed"
);
}
...
...
paddle/gserver/layers/PoolLayer.cpp
浏览文件 @
de80c569
...
...
@@ -45,6 +45,8 @@ bool PoolLayer::init(const LayerMap& layerMap,
strideY_
=
conf
.
has_stride_y
()
?
conf
.
stride_y
()
:
conf
.
stride
();
confPaddingY_
=
conf
.
has_padding_y
()
?
conf
.
padding_y
()
:
conf
.
padding
();
outputY_
=
conf
.
has_output_y
()
?
conf
.
output_y
()
:
conf
.
output_x
();
excludeMode_
=
conf
.
has_exclude_mode
()
?
conf
.
exclude_mode
()
:
true
;
return
true
;
}
...
...
paddle/gserver/layers/PoolLayer.h
浏览文件 @
de80c569
...
...
@@ -38,6 +38,8 @@ protected:
std
::
string
poolType_
;
bool
excludeMode_
;
public:
explicit
PoolLayer
(
const
LayerConfig
&
config
)
:
Layer
(
config
)
{}
...
...
paddle/gserver/layers/PoolProjection.cpp
浏览文件 @
de80c569
...
...
@@ -36,6 +36,8 @@ PoolProjection::PoolProjection(const ProjectionConfig& config,
strideY_
=
conf
.
has_stride_y
()
?
conf
.
stride_y
()
:
conf
.
stride
();
confPaddingY_
=
conf
.
has_padding_y
()
?
conf
.
padding_y
()
:
conf
.
padding
();
outputY_
=
conf
.
has_output_y
()
?
conf
.
output_y
()
:
conf
.
output_x
();
excludeMode_
=
conf
.
has_exclude_mode
()
?
conf
.
exclude_mode
()
:
true
;
}
size_t
PoolProjection
::
getSize
()
{
...
...
@@ -141,7 +143,8 @@ void AvgPoolProjection::forward() {
outputY_
,
outputX_
,
confPaddingY_
,
confPadding_
);
confPadding_
,
excludeMode_
);
}
void
AvgPoolProjection
::
backward
(
const
UpdateCallback
&
callback
)
{
...
...
@@ -166,6 +169,7 @@ void AvgPoolProjection::backward(const UpdateCallback& callback) {
1
,
1
,
confPaddingY_
,
confPadding_
);
confPadding_
,
excludeMode_
);
}
}
// namespace paddle
paddle/gserver/layers/PoolProjection.h
浏览文件 @
de80c569
...
...
@@ -28,6 +28,7 @@ protected:
int
confPaddingY_
,
confPadding_
;
size_t
channels_
;
std
::
string
poolType_
;
bool
excludeMode_
;
public:
PoolProjection
(
const
ProjectionConfig
&
config
,
...
...
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
de80c569
...
...
@@ -1211,7 +1211,10 @@ void setPoolConfig(TestConfig* config,
pool
->
set_output_y
(
oh
);
}
void
testPoolLayer
(
const
string
&
poolType
,
bool
trans
,
bool
useGpu
)
{
void
testPoolLayer
(
const
string
&
poolType
,
bool
trans
,
bool
useGpu
,
bool
excludeMode
=
true
)
{
TestConfig
config
;
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"layer_0"
,
3136
,
0
});
LayerInputConfig
*
input
=
config
.
layerConfig
.
add_inputs
();
...
...
@@ -1219,6 +1222,7 @@ void testPoolLayer(const string& poolType, bool trans, bool useGpu) {
pool
->
set_img_size
(
14
);
pool
->
set_img_size_y
(
14
);
pool
->
set_exclude_mode
(
excludeMode
);
setPoolConfig
(
&
config
,
pool
,
poolType
);
config
.
layerConfig
.
set_size
(
pool
->
output_x
()
*
pool
->
output_y
()
*
pool
->
channels
());
...
...
@@ -1250,16 +1254,26 @@ void testPoolLayer2(const string& poolType, bool trans, bool useGpu) {
TEST
(
Layer
,
PoolLayer
)
{
testPoolLayer
(
"avg-projection"
,
/* trans= */
false
,
/* useGpu= */
false
);
testPoolLayer
(
"avg-projection"
,
/* trans= */
false
,
/* useGpu= */
false
,
/* excludeMode= */
false
);
testPoolLayer
(
"max-projection"
,
/* trans= */
false
,
/* useGpu= */
false
);
testPoolLayer
(
"max-pool-with-mask"
,
/* trans= */
false
,
/* useGpu= */
false
);
#ifdef PADDLE_WITH_CUDA
testPoolLayer
(
"avg-projection"
,
/* trans= */
false
,
/* useGpu= */
true
);
testPoolLayer
(
"avg-projection"
,
/* trans= */
false
,
/* useGpu= */
true
,
/* excludeMode= */
false
);
testPoolLayer
(
"max-projection"
,
/* trans= */
false
,
/* useGpu= */
true
);
testPoolLayer
(
"cudnn-max-pool"
,
/* trans= */
false
,
/* useGpu= */
true
);
testPoolLayer
(
"cudnn-avg-pool"
,
/* trans= */
false
,
/* useGpu= */
true
);
testPoolLayer2
(
"cudnn-max-pool"
,
/* trans= */
false
,
/* useGpu= */
true
);
testPoolLayer2
(
"cudnn-avg-pool"
,
/* trans= */
false
,
/* useGpu= */
true
);
testPoolLayer2
(
"cudnn-avg-incl-pad-pool"
,
/* trans= */
false
,
/* useGpu= */
true
);
testPoolLayer
(
"max-pool-with-mask"
,
/* trans= */
false
,
/* useGpu= */
true
);
#endif
}
...
...
paddle/math/Matrix.cpp
浏览文件 @
de80c569
...
...
@@ -1130,7 +1130,8 @@ void GpuMatrix::avgPoolForward(Matrix& inputMat,
size_t
outputH
,
size_t
outputW
,
size_t
paddingH
,
size_t
paddingW
)
{
size_t
paddingW
,
bool
excludeMode
)
{
CHECK
(
inputMat
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
real
*
inputData
=
inputMat
.
getData
();
...
...
@@ -1153,7 +1154,8 @@ void GpuMatrix::avgPoolForward(Matrix& inputMat,
paddingH
,
paddingW
,
data_
,
getStride
());
getStride
(),
excludeMode
);
}
void
GpuMatrix
::
avgPoolBackward
(
Matrix
&
outGrad
,
...
...
@@ -1168,7 +1170,8 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad,
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingH
,
size_t
paddingW
)
{
size_t
paddingW
,
bool
excludeMode
)
{
CHECK
(
outGrad
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
real
*
outDiff
=
outGrad
.
getData
();
...
...
@@ -1194,7 +1197,8 @@ void GpuMatrix::avgPoolBackward(Matrix& outGrad,
scaleTargets
,
scaleOutput
,
data_
,
outGrad
.
getStride
());
outGrad
.
getStride
(),
excludeMode
);
}
void
GpuMatrix
::
maxPool3DForward
(
Matrix
&
inputMat
,
...
...
@@ -2136,7 +2140,8 @@ void CpuMatrix::avgPoolForward(Matrix& input,
size_t
outputH
,
size_t
outputW
,
size_t
paddingH
,
size_t
paddingW
)
{
size_t
paddingW
,
bool
excludeMode
)
{
// The main loop
size_t
num
=
input
.
getHeight
();
size_t
inLength
=
imgSizeH
*
imgSizeW
;
...
...
@@ -2165,7 +2170,8 @@ void CpuMatrix::avgPoolForward(Matrix& input,
tgtData
[
ph
*
outputW
+
pw
]
+=
inData
[
h
*
imgSizeW
+
w
];
}
}
int
poolSize
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
poolSize
=
excludeMode
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
sizeY
*
sizeX
;
CHECK
(
poolSize
);
tgtData
[
ph
*
outputW
+
pw
]
/=
poolSize
;
}
...
...
@@ -2189,7 +2195,8 @@ void CpuMatrix::avgPoolBackward(Matrix& input,
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingH
,
size_t
paddingW
)
{
size_t
paddingW
,
bool
excludeMode
)
{
size_t
num
=
input
.
getHeight
();
size_t
channels
=
input
.
getWidth
()
/
outputH
/
outputW
;
size_t
inLength
=
imgSizeH
*
imgSizeW
;
...
...
@@ -2211,7 +2218,8 @@ void CpuMatrix::avgPoolBackward(Matrix& input,
int
wstart
=
pw
*
strideW
-
paddingW
;
int
wend
=
std
::
min
(
wstart
+
sizeX
,
imgSizeW
);
wstart
=
std
::
max
(
wstart
,
0
);
int
poolSize
=
(
hend
-
hstart
)
*
(
wend
-
wstart
);
int
poolSize
=
excludeMode
?
(
hend
-
hstart
)
*
(
wend
-
wstart
)
:
sizeY
*
sizeX
;
CHECK
(
poolSize
);
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
...
...
paddle/math/Matrix.h
浏览文件 @
de80c569
...
...
@@ -911,7 +911,8 @@ public:
size_t
outputH
,
size_t
outputW
,
size_t
paddingH
,
size_t
paddingW
)
{
size_t
paddingW
,
bool
excludeMode
=
true
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
...
...
@@ -927,9 +928,11 @@ public:
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingH
,
size_t
paddingW
)
{
size_t
paddingW
,
bool
excludeMode
=
true
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
/**
* Pooling 3D forward operation, pick out the largest element
* in the sizeX of value
...
...
@@ -1458,7 +1461,8 @@ public:
size_t
outputH
,
size_t
outputW
,
size_t
paddingH
,
size_t
paddingW
);
size_t
paddingW
,
bool
excludeMode
=
true
);
void
avgPoolBackward
(
Matrix
&
input
,
size_t
imgSizeH
,
...
...
@@ -1472,7 +1476,8 @@ public:
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingH
,
size_t
paddingW
);
size_t
paddingW
,
bool
excludeMode
=
true
);
void
maxPool3DForward
(
Matrix
&
inputMat
,
Matrix
&
maxPoolIdx
,
...
...
@@ -1730,7 +1735,8 @@ public:
size_t
outputH
,
size_t
outputW
,
size_t
paddingH
,
size_t
paddingW
);
size_t
paddingW
,
bool
excludeMode
=
true
);
void
avgPoolBackward
(
Matrix
&
input
,
size_t
imgSizeH
,
...
...
@@ -1744,7 +1750,8 @@ public:
real
scaleTargets
,
real
scaleOutput
,
size_t
paddingH
,
size_t
paddingW
);
size_t
paddingW
,
bool
excludeMode
=
true
);
void
maxPool3DForward
(
Matrix
&
inputMat
,
Matrix
&
maxPoolIdx
,
...
...
proto/ModelConfig.proto
浏览文件 @
de80c569
...
...
@@ -139,6 +139,8 @@ message PoolConfig {
optional
uint32
output_z
=
16
[
default
=
1
];
optional
uint32
img_size_z
=
17
[
default
=
1
];
optional
uint32
padding_z
=
18
[
default
=
1
];
optional
bool
exclude_mode
=
19
;
}
message
SppConfig
{
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
de80c569
...
...
@@ -1233,7 +1233,7 @@ def parse_bilinear(bilinear, input_layer_name, bilinear_conf):
bilinear_conf
.
out_size_y
=
bilinear
.
out_size_y
def
parse_pool
(
pool
,
input_layer_name
,
pool_conf
,
ceil_mode
):
def
parse_pool
(
pool
,
input_layer_name
,
pool_conf
,
ceil_mode
,
exclude_mode
):
pool_conf
.
pool_type
=
pool
.
pool_type
config_assert
(
pool
.
pool_type
in
[
'max-projection'
,
'avg-projection'
,
'max-pool-with-mask'
,
'cudnn-max-pool'
,
'cudnn-avg-pool'
...
...
@@ -1262,6 +1262,8 @@ def parse_pool(pool, input_layer_name, pool_conf, ceil_mode):
pool_conf
.
output_y
=
cnn_output_size
(
pool_conf
.
img_size_y
,
pool_conf
.
size_y
,
pool_conf
.
padding_y
,
pool_conf
.
stride_y
,
not
ceil_mode
)
if
exclude_mode
!=
None
:
pool_conf
.
exclude_mode
=
exclude_mode
def
parse_pool3d
(
pool
,
input_layer_name
,
pool_conf
,
ceil_mode
):
...
...
@@ -2303,7 +2305,8 @@ class NormLayer(LayerBase):
class
PoolLayer
(
LayerBase
):
layer_type
=
'pool'
def
__init__
(
self
,
name
,
inputs
,
ceil_mode
=
True
,
**
xargs
):
def
__init__
(
self
,
name
,
inputs
,
ceil_mode
=
True
,
exclude_mode
=
None
,
**
xargs
):
use_mkldnn
=
int
(
g_command_config_args
.
get
(
"use_mkldnn"
,
0
))
if
self
.
layer_type
==
"mkldnn_pool"
:
config_assert
(
use_mkldnn
,
"mkldnn_pool only support MKLDNN"
)
...
...
@@ -2314,7 +2317,7 @@ class PoolLayer(LayerBase):
input_layer
=
self
.
get_input_layer
(
input_index
)
pool_conf
=
self
.
config
.
inputs
[
input_index
].
pool_conf
parse_pool
(
self
.
inputs
[
input_index
].
pool
,
input_layer
.
name
,
pool_conf
,
ceil_mode
)
pool_conf
,
ceil_mode
,
exclude_mode
)
self
.
set_cnn_layer
(
name
,
pool_conf
.
output_y
,
pool_conf
.
output_x
,
pool_conf
.
channels
)
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
de80c569
...
...
@@ -21,7 +21,7 @@ from .activations import LinearActivation, SigmoidActivation, TanhActivation, \
ReluActivation
,
IdentityActivation
,
SoftmaxActivation
,
BaseActivation
from
.evaluators
import
*
from
.poolings
import
MaxPooling
,
AvgPooling
,
MaxWithMaskPooling
,
BasePoolingType
,
\
CudnnAvgPooling
,
CudnnMaxPooling
CudnnAvgPooling
,
Cudnn
AvgInclPadPooling
,
Cudnn
MaxPooling
from
.attrs
import
*
from
.default_decorators
import
*
...
...
@@ -2709,7 +2709,8 @@ def img_pool_layer(input,
pool_size_y
=
None
,
stride_y
=
None
,
padding_y
=
None
,
ceil_mode
=
True
):
ceil_mode
=
True
,
exclude_mode
=
None
):
"""
Image pooling Layer.
...
...
@@ -2773,10 +2774,15 @@ def img_pool_layer(input,
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
details.
:type layer_attr: ExtraLayerAttribute
:param ceil_mode: Wether to use the ceil function to calculate output height and width.
:param ceil_mode: W
h
ether to use the ceil function to calculate output height and width.
True is the default. If it is set to False, the floor function will
be used.
:type ceil_mode: bool
:param exclude_mode: Whether to exclude the padding cells when calculating, but only
work when pool_type is AvgPooling. If None, also exclude the padding
cells. If use cudnn, use CudnnAvgPooling or CudnnAvgInclPadPooling
as pool_type to identify the mode.
:type exclude_mode: bool
:return: LayerOutput object.
:rtype: LayerOutput
"""
...
...
@@ -2790,7 +2796,7 @@ def img_pool_layer(input,
pool_type
.
name
=
'avg'
assert
type
(
pool_type
)
in
[
AvgPooling
,
MaxPooling
,
MaxWithMaskPooling
,
CudnnAvgPooling
,
CudnnMaxPooling
],
\
CudnnMaxPooling
,
CudnnAvgInclPadPooling
],
\
"only (Cudnn)AvgPooling, (Cudnn)MaxPooling, MaxWithMaskPooling are supported"
type_name
=
pool_type
.
name
+
'-projection'
\
...
...
@@ -2819,6 +2825,7 @@ def img_pool_layer(input,
padding_y
=
padding_y
))
],
ceil_mode
=
ceil_mode
,
exclude_mode
=
exclude_mode
,
**
ExtraLayerAttribute
.
to_kwargs
(
layer_attr
))
return
LayerOutput
(
name
,
...
...
python/paddle/trainer_config_helpers/poolings.py
浏览文件 @
de80c569
...
...
@@ -16,7 +16,8 @@
__all__
=
[
"BasePoolingType"
,
"MaxPooling"
,
"AvgPooling"
,
"MaxWithMaskPooling"
,
"CudnnMaxPooling"
,
"CudnnAvgPooling"
,
"SumPooling"
,
"SquareRootNPooling"
"CudnnMaxPooling"
,
"CudnnAvgPooling"
,
"CudnnAvgInclPadPooling"
,
"SumPooling"
,
"SquareRootNPooling"
]
...
...
@@ -88,6 +89,16 @@ class CudnnAvgPooling(BasePoolingType):
BasePoolingType
.
__init__
(
self
,
"cudnn-avg-pool"
)
class
CudnnAvgInclPadPooling
(
BasePoolingType
):
"""
Cudnn average pooling only support GPU. Return the average value in the
pooling window taking into account the padding cells.
"""
def
__init__
(
self
):
BasePoolingType
.
__init__
(
self
,
"cudnn-avg-incl-pad-pool"
)
class
AvgPooling
(
BasePoolingType
):
"""
Average pooling.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录