Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3e6f7684
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3e6f7684
编写于
11月 14, 2017
作者:
Z
Zhaolong Xing
提交者:
GitHub
11月 14, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4891 from NHZlX/poolmaxpool_with_mask
max pool Layer with mask
上级
4adc8a7a
e19b931a
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
357 addition
and
29 deletion
+357
-29
paddle/cuda/include/hl_cnn.h
paddle/cuda/include/hl_cnn.h
+4
-3
paddle/cuda/include/stub/hl_cnn_stub.h
paddle/cuda/include/stub/hl_cnn_stub.h
+2
-1
paddle/cuda/src/hl_cuda_cnn.cu
paddle/cuda/src/hl_cuda_cnn.cu
+14
-5
paddle/gserver/layers/MaxPoolWithMaskLayer.cpp
paddle/gserver/layers/MaxPoolWithMaskLayer.cpp
+109
-0
paddle/gserver/layers/MaxPoolWithMaskLayer.h
paddle/gserver/layers/MaxPoolWithMaskLayer.h
+40
-0
paddle/gserver/layers/PoolLayer.cpp
paddle/gserver/layers/PoolLayer.cpp
+3
-1
paddle/gserver/tests/CMakeLists.txt
paddle/gserver/tests/CMakeLists.txt
+1
-0
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+2
-0
paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp
paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp
+117
-0
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+36
-7
paddle/math/Matrix.h
paddle/math/Matrix.h
+8
-4
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+3
-3
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+3
-3
python/paddle/trainer_config_helpers/poolings.py
python/paddle/trainer_config_helpers/poolings.py
+15
-2
未找到文件。
paddle/cuda/include/hl_cnn.h
浏览文件 @
3e6f7684
...
...
@@ -18,7 +18,7 @@ limitations under the License. */
#include "hl_base.h"
/**
* @brief Maximum pool forward.
* @brief Maximum pool forward
with Mask output
.
*
* @param[in] frameCnt batch size of input image.
* @param[in] inputData input data.
...
...
@@ -35,7 +35,7 @@ limitations under the License. */
* @param[in] paddingW padding width.
* @param[out] tgtData output data.
* @param[in] tgtStride stride between output data samples.
*
*
@param[out] maskData the location indices of select max data.
*/
extern
void
hl_maxpool_forward
(
const
int
frameCnt
,
const
real
*
inputData
,
...
...
@@ -51,7 +51,8 @@ extern void hl_maxpool_forward(const int frameCnt,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
);
const
int
tgtStride
,
real
*
maskData
=
NULL
);
/**
* @brief Maximum pool backward.
...
...
paddle/cuda/include/stub/hl_cnn_stub.h
浏览文件 @
3e6f7684
...
...
@@ -31,7 +31,8 @@ inline void hl_maxpool_forward(const int frameCnt,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
)
{}
const
int
tgtStride
,
real
*
MaskData
)
{}
inline
void
hl_maxpool_backward
(
const
int
frameCnt
,
const
real
*
inputData
,
...
...
paddle/cuda/src/hl_cuda_cnn.cu
浏览文件 @
3e6f7684
...
...
@@ -31,7 +31,8 @@ __global__ void KeMaxPoolForward(const int nthreads,
const
int
offsetH
,
const
int
offsetW
,
real
*
tgtData
,
const
int
tgtStride
)
{
const
int
tgtStride
,
real
*
maskData
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
int
pw
=
index
%
pooledW
;
...
...
@@ -45,16 +46,22 @@ __global__ void KeMaxPoolForward(const int nthreads,
hstart
=
max
(
hstart
,
0
);
wstart
=
max
(
wstart
,
0
);
real
maxval
=
-
FLT_MAX
;
int
max_index
=
-
1
;
inputData
+=
(
frameNum
*
channels
+
c
)
*
height
*
width
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
if
(
maxval
<
inputData
[
h
*
width
+
w
])
maxval
=
inputData
[
h
*
width
+
w
];
if
(
maxval
<
inputData
[
h
*
width
+
w
])
{
max_index
=
h
*
width
+
w
;
maxval
=
inputData
[
max_index
];
}
}
}
int
tgtIndex
=
index
%
(
pooledW
*
pooledH
*
channels
)
+
frameNum
*
tgtStride
;
tgtData
[
tgtIndex
]
=
maxval
;
if
(
maskData
!=
NULL
)
{
maskData
[
tgtIndex
]
=
max_index
;
}
}
}
...
...
@@ -72,7 +79,8 @@ void hl_maxpool_forward(const int frameCnt,
const
int
paddingH
,
const
int
paddingW
,
real
*
tgtData
,
const
int
tgtStride
)
{
const
int
tgtStride
,
real
*
maskData
)
{
int
num_kernels
=
pooledH
*
pooledW
*
channels
*
frameCnt
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
dim3
threads
(
1024
,
1
);
...
...
@@ -92,7 +100,8 @@ void hl_maxpool_forward(const int frameCnt,
paddingH
,
paddingW
,
tgtData
,
tgtStride
);
tgtStride
,
maskData
);
CHECK_SYNC
(
"hl_maxpool_forward failed"
);
}
...
...
paddle/gserver/layers/MaxPoolWithMaskLayer.cpp
0 → 100644
浏览文件 @
3e6f7684
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "MaxPoolWithMaskLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace
paddle
{
bool
MaxPoolWithMaskLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
PoolLayer
::
init
(
layerMap
,
parameterMap
);
setOutput
(
"mask"
,
&
mask_
);
return
true
;
}
size_t
MaxPoolWithMaskLayer
::
getSize
()
{
CHECK_EQ
(
inputLayers_
.
size
(),
1UL
);
size_t
layerSize
=
0
;
outputY_
=
outputSize
(
imgSizeY_
,
sizeY_
,
confPaddingY_
,
strideY_
,
/* caffeMode */
false
);
outputX_
=
outputSize
(
imgSize_
,
sizeX_
,
confPadding_
,
stride_
,
/* caffeMode */
false
);
layerSize
=
outputX_
*
outputY_
*
channels_
;
getOutput
().
setFrameHeight
(
outputY_
);
getOutput
().
setFrameWidth
(
outputX_
);
return
layerSize
;
}
void
MaxPoolWithMaskLayer
::
forward
(
PassType
passType
)
{
size_t
size
=
getSize
();
MatrixPtr
inputV
=
inputLayers_
[
0
]
->
getOutputValue
();
int
batchSize
=
inputV
->
getHeight
();
resetOutput
(
batchSize
,
size
);
MatrixPtr
outV
=
getOutputValue
();
CHECK_EQ
(
size
,
outV
->
getWidth
());
resetSpecifyOutput
(
mask_
,
batchSize
,
size
,
/* isValueClean */
false
,
/* isGradClean */
true
);
MatrixPtr
maskV
=
mask_
.
value
;
outV
->
maxPoolForward
(
*
inputV
,
imgSizeY_
,
imgSize_
,
channels_
,
sizeX_
,
sizeY_
,
strideY_
,
stride_
,
outputY_
,
outputX_
,
confPaddingY_
,
confPadding_
,
maskV
);
}
void
MaxPoolWithMaskLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
(
void
)
callback
;
if
(
NULL
==
getInputGrad
(
0
))
{
return
;
}
MatrixPtr
outGrad
=
getOutputGrad
();
MatrixPtr
inputV
=
inputLayers_
[
0
]
->
getOutputValue
();
MatrixPtr
outV
=
getOutputValue
();
MatrixPtr
inputGrad
=
inputLayers_
[
0
]
->
getOutputGrad
();
inputGrad
->
maxPoolBackward
(
*
inputV
,
imgSizeY_
,
imgSize_
,
*
outGrad
,
*
outV
,
sizeX_
,
sizeY_
,
strideY_
,
stride_
,
outputY_
,
outputX_
,
1
,
1
,
confPaddingY_
,
confPadding_
);
}
}
// namespace paddle
paddle/gserver/layers/MaxPoolWithMaskLayer.h
0 → 100644
浏览文件 @
3e6f7684
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "PoolLayer.h"
#include "paddle/math/Matrix.h"
namespace
paddle
{
/**
* @brief Basic parent layer of different kinds of pooling
*/
class
MaxPoolWithMaskLayer
:
public
PoolLayer
{
protected:
Argument
mask_
;
public:
explicit
MaxPoolWithMaskLayer
(
const
LayerConfig
&
config
)
:
PoolLayer
(
config
)
{}
size_t
getSize
();
void
forward
(
PassType
passType
)
override
;
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
)
override
;
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
override
;
};
}
// namespace paddle
paddle/gserver/layers/PoolLayer.cpp
浏览文件 @
3e6f7684
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "PoolLayer.h"
#include "MaxPoolWithMaskLayer.h"
#include "PoolProjectionLayer.h"
#include "paddle/utils/Logging.h"
#ifdef PADDLE_WITH_CUDA
...
...
@@ -44,7 +45,6 @@ bool PoolLayer::init(const LayerMap& layerMap,
strideY_
=
conf
.
has_stride_y
()
?
conf
.
stride_y
()
:
conf
.
stride
();
confPaddingY_
=
conf
.
has_padding_y
()
?
conf
.
padding_y
()
:
conf
.
padding
();
outputY_
=
conf
.
has_output_y
()
?
conf
.
output_y
()
:
conf
.
output_x
();
return
true
;
}
...
...
@@ -57,6 +57,8 @@ Layer* PoolLayer::create(const LayerConfig& config) {
}
else
if
(
CudnnPoolLayer
::
typeCheck
(
pool
))
{
return
new
CudnnPoolLayer
(
config
);
#endif
}
else
if
(
pool
==
"max-pool-with-mask"
)
{
return
new
MaxPoolWithMaskLayer
(
config
);
}
else
{
LOG
(
FATAL
)
<<
"Unknown pool type: "
<<
pool
;
return
nullptr
;
...
...
paddle/gserver/tests/CMakeLists.txt
浏览文件 @
3e6f7684
...
...
@@ -24,6 +24,7 @@ gserver_test(test_ConvUnify)
gserver_test
(
test_BatchNorm
)
gserver_test
(
test_KmaxSeqScore
)
gserver_test
(
test_Expand
)
gserver_test
(
test_MaxPoolingWithMaskOutput
)
########## test_Mkldnn layers and activations ##########
if
(
WITH_MKLDNN
)
...
...
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
3e6f7684
...
...
@@ -1234,6 +1234,7 @@ void testPoolLayer2(const string& poolType, bool trans, bool useGpu) {
TEST
(
Layer
,
PoolLayer
)
{
testPoolLayer
(
"avg-projection"
,
/* trans= */
false
,
/* useGpu= */
false
);
testPoolLayer
(
"max-projection"
,
/* trans= */
false
,
/* useGpu= */
false
);
testPoolLayer
(
"max-pool-with-mask"
,
/* trans= */
false
,
/* useGpu= */
false
);
#ifdef PADDLE_WITH_CUDA
testPoolLayer
(
"avg-projection"
,
/* trans= */
false
,
/* useGpu= */
true
);
...
...
@@ -1242,6 +1243,7 @@ TEST(Layer, PoolLayer) {
testPoolLayer
(
"cudnn-avg-pool"
,
/* trans= */
false
,
/* useGpu= */
true
);
testPoolLayer2
(
"cudnn-max-pool"
,
/* trans= */
false
,
/* useGpu= */
true
);
testPoolLayer2
(
"cudnn-avg-pool"
,
/* trans= */
false
,
/* useGpu= */
true
);
testPoolLayer
(
"max-pool-with-mask"
,
/* trans= */
false
,
/* useGpu= */
true
);
#endif
}
...
...
paddle/gserver/tests/test_MaxPoolingWithMaskOutput.cpp
0 → 100644
浏览文件 @
3e6f7684
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "LayerGradUtil.h"
#include "paddle/math/MathUtils.h"
#include "paddle/testing/TestUtil.h"
using
namespace
paddle
;
void
setPoolConfig
(
TestConfig
*
config
,
PoolConfig
*
pool
,
const
string
&
poolType
)
{
(
*
config
).
biasSize
=
0
;
(
*
config
).
layerConfig
.
set_type
(
"pool"
);
(
*
config
).
layerConfig
.
set_num_filters
(
1
);
int
kw
=
3
,
kh
=
3
;
int
pw
=
0
,
ph
=
0
;
int
sw
=
2
,
sh
=
2
;
pool
->
set_pool_type
(
poolType
);
pool
->
set_channels
(
1
);
pool
->
set_size_x
(
kw
);
pool
->
set_size_y
(
kh
);
pool
->
set_start
(
0
);
pool
->
set_padding
(
pw
);
pool
->
set_padding_y
(
ph
);
pool
->
set_stride
(
sw
);
pool
->
set_stride_y
(
sh
);
int
ow
=
outputSize
(
pool
->
img_size
(),
kw
,
pw
,
sw
,
/* caffeMode */
false
);
int
oh
=
outputSize
(
pool
->
img_size_y
(),
kh
,
ph
,
sh
,
/* caffeMode */
false
);
pool
->
set_output_x
(
ow
);
pool
->
set_output_y
(
oh
);
}
void
doOneMaxPoolingWithMaskOutputTest
(
MatrixPtr
&
inputMat
,
const
string
&
poolType
,
bool
use_gpu
,
MatrixPtr
&
maskMat
)
{
TestConfig
config
;
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"layer_0"
,
25
,
0
});
LayerInputConfig
*
input
=
config
.
layerConfig
.
add_inputs
();
PoolConfig
*
pool
=
input
->
mutable_pool_conf
();
pool
->
set_img_size
(
5
);
pool
->
set_img_size_y
(
5
);
setPoolConfig
(
&
config
,
pool
,
poolType
);
config
.
layerConfig
.
set_size
(
pool
->
output_x
()
*
pool
->
output_y
()
*
pool
->
channels
());
config
.
layerConfig
.
set_name
(
"MaxPoolWithMask"
);
std
::
vector
<
DataLayerPtr
>
dataLayers
;
LayerMap
layerMap
;
vector
<
Argument
>
datas
;
initDataLayer
(
config
,
&
dataLayers
,
&
datas
,
&
layerMap
,
"MaxPoolWithMask"
,
1
,
false
,
use_gpu
);
dataLayers
[
0
]
->
getOutputValue
()
->
copyFrom
(
*
inputMat
);
FLAGS_use_gpu
=
use_gpu
;
std
::
vector
<
ParameterPtr
>
parameters
;
LayerPtr
maxPoolingWithMaskOutputLayer
;
initTestLayer
(
config
,
&
layerMap
,
&
parameters
,
&
maxPoolingWithMaskOutputLayer
);
maxPoolingWithMaskOutputLayer
->
forward
(
PASS_GC
);
checkMatrixEqual
(
maxPoolingWithMaskOutputLayer
->
getOutput
(
"mask"
).
value
,
maskMat
);
}
TEST
(
Layer
,
maxPoolingWithMaskOutputLayerFwd
)
{
bool
useGpu
=
false
;
MatrixPtr
inputMat
;
MatrixPtr
maskMat
;
real
inputData
[]
=
{
0.1
,
0.1
,
0.5
,
0.5
,
1.1
,
0.2
,
0.2
,
0.6
,
0.1
,
0.1
,
0.3
,
0.3
,
0.7
,
0.1
,
0.1
,
0.4
,
0.4
,
0.8
,
0.8
,
0.1
,
1.0
,
2.0
,
3.0
,
0.0
,
9.0
};
real
maskData
[]
=
{
12
,
4
,
22
,
24
};
inputMat
=
Matrix
::
create
(
1
,
25
,
false
,
useGpu
);
maskMat
=
Matrix
::
create
(
1
,
4
,
false
,
useGpu
);
inputMat
->
setData
(
inputData
);
maskMat
->
setData
(
maskData
);
doOneMaxPoolingWithMaskOutputTest
(
inputMat
,
"max-pool-with-mask"
,
useGpu
,
maskMat
);
#ifdef PADDLE_WITH_CUDA
useGpu
=
true
;
inputMat
=
Matrix
::
create
(
1
,
25
,
false
,
useGpu
);
maskMat
=
Matrix
::
create
(
1
,
4
,
false
,
useGpu
);
inputMat
->
copyFrom
(
inputData
,
25
);
maskMat
->
copyFrom
(
maskData
,
4
);
doOneMaxPoolingWithMaskOutputTest
(
inputMat
,
"max-pool-with-mask"
,
useGpu
,
maskMat
);
#endif
}
paddle/math/Matrix.cpp
浏览文件 @
3e6f7684
...
...
@@ -1028,15 +1028,23 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat,
size_t
outputH
,
size_t
outputW
,
size_t
paddingH
,
size_t
paddingW
)
{
size_t
paddingW
,
MatrixPtr
maskMatP
)
{
CHECK
(
inputMat
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
real
*
inputData
=
inputMat
.
getData
();
real
*
maskData
=
NULL
;
size_t
frameNum
=
inputMat
.
getHeight
();
CHECK
(
imgSizeH
*
imgSizeW
*
channels
==
inputMat
.
getWidth
());
CHECK
(
height_
==
inputMat
.
getHeight
());
CHECK
(
width_
==
outputH
*
outputW
*
channels
);
if
(
maskMatP
!=
NULL
)
{
CHECK
(
maskMatP
->
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
CHECK
(
outputH
*
outputW
*
channels
==
maskMatP
->
getWidth
());
maskData
=
maskMatP
->
getData
();
}
hl_maxpool_forward
(
frameNum
,
inputData
,
channels
,
...
...
@@ -1051,7 +1059,8 @@ void GpuMatrix::maxPoolForward(Matrix& inputMat,
paddingH
,
paddingW
,
data_
,
getStride
());
getStride
(),
maskData
);
}
void
GpuMatrix
::
maxPoolBackward
(
Matrix
&
inputMat
,
...
...
@@ -1973,9 +1982,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
size_t
outputH
,
size_t
outputW
,
size_t
paddingH
,
size_t
paddingW
)
{
size_t
paddingW
,
MatrixPtr
maskMatP
)
{
real
*
inputData
=
inputMat
.
getData
();
real
*
outData
=
data_
;
real
*
maskData
=
NULL
;
size_t
num
=
inputMat
.
getHeight
();
size_t
inLength
=
imgSizeH
*
imgSizeW
;
size_t
outLength
=
outputH
*
outputW
;
...
...
@@ -1984,6 +1995,11 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
CHECK_EQ
(
channels
*
outLength
,
this
->
getWidth
());
size_t
outStride
=
getStride
();
if
(
maskMatP
!=
NULL
)
{
maskData
=
maskMatP
->
getData
();
CHECK_EQ
(
channels
*
outLength
,
maskMatP
->
getWidth
());
}
/* initialize the data_ */
for
(
size_t
i
=
0
;
i
<
height_
;
i
++
)
{
for
(
size_t
j
=
0
;
j
<
width_
;
j
++
)
{
...
...
@@ -2005,17 +2021,30 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat,
int
wstart
=
pw
*
strideW
-
paddingW
;
int
wend
=
std
::
min
(
wstart
+
sizeX
,
imgSizeW
);
wstart
=
std
::
max
(
wstart
,
0
);
if
(
maskData
==
NULL
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
outData
[
ph
*
outputW
+
pw
]
=
std
::
max
(
outData
[
ph
*
outputW
+
pw
],
inputData
[
h
*
imgSizeW
+
w
]);
}
}
}
else
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
if
(
outData
[
ph
*
outputW
+
pw
]
<
inputData
[
h
*
imgSizeW
+
w
])
{
outData
[
ph
*
outputW
+
pw
]
=
inputData
[
h
*
imgSizeW
+
w
];
maskData
[
ph
*
outputW
+
pw
]
=
h
*
imgSizeW
+
w
;
}
}
}
}
}
}
// compute offset
inputData
+=
inLength
;
outData
+=
outLength
;
if
(
maskData
!=
NULL
)
maskData
+=
outLength
;
}
}
}
...
...
paddle/math/Matrix.h
浏览文件 @
3e6f7684
...
...
@@ -861,7 +861,8 @@ public:
/**
* Pooling forward operation, pick out the largest element
* in the sizeX of value
* in the sizeX of value, if the maskMatP is not NULL, it will
* also caculate the location indices.
*/
virtual
void
maxPoolForward
(
Matrix
&
inputMat
,
size_t
imgSizeH
,
...
...
@@ -874,7 +875,8 @@ public:
size_t
outputH
,
size_t
outputW
,
size_t
paddingH
,
size_t
paddingW
)
{
size_t
paddingW
,
MatrixPtr
maskMatP
=
NULL
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
...
...
@@ -1426,7 +1428,8 @@ public:
size_t
outputH
,
size_t
outputW
,
size_t
paddingH
,
size_t
paddingW
);
size_t
paddingW
,
MatrixPtr
maskMatP
);
void
maxPoolBackward
(
Matrix
&
image
,
size_t
imgSizeH
,
...
...
@@ -1697,7 +1700,8 @@ public:
size_t
outputH
,
size_t
outputW
,
size_t
paddingH
,
size_t
paddingW
);
size_t
paddingW
,
MatrixPtr
maskMatP
);
void
maxPoolBackward
(
Matrix
&
image
,
size_t
imgSizeH
,
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
3e6f7684
...
...
@@ -1253,9 +1253,9 @@ def parse_bilinear(bilinear, input_layer_name, bilinear_conf):
def
parse_pool
(
pool
,
input_layer_name
,
pool_conf
,
ceil_mode
):
pool_conf
.
pool_type
=
pool
.
pool_type
config_assert
(
pool
.
pool_type
in
[
'max-projection'
,
'avg-projection'
,
'cudnn-max-pool'
,
'cudnn-avg-pool'
],
"pool-type %s is not in "
"['max-projection', 'avg-projection', "
'max-projection'
,
'avg-projection'
,
'
max-pool-with-mask'
,
'
cudnn-max-pool'
,
'cudnn-avg-pool'
],
"pool-type %s is not in "
\
"['max-projection', 'avg-projection', 'max-pool-with-mask',"
\
"'cudnn-max-pool', 'cudnn-avg-pool']"
%
pool
.
pool_type
)
pool_conf
.
channels
=
pool
.
channels
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
3e6f7684
...
...
@@ -20,7 +20,7 @@ from paddle.trainer.config_parser import *
from
.activations
import
LinearActivation
,
SigmoidActivation
,
TanhActivation
,
\
ReluActivation
,
IdentityActivation
,
SoftmaxActivation
,
BaseActivation
from
.evaluators
import
*
from
.poolings
import
MaxPooling
,
AvgPooling
,
BasePoolingType
,
\
from
.poolings
import
MaxPooling
,
AvgPooling
,
MaxWithMaskPooling
,
BasePoolingType
,
\
CudnnAvgPooling
,
CudnnMaxPooling
from
.attrs
import
*
from
.default_decorators
import
*
...
...
@@ -2699,9 +2699,9 @@ def img_pool_layer(input,
elif
isinstance
(
pool_type
,
AvgPooling
):
pool_type
.
name
=
'avg'
assert
type
(
pool_type
)
in
[
AvgPooling
,
MaxPooling
,
CudnnAvgPooling
,
assert
type
(
pool_type
)
in
[
AvgPooling
,
MaxPooling
,
MaxWithMaskPooling
,
CudnnAvgPooling
,
CudnnMaxPooling
],
\
"only (Cudnn)AvgPooling, (Cudnn)MaxPooling are supported"
"only (Cudnn)AvgPooling, (Cudnn)MaxPooling
, MaxWithMaskPooling
are supported"
type_name
=
pool_type
.
name
+
'-projection'
\
if
(
...
...
python/paddle/trainer_config_helpers/poolings.py
浏览文件 @
3e6f7684
...
...
@@ -15,8 +15,8 @@
"""
__all__
=
[
"BasePoolingType"
,
"MaxPooling"
,
"AvgPooling"
,
"
CudnnMax
Pooling"
,
"CudnnAvgPooling"
,
"SumPooling"
,
"SquareRootNPooling"
"BasePoolingType"
,
"MaxPooling"
,
"AvgPooling"
,
"
MaxWithMask
Pooling"
,
"Cudnn
MaxPooling"
,
"Cudnn
AvgPooling"
,
"SumPooling"
,
"SquareRootNPooling"
]
...
...
@@ -55,6 +55,19 @@ class MaxPooling(BasePoolingType):
self
.
output_max_index
=
output_max_index
class
MaxWithMaskPooling
(
BasePoolingType
):
"""
MaxWithMask pooling.
Not only return the very large values for each dimension in sequence or time steps,
but also the location indices of found maxinum values.
"""
def
__init__
(
self
):
BasePoolingType
.
__init__
(
self
,
"max-pool-with-mask"
)
class
CudnnMaxPooling
(
BasePoolingType
):
"""
Cudnn max pooling only support GPU. Return the maxinum value in the
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录