Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
ca0bb40c
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ca0bb40c
编写于
11月 11, 2016
作者:
T
Tao Luo
提交者:
GitHub
11月 11, 2016
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #300 from QiJune/feature/sppnet
add SpatialPyramidPoolLayer c++ support
上级
c8da3f4b
9dd588b4
变更
20
展开全部
隐藏空白更改
内联
并排
Showing
20 changed file
with
795 addition
and
268 deletion
+795
-268
doc/ui/api/trainer_config_helpers/layers.rst
doc/ui/api/trainer_config_helpers/layers.rst
+12
-0
paddle/cuda/include/hl_cnn.h
paddle/cuda/include/hl_cnn.h
+10
-4
paddle/cuda/include/stub/hl_cnn_stub.h
paddle/cuda/include/stub/hl_cnn_stub.h
+6
-4
paddle/cuda/src/hl_cuda_cnn.cu
paddle/cuda/src/hl_cuda_cnn.cu
+23
-17
paddle/gserver/layers/PoolLayer.cpp
paddle/gserver/layers/PoolLayer.cpp
+2
-4
paddle/gserver/layers/PoolProjection.cpp
paddle/gserver/layers/PoolProjection.cpp
+123
-0
paddle/gserver/layers/PoolProjection.h
paddle/gserver/layers/PoolProjection.h
+63
-0
paddle/gserver/layers/PoolProjectionLayer.cpp
paddle/gserver/layers/PoolProjectionLayer.cpp
+7
-57
paddle/gserver/layers/PoolProjectionLayer.h
paddle/gserver/layers/PoolProjectionLayer.h
+11
-26
paddle/gserver/layers/Projection.h
paddle/gserver/layers/Projection.h
+9
-4
paddle/gserver/layers/SpatialPyramidPoolLayer.cpp
paddle/gserver/layers/SpatialPyramidPoolLayer.cpp
+130
-0
paddle/gserver/layers/SpatialPyramidPoolLayer.h
paddle/gserver/layers/SpatialPyramidPoolLayer.h
+57
-0
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+29
-3
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+143
-146
proto/ModelConfig.proto.m4
proto/ModelConfig.proto.m4
+12
-0
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+47
-1
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+60
-1
python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh
...trainer_config_helpers/tests/configs/generate_protostr.sh
+1
-1
python/paddle/trainer_config_helpers/tests/configs/protostr/test_spp_layer.protostr
...ig_helpers/tests/configs/protostr/test_spp_layer.protostr
+34
-0
python/paddle/trainer_config_helpers/tests/configs/test_spp_layer.py
...le/trainer_config_helpers/tests/configs/test_spp_layer.py
+16
-0
未找到文件。
doc/ui/api/trainer_config_helpers/layers.rst
浏览文件 @
ca0bb40c
...
...
@@ -46,6 +46,12 @@ conv_operator
:members: conv_operator
:noindex:
conv_projection
-------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: conv_projection
:noindex:
conv_shift_layer
------------------
.. automodule:: paddle.trainer_config_helpers.layers
...
...
@@ -71,6 +77,12 @@ img_pool_layer
--------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: img_pool_layer
:noindex:
spp_layer
--------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: spp_layer
:noindex:
maxout_layer
...
...
paddle/cuda/include/hl_cnn.h
浏览文件 @
ca0bb40c
...
...
@@ -91,6 +91,7 @@ extern void hl_expand_feature2col(
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
*
*/
extern
void
hl_maxpool_forward
(
...
...
@@ -100,7 +101,8 @@ extern void hl_maxpool_forward(
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeX
,
const
int
sizeY
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
);
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
);
/**
* @brief Maximum pool backward.
...
...
@@ -123,6 +125,7 @@ extern void hl_maxpool_forward(
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[out] targetGrad output grad.
* @param[in] outStride stride between output data samples.
*
*/
extern
void
hl_maxpool_backward
(
...
...
@@ -135,7 +138,7 @@ extern void hl_maxpool_backward(
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
real
scaleA
,
real
scaleB
,
real
*
targetGrad
);
real
*
targetGrad
,
const
int
outStride
);
/**
* @brief Averge pool forward.
...
...
@@ -154,6 +157,7 @@ extern void hl_maxpool_backward(
* @param[in] paddingH padding height.
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
*
*/
extern
void
hl_avgpool_forward
(
...
...
@@ -163,7 +167,8 @@ extern void hl_avgpool_forward(
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeX
,
const
int
sizeY
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
);
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
);
/**
* @brief Maximum pool backward.
...
...
@@ -184,6 +189,7 @@ extern void hl_avgpool_forward(
* @param[in] scaleA scale.
* @param[in] scaleB scale.
* @param[out] backGrad output grad.
* @param[in] outStride stride between output data samples.
*
*/
extern
void
hl_avgpool_backward
(
...
...
@@ -195,7 +201,7 @@ extern void hl_avgpool_backward(
const
int
strideH
,
const
int
strideW
,
int
paddingH
,
int
paddingW
,
real
scaleA
,
real
scaleB
,
real
*
backGrad
);
real
*
backGrad
,
const
int
outStride
);
/**
* @brief Cross-map-respose normalize forward.
...
...
paddle/cuda/include/stub/hl_cnn_stub.h
浏览文件 @
ca0bb40c
...
...
@@ -44,7 +44,8 @@ inline void hl_maxpool_forward(
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeX
,
const
int
sizeY
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
)
{}
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
)
{}
inline
void
hl_maxpool_backward
(
const
int
frameCnt
,
const
real
*
inputData
,
...
...
@@ -56,7 +57,7 @@ inline void hl_maxpool_backward(
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
real
scaleA
,
real
scaleB
,
real
*
targetGrad
)
{}
real
*
targetGrad
,
const
int
outStride
)
{}
inline
void
hl_avgpool_forward
(
const
int
frameCnt
,
const
real
*
inputData
,
...
...
@@ -65,7 +66,8 @@ inline void hl_avgpool_forward(
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeX
,
const
int
sizeY
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
)
{}
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
)
{}
inline
void
hl_avgpool_backward
(
const
int
frameCnt
,
const
real
*
outGrad
,
...
...
@@ -76,7 +78,7 @@ inline void hl_avgpool_backward(
const
int
strideH
,
const
int
strideW
,
int
paddingH
,
int
paddingW
,
real
scaleA
,
real
scaleB
,
real
*
backGrad
)
{}
real
*
backGrad
,
const
int
outStride
)
{}
inline
void
hl_CMRNorm_forward
(
size_t
frameCnt
,
const
real
*
in
,
real
*
scale
,
real
*
out
,
...
...
paddle/cuda/src/hl_cuda_cnn.cu
浏览文件 @
ca0bb40c
...
...
@@ -152,7 +152,7 @@ __global__ void KeMaxPoolForward(const int nthreads, const real* inputData,
const
int
ksizeW
,
const
int
ksizeH
,
const
int
strideH
,
const
int
strideW
,
const
int
offsetH
,
const
int
offsetW
,
real
*
tgtData
)
{
real
*
tgtData
,
const
int
tgtStride
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
int
pw
=
index
%
pooledW
;
...
...
@@ -173,7 +173,9 @@ __global__ void KeMaxPoolForward(const int nthreads, const real* inputData,
maxval
=
inputData
[
h
*
width
+
w
];
}
}
tgtData
[
index
]
=
maxval
;
int
tgtIndex
=
index
%
(
pooledW
*
pooledH
*
channels
)
+
frameNum
*
tgtStride
;
tgtData
[
tgtIndex
]
=
maxval
;
}
}
...
...
@@ -184,7 +186,7 @@ void hl_maxpool_forward(const int frameCnt, const real* inputData,
const
int
sizeX
,
const
int
sizeY
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
)
{
real
*
tgtData
,
const
int
tgtStride
)
{
int
num_kernels
=
pooledH
*
pooledW
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
...
...
@@ -194,7 +196,7 @@ void hl_maxpool_forward(const int frameCnt, const real* inputData,
KeMaxPoolForward
<<<
grid
,
threads
,
0
,
STREAM_DEFAULT
>>>
(
num_kernels
,
inputData
,
channels
,
height
,
width
,
pooledH
,
pooledW
,
sizeX
,
sizeY
,
strideH
,
strideW
,
paddingH
,
paddingW
,
tgtData
);
paddingH
,
paddingW
,
tgtData
,
tgtStride
);
CHECK_SYNC
(
"hl_maxpool_forward failed"
);
}
...
...
@@ -207,7 +209,7 @@ __global__ void KeMaxPoolBackward(const int nthreads, const real* inputData,
const
int
strideH
,
const
int
strideW
,
const
int
padH
,
const
int
padW
,
real
scaleA
,
real
scaleB
,
real
*
targetGrad
)
{
real
*
targetGrad
,
const
int
outStride
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
// find out the local index
...
...
@@ -223,8 +225,8 @@ __global__ void KeMaxPoolBackward(const int nthreads, const real* inputData,
int
pwend
=
offsetW
>=
0
?
min
(
offsetW
/
strideW
+
1
,
pooledW
)
:
0
;
real
gradient
=
0
;
real
input
=
inputData
[
index
];
outData
+=
(
frameNum
*
channels
+
offsetC
)
*
pooledH
*
pooledW
;
outGrad
+=
(
frameNum
*
channels
+
offsetC
)
*
pooledH
*
pooledW
;
outData
+=
(
frameNum
*
outStride
+
offsetC
*
pooledH
*
pooledW
)
;
outGrad
+=
(
frameNum
*
outStride
+
offsetC
*
pooledH
*
pooledW
)
;
for
(
int
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
int
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
if
(
input
==
outData
[
ph
*
pooledW
+
pw
])
{
...
...
@@ -246,7 +248,7 @@ void hl_maxpool_backward(const int frameCnt, const real* inputData,
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
real
scaleA
,
real
scaleB
,
real
*
targetGrad
)
{
real
*
targetGrad
,
const
int
outStride
)
{
int
num_kernels
=
height
*
width
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
...
...
@@ -257,7 +259,7 @@ void hl_maxpool_backward(const int frameCnt, const real* inputData,
strideH
,
strideW
,
paddingH
,
paddingW
,
scaleA
,
scaleB
,
targetGrad
);
targetGrad
,
outStride
);
CHECK_SYNC
(
"hl_maxpool_backward"
);
}
...
...
@@ -268,7 +270,7 @@ __global__ void KeAvgPoolForward(const int nthreads, const real* inputData,
const
int
sizeX
,
const
int
sizeY
,
const
int
strideH
,
const
int
strideW
,
const
int
padH
,
const
int
padW
,
real
*
tgtData
)
{
real
*
tgtData
,
const
int
tgtStride
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
int
pw
=
index
%
pooledW
;
...
...
@@ -293,7 +295,9 @@ __global__ void KeAvgPoolForward(const int nthreads, const real* inputData,
aveval
+=
inputData
[
h
*
width
+
w
];
}
}
tgtData
[
index
]
=
aveval
/
pool_size
;
int
tgtIndex
=
index
%
(
pooledW
*
pooledH
*
channels
)
+
frameNum
*
tgtStride
;
tgtData
[
tgtIndex
]
=
aveval
/
pool_size
;
}
}
...
...
@@ -303,14 +307,15 @@ void hl_avgpool_forward(const int frameCnt, const real* inputData,
const
int
pooledH
,
const
int
pooledW
,
const
int
sizeX
,
const
int
sizeY
,
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
)
{
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
)
{
int
num_kernels
=
pooledH
*
pooledW
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
KeAvgPoolForward
<<<
blocks
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
num_kernels
,
inputData
,
channels
,
height
,
width
,
pooledH
,
pooledW
,
sizeX
,
sizeY
,
strideH
,
strideW
,
paddingH
,
paddingW
,
tgtData
);
paddingH
,
paddingW
,
tgtData
,
tgtStride
);
CHECK_SYNC
(
"hl_avgpool_forward failed"
);
}
...
...
@@ -322,7 +327,7 @@ __global__ void KeAvgPoolBackward(const int nthreads, const real* outGrad,
const
int
strideH
,
const
int
strideW
,
const
int
padH
,
const
int
padW
,
real
scaleA
,
real
scaleB
,
real
*
tgtGrad
)
{
real
*
tgtGrad
,
const
int
outStride
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
int
offsetW
=
index
%
width
+
padW
;
...
...
@@ -335,7 +340,8 @@ __global__ void KeAvgPoolBackward(const int nthreads, const real* outGrad,
int
phend
=
offsetH
>=
0
?
min
(
offsetH
/
strideH
+
1
,
pooledH
)
:
0
;
int
pwend
=
offsetW
>=
0
?
min
(
offsetW
/
strideW
+
1
,
pooledW
)
:
0
;
real
gradient
=
0
;
outGrad
+=
(
frameNum
*
channels
+
offsetC
)
*
pooledH
*
pooledW
;
outGrad
+=
(
frameNum
*
outStride
+
offsetC
*
pooledH
*
pooledW
);
for
(
int
ph
=
phstart
;
ph
<
phend
;
++
ph
)
{
for
(
int
pw
=
pwstart
;
pw
<
pwend
;
++
pw
)
{
...
...
@@ -360,7 +366,7 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad,
const
int
strideH
,
const
int
strideW
,
const
int
paddingH
,
const
int
paddingW
,
real
scaleA
,
real
scaleB
,
real
*
backGrad
)
{
real
*
backGrad
,
const
int
outStride
)
{
int
num_kernels
=
height
*
width
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
...
...
@@ -370,7 +376,7 @@ void hl_avgpool_backward(const int frameCnt, const real* outGrad,
strideH
,
strideW
,
paddingH
,
paddingW
,
scaleA
,
scaleB
,
backGrad
);
backGrad
,
outStride
);
CHECK_SYNC
(
"hl_avgpool_backward failed"
);
}
...
...
paddle/gserver/layers/PoolLayer.cpp
浏览文件 @
ca0bb40c
...
...
@@ -52,10 +52,8 @@ bool PoolLayer::init(const LayerMap& layerMap,
Layer
*
PoolLayer
::
create
(
const
LayerConfig
&
config
)
{
CHECK_EQ
(
config
.
inputs_size
(),
1
);
const
std
::
string
&
pool
=
config
.
inputs
(
0
).
pool_conf
().
pool_type
();
if
(
pool
==
"max-projection"
)
{
return
new
MaxPoolProjectionLayer
(
config
);
}
else
if
(
pool
==
"avg-projection"
)
{
return
new
AvgPoolProjectionLayer
(
config
);
if
(
pool
==
"max-projection"
||
pool
==
"avg-projection"
)
{
return
new
PoolProjectionLayer
(
config
);
#ifndef PADDLE_ONLY_CPU
}
else
if
(
CudnnPoolLayer
::
typeCheck
(
pool
))
{
return
new
CudnnPoolLayer
(
config
);
...
...
paddle/gserver/layers/PoolProjection.cpp
0 → 100644
浏览文件 @
ca0bb40c
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "PoolProjection.h"
namespace
paddle
{
REGISTER_PROJECTION_CREATE_FUNC
(
pool
,
&
PoolProjection
::
create
);
PoolProjection
::
PoolProjection
(
const
ProjectionConfig
&
config
,
ParameterPtr
parameter
,
bool
useGpu
)
:
Projection
(
config
,
parameter
,
useGpu
)
{
const
PoolConfig
&
conf
=
config_
.
pool_conf
();
poolType_
=
conf
.
pool_type
();
channels_
=
conf
.
channels
();
sizeX_
=
conf
.
size_x
();
stride_
=
conf
.
stride
();
outputX_
=
conf
.
output_x
();
imgSize_
=
conf
.
img_size
();
confPadding_
=
conf
.
padding
();
sizeY_
=
conf
.
has_size_y
()
?
conf
.
size_y
()
:
conf
.
size_x
();
imgSizeY_
=
conf
.
has_img_size_y
()
?
conf
.
img_size_y
()
:
conf
.
img_size
();
strideY_
=
conf
.
has_stride_y
()
?
conf
.
stride_y
()
:
conf
.
stride
();
confPaddingY_
=
conf
.
has_padding_y
()
?
conf
.
padding_y
()
:
conf
.
padding
();
outputY_
=
conf
.
has_output_y
()
?
conf
.
output_y
()
:
conf
.
output_x
();
}
size_t
PoolProjection
::
getSize
()
{
imgSizeY_
=
in_
->
getFrameHeight
();
imgSize_
=
in_
->
getFrameWidth
();
const
PoolConfig
&
conf
=
config_
.
pool_conf
();
if
(
imgSizeY_
==
0
)
{
imgSizeY_
=
conf
.
has_img_size_y
()
?
conf
.
img_size_y
()
:
conf
.
img_size
();
}
if
(
imgSize_
==
0
)
{
imgSize_
=
conf
.
img_size
();
}
outputY_
=
outputSize
(
imgSizeY_
,
sizeY_
,
confPaddingY_
,
strideY_
,
/* caffeMode */
false
);
outputX_
=
outputSize
(
imgSize_
,
sizeX_
,
confPadding_
,
stride_
,
/* caffeMode */
false
);
const_cast
<
Argument
*>
(
out_
)
->
setFrameHeight
(
outputY_
);
const_cast
<
Argument
*>
(
out_
)
->
setFrameWidth
(
outputX_
);
return
outputY_
*
outputX_
*
channels_
;
}
PoolProjection
*
PoolProjection
::
create
(
const
ProjectionConfig
&
config
,
ParameterPtr
parameter
,
bool
useGpu
)
{
const
std
::
string
&
pool
=
config
.
pool_conf
().
pool_type
();
if
(
pool
==
"max-projection"
)
{
return
new
MaxPoolProjection
(
config
,
parameter
,
useGpu
);
}
else
if
(
pool
==
"avg-projection"
)
{
return
new
AvgPoolProjection
(
config
,
parameter
,
useGpu
);
}
else
{
LOG
(
FATAL
)
<<
"Unknown pool type: "
<<
pool
;
return
nullptr
;
}
}
void
MaxPoolProjection
::
forward
()
{
size_t
width
=
getSize
();
CHECK_EQ
(
width
,
out_
->
value
->
getWidth
());
MatrixPtr
inputV
=
in_
->
value
;
MatrixPtr
outV
=
out_
->
value
;
outV
->
maxPoolForward
(
*
inputV
,
imgSizeY_
,
imgSize_
,
channels_
,
sizeX_
,
sizeY_
,
strideY_
,
stride_
,
outputY_
,
outputX_
,
confPaddingY_
,
confPadding_
);
}
void
MaxPoolProjection
::
backward
(
const
UpdateCallback
&
callback
)
{
(
void
)
callback
;
MatrixPtr
outGrad
=
out_
->
grad
;
MatrixPtr
inputV
=
in_
->
value
;
MatrixPtr
outV
=
out_
->
value
;
MatrixPtr
inputGrad
=
in_
->
grad
;
if
(
NULL
==
inputGrad
)
{
return
;
}
inputGrad
->
maxPoolBackward
(
*
inputV
,
imgSizeY_
,
imgSize_
,
*
outGrad
,
*
outV
,
sizeX_
,
sizeY_
,
strideY_
,
stride_
,
outputY_
,
outputX_
,
1
,
1
,
confPaddingY_
,
confPadding_
);
}
void
AvgPoolProjection
::
forward
()
{
size_t
width
=
getSize
();
CHECK_EQ
(
width
,
out_
->
value
->
getWidth
());
MatrixPtr
inputV
=
in_
->
value
;
MatrixPtr
outV
=
out_
->
value
;
outV
->
avgPoolForward
(
*
inputV
,
imgSizeY_
,
imgSize_
,
channels_
,
sizeX_
,
sizeY_
,
strideY_
,
stride_
,
outputY_
,
outputX_
,
confPaddingY_
,
confPadding_
);
}
void
AvgPoolProjection
::
backward
(
const
UpdateCallback
&
callback
)
{
(
void
)
callback
;
MatrixPtr
outputGrad
=
out_
->
grad
;
MatrixPtr
inputGrad
=
in_
->
grad
;
if
(
NULL
==
inputGrad
)
{
return
;
}
inputGrad
->
avgPoolBackward
(
*
outputGrad
,
imgSizeY_
,
imgSize_
,
sizeX_
,
sizeY_
,
strideY_
,
stride_
,
outputY_
,
outputX_
,
1
,
1
,
confPaddingY_
,
confPadding_
);
}
}
// namespace paddle
paddle/gserver/layers/PoolProjection.h
0 → 100644
浏览文件 @
ca0bb40c
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Projection.h"
#include "paddle/math/MathUtils.h"
namespace
paddle
{
class
PoolProjection
:
public
Projection
{
protected:
size_t
imgSizeY_
,
imgSize_
;
size_t
outputY_
,
outputX_
;
size_t
strideY_
,
stride_
;
size_t
sizeY_
,
sizeX_
;
int
confPaddingY_
,
confPadding_
;
size_t
channels_
;
std
::
string
poolType_
;
public:
PoolProjection
(
const
ProjectionConfig
&
config
,
ParameterPtr
parameter
,
bool
useGpu
);
static
PoolProjection
*
create
(
const
ProjectionConfig
&
config
,
ParameterPtr
parameter
,
bool
useGpu
);
const
std
::
string
&
getPoolType
()
const
{
return
poolType_
;
}
size_t
getSize
();
};
class
MaxPoolProjection
:
public
PoolProjection
{
public:
MaxPoolProjection
(
const
ProjectionConfig
&
config
,
ParameterPtr
parameter
,
bool
useGpu
)
:
PoolProjection
(
config
,
parameter
,
useGpu
)
{}
virtual
void
forward
();
virtual
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
);
};
class
AvgPoolProjection
:
public
PoolProjection
{
public:
AvgPoolProjection
(
const
ProjectionConfig
&
config
,
ParameterPtr
parameter
,
bool
useGpu
)
:
PoolProjection
(
config
,
parameter
,
useGpu
)
{}
virtual
void
forward
();
virtual
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
);
};
}
// namespace paddle
paddle/gserver/layers/PoolProjectionLayer.cpp
浏览文件 @
ca0bb40c
...
...
@@ -18,6 +18,7 @@ limitations under the License. */
namespace
paddle
{
size_t
PoolProjectionLayer
::
getSize
()
{
CHECK_EQ
(
inputLayers_
.
size
(),
1UL
);
size_t
layerSize
=
0
;
...
...
@@ -37,74 +38,23 @@ size_t PoolProjectionLayer::getSize() {
layerSize
=
outputH_
*
outputW_
*
channels_
;
getOutput
().
setFrameHeight
(
outputH_
);
getOutput
().
setFrameWidth
(
outputW_
);
return
layerSize
;
}
void
MaxPoolProjectionLayer
::
forward
(
PassType
passType
)
{
Layer
::
forward
(
passType
);
/* malloc memory for the output_ if necessary */
/* note: one sample correspond to one ROW */
MatrixPtr
input
=
getInputValue
(
0
);
int
batchSize
=
input
->
getHeight
();
int
size
=
getSize
();
resetOutput
(
batchSize
,
size
);
MatrixPtr
outV
=
getOutputValue
();
outV
->
maxPoolForward
(
*
input
,
imgSizeH_
,
imgSizeW_
,
channels_
,
sizeX_
,
sizeY_
,
strideY_
,
stride_
,
outputH_
,
outputW_
,
confPaddingY_
,
confPadding_
);
}
void
MaxPoolProjectionLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
(
void
)
callback
;
if
(
NULL
==
getInputGrad
(
0
))
{
return
;
}
/* Do derivation */
MatrixPtr
outGrad
=
getOutputGrad
();
MatrixPtr
inputV
=
getInputValue
(
0
);
MatrixPtr
outV
=
getOutputValue
();
MatrixPtr
inputGrad
=
getInputGrad
(
0
);
inputGrad
->
maxPoolBackward
(
*
inputV
,
imgSizeH_
,
imgSizeW_
,
*
outGrad
,
*
outV
,
sizeX_
,
sizeY_
,
strideY_
,
stride_
,
outputH_
,
outputW_
,
1
,
1
,
confPaddingY_
,
confPadding_
);
}
void
AvgPoolProjectionLayer
::
forward
(
PassType
passType
)
{
void
PoolProjectionLayer
::
forward
(
PassType
passType
)
{
Layer
::
forward
(
passType
);
/* malloc memory for the output_ if necessary */
/* note: one sample correspond to one ROW */
MatrixPtr
input
=
getInputValue
(
0
);
int
batchSize
=
input
->
getHeight
();
const
Argument
&
in
=
getInput
(
0
);
int
batchSize
=
in
.
value
->
getHeight
();
int
size
=
getSize
();
resetOutput
(
batchSize
,
size
);
MatrixPtr
outV
=
getOutputValue
();
outV
->
avgPoolForward
(
*
input
,
imgSizeH_
,
imgSizeW_
,
channels_
,
sizeX_
,
sizeY_
,
strideY_
,
stride_
,
outputH_
,
outputW_
,
confPaddingY_
,
confPadding_
);
poolProjection_
->
forward
(
&
in
,
&
output_
,
passType
);
}
void
Avg
PoolProjectionLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
void
PoolProjectionLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
(
void
)
callback
;
if
(
NULL
==
getInputGrad
(
0
))
{
return
;
}
/* Do derivation */
MatrixPtr
outputGrad
=
getOutputGrad
();
MatrixPtr
inputGrad
=
getInputGrad
(
0
);
inputGrad
->
avgPoolBackward
(
*
outputGrad
,
imgSizeH_
,
imgSizeW_
,
sizeX_
,
sizeY_
,
strideY_
,
stride_
,
outputH_
,
outputW_
,
1
,
1
,
confPaddingY_
,
confPadding_
);
poolProjection_
->
backward
(
callback
);
}
}
// namespace paddle
paddle/gserver/layers/PoolProjectionLayer.h
浏览文件 @
ca0bb40c
...
...
@@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "PoolLayer.h"
#include "PoolProjection.h"
#include "paddle/math/Matrix.h"
#include <vector>
namespace
paddle
{
/**
...
...
@@ -27,33 +27,18 @@ class PoolProjectionLayer : public PoolLayer {
protected:
size_t
imgSizeH_
,
imgSizeW_
;
size_t
outputH_
,
outputW_
;
std
::
unique_ptr
<
PoolProjection
>
poolProjection_
;
ProjectionConfig
projectionConfig_
;
public:
size_t
getSize
();
explicit
PoolProjectionLayer
(
const
LayerConfig
&
config
)
:
PoolLayer
(
config
)
{}
};
/**
* @brief A layer for max pooling
*/
class
MaxPoolProjectionLayer
:
public
PoolProjectionLayer
{
public:
explicit
MaxPoolProjectionLayer
(
const
LayerConfig
&
config
)
:
PoolProjectionLayer
(
config
)
{}
~
MaxPoolProjectionLayer
()
{}
explicit
PoolProjectionLayer
(
const
LayerConfig
&
config
)
:
PoolLayer
(
config
)
{
PoolConfig
*
conf
=
projectionConfig_
.
mutable_pool_conf
();
*
conf
=
config_
.
inputs
(
0
).
pool_conf
();
poolProjection_
.
reset
(
PoolProjection
::
create
(
projectionConfig_
,
nullptr
,
useGpu_
));
}
virtual
void
forward
(
PassType
passType
);
virtual
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
);
};
/**
* @brief A layer for average pooling
*/
class
AvgPoolProjectionLayer
:
public
PoolProjectionLayer
{
public:
explicit
AvgPoolProjectionLayer
(
const
LayerConfig
&
config
)
:
PoolProjectionLayer
(
config
)
{}
~
AvgPoolProjectionLayer
()
{}
size_t
getSize
();
virtual
void
forward
(
PassType
passType
);
virtual
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
);
...
...
paddle/gserver/layers/Projection.h
浏览文件 @
ca0bb40c
...
...
@@ -12,12 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/parameter/Parameter.h"
#include "ModelConfig.pb.h"
#include "Layer.h"
#include "ModelConfig.pb.h"
#include "paddle/parameter/Parameter.h"
namespace
paddle
{
...
...
@@ -28,6 +27,11 @@ namespace paddle {
Projection::registrar_.registerClass<__class_name>(#__type_name); \
})
#define REGISTER_PROJECTION_CREATE_FUNC(__type_name, createFunction) \
static InitFunction __reg_type_##__type_name([]() { \
Projection::registrar_.registerClass(#__type_name, createFunction); \
})
/**
* A projection takes one Argument as input, calculate the result and add it
* to output Argument.
...
...
@@ -50,7 +54,8 @@ public:
registrar_
;
/**
* Forward propagation. If backward() will be called, in and out must be kept valid until then.
* Forward propagation. If backward() will be called, in and out must be kept
* valid until then.
* @param in input of projection
* @param out output of projection
* @param passType PASS_TRAIN of PASS_TEST
...
...
paddle/gserver/layers/SpatialPyramidPoolLayer.cpp
0 → 100644
浏览文件 @
ca0bb40c
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "SpatialPyramidPoolLayer.h"
namespace
paddle
{
REGISTER_LAYER
(
spp
,
SpatialPyramidPoolLayer
);
ProjectionConfig
SpatialPyramidPoolLayer
::
getConfig
(
size_t
imgSizeW
,
size_t
imgSizeH
,
size_t
channels
,
size_t
pyramidLevel
,
std
::
string
&
poolType
)
{
ProjectionConfig
config
;
config
.
set_type
(
"pool"
);
PoolConfig
*
conf
=
config
.
mutable_pool_conf
();
conf
->
set_channels
(
channels
);
conf
->
set_img_size
(
imgSizeW
);
conf
->
set_img_size_y
(
imgSizeH
);
conf
->
set_pool_type
(
poolType
);
int
numBins
=
std
::
pow
(
2
,
pyramidLevel
);
int
sizeH
=
std
::
ceil
(
imgSizeH
/
static_cast
<
double
>
(
numBins
));
int
paddingH
=
(
sizeH
*
numBins
-
imgSizeH
+
1
)
/
2
;
int
outSizeH
=
outputSize
(
imgSizeH
,
sizeH
,
paddingH
,
sizeH
,
true
);
int
sizeW
=
std
::
ceil
(
imgSizeW
/
static_cast
<
double
>
(
numBins
));
int
paddingW
=
(
sizeW
*
numBins
-
imgSizeW
+
1
)
/
2
;
int
outSizeW
=
outputSize
(
imgSizeW
,
sizeW
,
paddingW
,
sizeW
,
true
);
conf
->
set_stride
(
sizeW
);
conf
->
set_stride_y
(
sizeH
);
conf
->
set_size_x
(
sizeW
);
conf
->
set_size_y
(
sizeH
);
conf
->
set_padding
(
paddingW
);
conf
->
set_padding_y
(
paddingH
);
conf
->
set_output_x
(
outSizeW
);
conf
->
set_output_y
(
outSizeH
);
config
.
set_output_size
(
outSizeH
*
outSizeW
*
channels
);
return
config
;
}
size_t
SpatialPyramidPoolLayer
::
getSize
()
{
CHECK_EQ
(
inputLayers_
.
size
(),
1UL
);
size_t
layerSize
=
0
;
const
SppConfig
&
sppConf
=
config_
.
inputs
(
0
).
spp_conf
();
imgSizeH_
=
inputLayers_
[
0
]
->
getOutput
().
getFrameHeight
();
imgSizeW_
=
inputLayers_
[
0
]
->
getOutput
().
getFrameWidth
();
if
(
imgSizeH_
==
0
)
{
imgSizeH_
=
sppConf
.
has_img_size_y
()
?
sppConf
.
img_size_y
()
:
imgSizeW_
;
}
if
(
imgSizeW_
==
0
)
{
imgSizeW_
=
sppConf
.
img_size
();
}
size_t
outputH
=
1
;
size_t
outputW
=
(
std
::
pow
(
4
,
pyramidHeight_
)
-
1
)
/
(
4
-
1
);
layerSize
=
outputH
*
outputW
*
channels_
;
return
layerSize
;
}
bool
SpatialPyramidPoolLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
Layer
::
init
(
layerMap
,
parameterMap
);
CHECK_EQ
(
config_
.
inputs_size
(),
1
);
const
SppConfig
&
sppConf
=
config_
.
inputs
(
0
).
spp_conf
();
pyramidHeight_
=
sppConf
.
pyramid_height
();
poolType_
=
sppConf
.
pool_type
();
channels_
=
sppConf
.
channels
();
imgSizeW_
=
sppConf
.
img_size
();
imgSizeH_
=
sppConf
.
has_img_size_y
()
?
sppConf
.
img_size_y
()
:
imgSizeW_
;
poolProjections_
.
reserve
(
pyramidHeight_
);
projCol_
.
reserve
(
pyramidHeight_
);
projOutput_
.
resize
(
pyramidHeight_
);
size_t
startCol
=
0
;
size_t
endCol
=
0
;
for
(
size_t
i
=
0
;
i
<
pyramidHeight_
;
i
++
)
{
poolProjections_
.
emplace_back
(
PoolProjection
::
create
(
getConfig
(
imgSizeW_
,
imgSizeH_
,
channels_
,
i
,
poolType_
),
nullptr
,
useGpu_
));
endCol
+=
poolProjections_
[
i
]
->
getOutputSize
();
projCol_
.
push_back
(
std
::
make_pair
(
startCol
,
endCol
));
startCol
=
endCol
;
}
CHECK_EQ
(
endCol
,
getSize
());
return
true
;
}
void
SpatialPyramidPoolLayer
::
forward
(
PassType
passType
)
{
Layer
::
forward
(
passType
);
int
batchSize
=
getInput
(
0
).
getBatchSize
();
resetOutput
(
batchSize
,
getSize
());
for
(
size_t
i
=
0
;
i
<
pyramidHeight_
;
i
++
)
{
size_t
startCol
=
projCol_
[
i
].
first
;
size_t
endCol
=
projCol_
[
i
].
second
;
projOutput_
[
i
].
value
=
output_
.
value
->
subColMatrix
(
startCol
,
endCol
);
projOutput_
[
i
].
grad
=
output_
.
grad
->
subColMatrix
(
startCol
,
endCol
);
}
for
(
size_t
i
=
0
;
i
<
pyramidHeight_
;
i
++
)
{
poolProjections_
[
i
]
->
forward
(
&
getInput
(
0
),
&
projOutput_
[
i
],
passType
);
}
}
void
SpatialPyramidPoolLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
for
(
size_t
i
=
0
;
i
<
pyramidHeight_
;
i
++
)
{
if
(
poolProjections_
[
i
])
{
poolProjections_
[
i
]
->
backward
(
callback
);
}
}
}
}
// namespace paddle
paddle/gserver/layers/SpatialPyramidPoolLayer.h
0 → 100644
浏览文件 @
ca0bb40c
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
#include "PoolProjection.h"
#include "paddle/math/MathUtils.h"
#include "paddle/utils/Logging.h"
namespace
paddle
{
/**
* @brief A layer for spatial pyramid pooling on the input image by taking
* the max, average, etc. within regions, so that the result vector of
* different sized images are of the same size.
*
* The config file api is spp_layer.
*/
class
SpatialPyramidPoolLayer
:
public
Layer
{
protected:
size_t
channels_
;
size_t
imgSizeW_
;
size_t
imgSizeH_
;
size_t
pyramidHeight_
;
std
::
string
poolType_
;
std
::
vector
<
std
::
unique_ptr
<
PoolProjection
>>
poolProjections_
;
std
::
vector
<
Argument
>
projOutput_
;
std
::
vector
<
std
::
pair
<
size_t
,
size_t
>>
projCol_
;
public:
explicit
SpatialPyramidPoolLayer
(
const
LayerConfig
&
config
)
:
Layer
(
config
)
{}
~
SpatialPyramidPoolLayer
()
{}
virtual
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
);
ProjectionConfig
getConfig
(
size_t
sizeX_
,
size_t
sizeY_
,
size_t
channels
,
size_t
pyamidLevel_
,
std
::
string
&
poolType_
);
size_t
getSize
();
virtual
void
forward
(
PassType
passType
);
virtual
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
);
};
}
// namespace paddle
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
ca0bb40c
...
...
@@ -13,15 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <vector>
#include <string>
#include
"paddle/gserver/layers/DataLayer.h"
#include
<vector>
#include "ModelConfig.pb.h"
#include "paddle/gserver/layers/DataLayer.h"
#include "paddle/trainer/Trainer.h"
#include "paddle/math/MathUtils.h"
#include "TestUtil.h"
#include "LayerGradUtil.h"
#include "TestUtil.h"
using
namespace
paddle
;
// NOLINT
using
namespace
std
;
// NOLINT
...
...
@@ -981,6 +981,32 @@ TEST(Layer, PoolLayer) {
#endif
}
void
testSppLayer
(
const
string
&
poolType
,
const
int
pyramidHeight
,
bool
trans
,
bool
useGpu
)
{
TestConfig
config
;
config
.
layerConfig
.
set_type
(
"spp"
);
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"layer_0"
,
3200
,
0
});
LayerInputConfig
*
input
=
config
.
layerConfig
.
add_inputs
();
SppConfig
*
sppConfig
=
input
->
mutable_spp_conf
();
sppConfig
->
set_pool_type
(
poolType
);
sppConfig
->
set_pyramid_height
(
pyramidHeight
);
sppConfig
->
set_channels
(
16
);
sppConfig
->
set_img_size
(
10
);
sppConfig
->
set_img_size_y
(
20
);
int
outputSize
=
(
std
::
pow
(
4
,
sppConfig
->
pyramid_height
())
-
1
)
/
(
4
-
1
);
config
.
layerConfig
.
set_size
(
outputSize
*
sppConfig
->
channels
());
testLayerGrad
(
config
,
"spp"
,
100
,
trans
,
useGpu
);
}
TEST
(
Layer
,
SpatialPyramidPoolLayer
)
{
for
(
auto
useGpu
:
{
false
,
true
})
{
for
(
auto
pyramidHeight
:
{
1
,
2
,
3
})
{
testSppLayer
(
"avg-projection"
,
pyramidHeight
,
false
,
useGpu
);
testSppLayer
(
"max-projection"
,
pyramidHeight
,
false
,
useGpu
);
}
}
}
TEST
(
Layer
,
rankCostLayer
)
{
TestConfig
config
;
config
.
layerConfig
.
set_type
(
"rank-cost"
);
...
...
paddle/math/Matrix.cpp
浏览文件 @
ca0bb40c
此差异已折叠。
点击以展开。
proto/ModelConfig.proto.m4
浏览文件 @
ca0bb40c
...
...
@@ -120,6 +120,14 @@ message PoolConfig {
optional uint32 padding_y = 13 [default = 0];
}
message SppConfig {
required string pool_type = 1;
required uint32 pyramid_height = 2;
required uint32 channels = 3;
required uint32 img_size = 4;
optional uint32 img_size_y = 5;
}
message NormConfig {
// rnorm or cmrnorm
required string norm_type = 1;
...
...
@@ -196,6 +204,9 @@ message ProjectionConfig {
// For IdentityOffsetProjection
optional uint64 offset = 11 [default = 0];
// For pool
optional PoolConfig pool_conf = 12;
}
message OperatorConfig {
...
...
@@ -245,6 +256,7 @@ message LayerInputConfig {
optional string input_layer_argument = 9;
optional BilinearInterpConfig bilinear_interp_conf = 10;
optional MaxOutConfig maxout_conf = 11;
optional SppConfig spp_conf = 12;
}
message LayerConfig {
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
ca0bb40c
...
...
@@ -471,6 +471,7 @@ class Input(Cfg):
image
=
None
,
block_expand
=
None
,
maxout
=
None
,
spp
=
None
,
format
=
None
,
nnz
=
None
,
is_static
=
None
,
...
...
@@ -671,7 +672,6 @@ class ConvProjection(Projection):
def
calc_parameter_dims
(
self
,
input_size
,
output_size
):
return
None
# Define a operator for mixed layer
@
config_class
class
Operator
(
Cfg
):
...
...
@@ -795,6 +795,17 @@ class Pool(Cfg):
padding
=
None
,
padding_y
=
None
):
self
.
add_keys
(
locals
())
# please refer to the comments in proto/ModelConfig.proto
@
config_class
class
SpatialPyramidPool
(
Cfg
):
def
__init__
(
self
,
pool_type
,
pyramid_height
,
channels
,
img_width
=
None
):
self
.
add_keys
(
locals
())
# please refer to the comments in proto/ModelConfig.proto
@
config_class
...
...
@@ -1081,6 +1092,22 @@ def parse_pool(pool, input_layer_name, pool_conf):
pool_conf
.
output_y
=
cnn_output_size
(
pool_conf
.
img_size_y
,
pool_conf
.
size_y
,
pool_conf
.
padding_y
,
pool_conf
.
stride_y
,
False
)
def
parse_spp
(
spp
,
input_layer_name
,
spp_conf
):
spp_conf
.
pool_type
=
spp
.
pool_type
config_assert
(
spp
.
pool_type
in
[
'max-projection'
,
'avg-projection'
],
"pool-type %s is not in "
"['max-projection', 'avg-projection']"
%
spp
.
pool_type
)
spp_conf
.
pyramid_height
=
spp
.
pyramid_height
spp_conf
.
channels
=
spp
.
channels
img_pixels
=
g_layer_map
[
input_layer_name
].
size
/
spp_conf
.
channels
spp_conf
.
img_size
=
default
(
spp
.
img_width
,
int
(
img_pixels
**
0.5
))
spp_conf
.
img_size_y
=
img_pixels
/
spp_conf
.
img_size
config_assert
(
spp_conf
.
img_size
*
spp_conf
.
img_size_y
==
img_pixels
,
"Incorrect input image size %d for input image pixels %d"
%
(
spp_conf
.
img_size
,
img_pixels
))
def
parse_image
(
image
,
input_layer_name
,
image_conf
):
image_conf
.
channels
=
image
.
channels
image_pixels
=
g_layer_map
[
input_layer_name
].
size
/
image_conf
.
channels
...
...
@@ -1756,6 +1783,25 @@ class PoolLayer(LayerBase):
name
,
pool_conf
.
output_y
,
pool_conf
.
output_x
))
self
.
set_layer_size
((
pool_conf
.
output_x
*
pool_conf
.
output_y
)
*
pool_conf
.
channels
)
@
config_layer
(
'spp'
)
class
SpatialPyramidPoolLayer
(
LayerBase
):
def
__init__
(
self
,
name
,
inputs
,
device
=
None
):
super
(
SpatialPyramidPoolLayer
,
self
).
__init__
(
name
,
'spp'
,
0
,
inputs
=
inputs
,
device
=
device
)
for
input_index
in
xrange
(
len
(
self
.
inputs
)):
input_layer
=
self
.
get_input_layer
(
input_index
)
parse_spp
(
self
.
inputs
[
input_index
].
spp
,
input_layer
.
name
,
self
.
config
.
inputs
[
input_index
].
spp_conf
)
spp_conf
=
self
.
config
.
inputs
[
input_index
].
spp_conf
output_size
=
(
pow
(
4
,
spp_conf
.
pyramid_height
)
-
1
)
/
(
4
-
1
)
print
(
"output size for %s is %d "
%
(
name
,
output_size
))
self
.
set_layer_size
(
output_size
*
spp_conf
.
channels
)
@
config_layer
(
'batch_norm'
)
class
BatchNormLayer
(
LayerBase
):
layer_type
=
'batch_norm'
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
ca0bb40c
...
...
@@ -56,7 +56,8 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel",
'multi_binary_label_cross_entropy'
,
'sum_cost'
,
'rank_cost'
,
'lambda_cost'
,
'huber_cost'
,
'block_expand_layer'
,
'maxout_layer'
,
'out_prod_layer'
,
'print_layer'
'maxout_layer'
,
'out_prod_layer'
,
'print_layer'
,
'spp_layer'
,
]
...
...
@@ -115,6 +116,7 @@ class LayerType(object):
LINEAR_COMBINATION_LAYER
=
"convex_comb"
BLOCK_EXPAND
=
"blockexpand"
MAXOUT
=
"maxout"
SPP_LAYER
=
"spp"
PRINT_LAYER
=
"print"
...
...
@@ -877,6 +879,7 @@ def pooling_layer(input, pooling_type=None, name=None, bias_attr=None,
size
=
input
.
size
)
@
wrap_bias_attr_default
()
@
wrap_param_attr_default
()
@
wrap_act_default
(
param_names
=
[
'gate_act'
],
...
...
@@ -1820,6 +1823,62 @@ def img_pool_layer(input, pool_size, name=None,
num_filters
=
num_channels
,
size
=
l
.
config
.
size
)
@
wrap_name_default
(
"spp"
)
@
layer_support
()
def
spp_layer
(
input
,
name
=
None
,
num_channels
=
None
,
pool_type
=
None
,
pyramid_height
=
None
,
img_width
=
None
,
layer_attr
=
None
):
pass
"""
Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition.
The details please refer to
`Kaiming He's paper <https://arxiv.org/abs/1406.4729>`_.
:param name: layer name.
:type name: basestring
:param input: layer's input.
:type input: LayerOutput
:param num_channels: number of input channel.
:type num_channels: int
:param pool_type: Pooling type. MaxPooling or AveragePooling. Default is MaxPooling.
:type scale: BasePoolingType
:param pyramid_height: pyramid height.
:type pyramid_height: int
:param img_width: the width of input feature map. If it is None, the input feature
map should be square.
:type img_width: int|None
:param layer_attr: Extra Layer Attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
if
num_channels
is
None
:
assert
input
.
num_filters
is
not
None
num_channels
=
input
.
num_filters
if
pool_type
is
None
:
pool_type
=
MaxPooling
()
elif
isinstance
(
pool_type
,
AvgPooling
):
pool_type
.
name
=
'avg'
type_name
=
pool_type
.
name
if
(
isinstance
(
pool_type
,
AvgPooling
)
or
isinstance
(
pool_type
,
MaxPooling
)):
type_name
+=
'-projection'
Layer
(
name
=
name
,
type
=
LayerType
.
SPP_LAYER
,
inputs
=
Input
(
input
.
name
,
spp
=
SpatialPyramidPool
(
pool_type
=
type_name
,
channels
=
num_channels
,
pyramid_height
=
pyramid_height
,
img_width
=
img_width
)
),
**
ExtraLayerAttribute
.
to_kwargs
(
layer_attr
)
)
return
LayerOutput
(
name
,
LayerType
.
SPP_LAYER
,
parents
=
[
input
],
num_filters
=
num_channels
)
def
__img_norm_layer__
(
name
,
input
,
size
,
norm_type
,
scale
,
power
,
num_channels
,
blocked
,
layer_attr
):
if
num_channels
is
None
:
...
...
python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh
浏览文件 @
ca0bb40c
...
...
@@ -11,7 +11,7 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer
last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers
test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight
test_bilinear_interp test_maxout test_bi_grumemory math_ops
)
test_
spp_layer test_
bilinear_interp test_maxout test_bi_grumemory math_ops
)
for
conf
in
${
configs
[*]
}
...
...
python/paddle/trainer_config_helpers/tests/configs/protostr/test_spp_layer.protostr
0 → 100644
浏览文件 @
ca0bb40c
type: "nn"
layers {
name: "data"
type: "data"
size: 3200
active_type: ""
}
layers {
name: "__spp_0__"
type: "spp"
size: 80
active_type: ""
inputs {
input_layer_name: "data"
spp_conf {
pool_type: "max-projection"
pyramid_height: 2
channels: 16
img_size: 10
img_size_y: 20
}
}
}
input_layer_names: "data"
output_layer_names: "__spp_0__"
sub_models {
name: "root"
layer_names: "data"
layer_names: "__spp_0__"
input_layer_names: "data"
output_layer_names: "__spp_0__"
is_recurrent_layer_group: false
}
python/paddle/trainer_config_helpers/tests/configs/test_spp_layer.py
0 → 100644
浏览文件 @
ca0bb40c
from
paddle.trainer_config_helpers
import
*
settings
(
batch_size
=
100
,
learning_rate
=
1e-5
)
data
=
data_layer
(
name
=
'data'
,
size
=
3200
)
spp
=
spp_layer
(
input
=
data
,
pyramid_height
=
2
,
num_channels
=
16
,
pool_type
=
MaxPooling
(),
img_width
=
10
)
outputs
(
spp
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录