Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7dc584f5
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7dc584f5
编写于
11月 20, 2017
作者:
X
xzl
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add upsample layer
上级
7c3ec220
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
607 addition
and
0 deletion
+607
-0
paddle/cuda/include/hl_cnn.h
paddle/cuda/include/hl_cnn.h
+42
-0
paddle/cuda/include/stub/hl_cnn_stub.h
paddle/cuda/include/stub/hl_cnn_stub.h
+18
-0
paddle/cuda/src/hl_cuda_cnn.cu
paddle/cuda/src/hl_cuda_cnn.cu
+76
-0
paddle/gserver/layers/UpsampleLayer.cpp
paddle/gserver/layers/UpsampleLayer.cpp
+107
-0
paddle/gserver/layers/UpsampleLayer.h
paddle/gserver/layers/UpsampleLayer.h
+54
-0
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+126
-0
paddle/math/Matrix.h
paddle/math/Matrix.h
+52
-0
proto/ModelConfig.proto
proto/ModelConfig.proto
+11
-0
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+44
-0
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+77
-0
未找到文件。
paddle/cuda/include/hl_cnn.h
浏览文件 @
7dc584f5
...
...
@@ -366,4 +366,46 @@ extern void hl_maxout_backward(real* inGrad,
size_t
featLen
,
size_t
groups
);
/**
* @brief Upsample forward.
* @param[in] inputData input data.
* @param[out] maskData the mask data from MaxPoolWithMaskLayer.
* @param[out] batchSize the batch size of the input.
* @param[in] imgSizeH image height.
* @param[in] imgSizeW image width.
* @param[in] channels the input channels.
* @param[in] outputH the output height.
* @param[in] outputW the output widht.
* @param[out] outputData output data.
*/
extern
void
hl_upsample_forward
(
real
*
inputData
,
real
*
maskData
,
size_t
batchSize
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
,
real
*
outputData
);
/**
* @brief Upsample backward.
* @param[in] outputGradData the output grad data.
* @param[out] maskData the mask data from MaxPoolWithMaskLayer.
* @param[out] batchSize the batch size of the input.
* @param[in] imgSizeH image height.
* @param[in] imgSizeW image width.
* @param[in] channels the input channels.
* @param[in] outputH the output height.
* @param[in] outputW the output widht.
* @param[out] inputGradData the input grad data.
*/
extern
void
hl_upsample_backward
(
real
*
outputGradData
,
real
*
maskData
,
size_t
batchSize
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
,
real
*
inputGradData
);
#endif // HL_CNN_H_
paddle/cuda/include/stub/hl_cnn_stub.h
浏览文件 @
7dc584f5
...
...
@@ -222,4 +222,22 @@ inline void hl_maxout_backward(real* inGrad,
size_t
featLen
,
size_t
group
)
{}
inline
void
hl_upsample_forward
(
real
*
inputData
,
real
*
maskData
,
size_t
batchSize
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
,
real
*
outputData
)
{}
inline
void
hl_upsample_backward
(
real
*
outputGradData
,
real
*
maskData
,
size_t
batchSize
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
,
real
*
inputGradData
)
{}
#endif // HL_CNN_STUB_H_
paddle/cuda/src/hl_cuda_cnn.cu
浏览文件 @
7dc584f5
...
...
@@ -1020,3 +1020,79 @@ void hl_maxout_backward(real* inGrad,
num_kernels
,
inGrad
,
outGrad
,
idData
,
size
,
featLen
,
groups
);
CHECK_SYNC
(
"hl_maxout_backward failed"
);
}
__global__
void
upsampleForwardCompute
(
real
*
input_data
,
real
*
mask_data
,
size_t
nthreads
,
size_t
in_h
,
size_t
in_w
,
size_t
out_h
,
size_t
out_w
,
real
*
output_data
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
int
offset
=
index
/
(
in_w
*
in_h
)
*
out_h
*
out_w
;
int
upsample_idx
=
static_cast
<
int
>
(
mask_data
[
index
]);
output_data
[
offset
+
upsample_idx
]
=
input_data
[
index
];
}
}
__global__
void
upsampleBackwardCompute
(
real
*
out_grad
,
real
*
mask_data
,
size_t
nthreads
,
size_t
in_h
,
size_t
in_w
,
size_t
out_h
,
size_t
out_w
,
real
*
input_grad
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
nthreads
)
{
int
offset
=
index
/
(
in_w
*
in_h
)
*
out_h
*
out_w
;
int
upsample_idx
=
static_cast
<
int
>
(
mask_data
[
index
]);
input_grad
[
index
]
=
out_grad
[
offset
+
upsample_idx
];
}
}
void
hl_upsample_forward
(
real
*
inputData
,
real
*
maskData
,
size_t
batchSize
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
,
real
*
outputData
)
{
int
num_kernels
=
batchSize
*
imgSizeH
*
imgSizeW
*
channels
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
upsampleForwardCompute
<<<
blocks
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
inputData
,
maskData
,
num_kernels
,
imgSizeH
,
imgSizeW
,
outputH
,
outputW
,
outputData
);
CHECK_SYNC
(
"hl_upsample_forward failed"
);
}
void
hl_upsample_backward
(
real
*
outputGradData
,
real
*
maskData
,
size_t
batchSize
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
,
real
*
inputGradData
)
{
int
num_kernels
=
batchSize
*
imgSizeH
*
imgSizeW
*
channels
;
int
blocks
=
(
num_kernels
+
1024
-
1
)
/
1024
;
upsampleBackwardCompute
<<<
blocks
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
outputGradData
,
maskData
,
num_kernels
,
imgSizeH
,
imgSizeW
,
outputH
,
outputW
,
inputGradData
);
CHECK_SYNC
(
"hl_upsample_backward failed"
);
}
paddle/gserver/layers/UpsampleLayer.cpp
0 → 100644
浏览文件 @
7dc584f5
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "UpsampleLayer.h"
#include "iostream"
namespace
paddle
{
REGISTER_LAYER
(
upsample
,
UpsampleLayer
);
size_t
UpsampleLayer
::
getOutputSize
()
{
if
(
upsampleSize_
==
0
)
{
upsampleSize_
=
imgSize_
*
scale_
-
static_cast
<
int
>
(
padOutX_
);
upsampleSizeY_
=
imgSizeY_
*
scaleY_
-
static_cast
<
int
>
(
padOutY_
);
}
return
upsampleSize_
*
upsampleSizeY_
*
channels_
;
}
bool
UpsampleLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
Layer
::
init
(
layerMap
,
parameterMap
);
CHECK_EQ
(
inputLayers_
.
size
(),
2U
);
CHECK_EQ
(
config_
.
inputs_size
(),
2
);
const
auto
&
conf
=
config_
.
inputs
(
0
).
upsample_conf
();
const
auto
&
img_conf
=
conf
.
image_conf
();
imgSizeY_
=
img_conf
.
has_img_size_y
()
?
img_conf
.
img_size_y
()
:
img_conf
.
img_size
();
imgSize_
=
img_conf
.
img_size
();
channels_
=
img_conf
.
channels
();
CHECK
((
conf
.
has_upsample_size
())
||
(
conf
.
has_scale
()))
<<
"scale or upsample_size is required."
;
if
(
conf
.
has_upsample_size
())
{
upsampleSize_
=
conf
.
upsample_size
();
upsampleSizeY_
=
upsampleSize_
;
if
(
conf
.
has_upsample_size_y
())
{
upsampleSizeY_
=
conf
.
upsample_size_y
();
}
}
else
{
if
(
!
conf
.
has_scale_y
())
{
scale_
=
scaleY_
=
conf
.
scale_y
();
CHECK_GT
(
static_cast
<
int
>
(
scale_
),
1
);
}
else
{
scale_
=
conf
.
scale
();
scaleY_
=
conf
.
scale_y
();
}
padOutX_
=
conf
.
pad_out_x
();
padOutY_
=
conf
.
pad_out_y
();
CHECK
(
!
padOutX_
||
scale_
==
2
)
<<
"Output height padding compensation requires scale_ == 2"
;
CHECK
(
!
padOutY_
||
scaleY_
==
2
)
<<
"Output width padding compensation requires scaleY_ == 2"
;
upsampleSize_
=
upsampleSizeY_
=
0
;
}
return
true
;
}
void
UpsampleLayer
::
forward
(
PassType
passType
)
{
Layer
::
forward
(
passType
);
MatrixPtr
input
=
getInputValue
(
0
);
MatrixPtr
mask
=
inputLayers_
[
1
]
->
getOutput
(
"mask"
).
value
;
size_t
batchSize
=
input
->
getHeight
();
size_t
outSize
=
getOutputSize
();
CHECK_EQ
(
input
->
getWidth
(),
mask
->
getWidth
());
CHECK_EQ
(
mask
->
getHeight
(),
batchSize
);
resetOutput
(
batchSize
,
outSize
);
MatrixPtr
output
=
getOutputValue
();
output
->
upsampleForward
(
*
input
,
*
mask
,
imgSize_
,
imgSizeY_
,
channels_
,
upsampleSize_
,
upsampleSizeY_
);
}
void
UpsampleLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
MatrixPtr
mask
=
inputLayers_
[
1
]
->
getOutput
(
"mask"
).
value
;
MatrixPtr
inputGrad
=
getInputGrad
(
0
);
MatrixPtr
outputGrad
=
getOutputGrad
();
inputGrad
->
upsampleBackward
(
*
outputGrad
,
*
mask
,
imgSize_
,
imgSizeY_
,
channels_
,
upsampleSize_
,
upsampleSizeY_
);
}
}
// namespace paddle
paddle/gserver/layers/UpsampleLayer.h
0 → 100644
浏览文件 @
7dc584f5
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "Layer.h"
#include "paddle/math/Matrix.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace
paddle
{
/**
* This layer transpose the pooling process.
* It takes two input, the first input is the input data, and
* the second is the mask data from the max-pool-with-mask layer.
*
*/
class
UpsampleLayer
:
public
Layer
{
public:
explicit
UpsampleLayer
(
const
LayerConfig
&
config
)
:
Layer
(
config
)
{}
~
UpsampleLayer
()
{}
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
override
;
void
forward
(
PassType
passType
)
override
;
void
backward
(
const
UpdateCallback
&
callback
)
override
;
size_t
getOutputSize
();
protected:
size_t
scale_
,
scaleY_
;
size_t
upsampleSize_
,
upsampleSizeY_
;
size_t
padOutX_
,
padOutY_
;
size_t
imgSize_
,
imgSizeY_
;
size_t
channels_
;
};
}
// namespace paddle
paddle/math/Matrix.cpp
浏览文件 @
7dc584f5
...
...
@@ -1023,6 +1023,64 @@ void GpuMatrix::check(std::ostream& os, Matrix& refMat, bool printDiff) {
LOG
(
INFO
)
<<
"the diffCnt is "
<<
diffCnt
;
}
void
GpuMatrix
::
upsampleForward
(
Matrix
&
input
,
Matrix
&
mask
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
)
{
CHECK
(
input
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
CHECK
(
mask
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
real
*
inputData
=
input
.
getData
();
real
*
maskData
=
mask
.
getData
();
real
*
outData
=
data_
;
size_t
batch
=
input
.
getHeight
();
CHECK
(
imgSizeH
*
imgSizeW
*
channels
==
input
.
getWidth
());
CHECK
(
imgSizeH
*
imgSizeW
*
channels
==
mask
.
getWidth
());
CHECK_EQ
(
batch
,
this
->
getHeight
());
CHECK
(
width_
==
outputH
*
outputW
*
channels
);
hl_upsample_forward
(
inputData
,
maskData
,
batch
,
imgSizeH
,
imgSizeW
,
channels
,
outputH
,
outputW
,
outData
);
}
void
GpuMatrix
::
upsampleBackward
(
Matrix
&
outputGrad
,
Matrix
&
mask
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
)
{
CHECK
(
outputGrad
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
CHECK
(
mask
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
real
*
outputGradData
=
outputGrad
.
getData
();
real
*
maskData
=
mask
.
getData
();
real
*
inputGradData
=
data_
;
size_t
batch
=
outputGrad
.
getHeight
();
CHECK
(
imgSizeH
*
imgSizeW
==
this
->
getWidth
()
/
channels
);
CHECK_EQ
(
batch
,
this
->
getHeight
());
CHECK_EQ
(
channels
*
outputH
*
outputW
,
outputGrad
.
getWidth
());
hl_upsample_backward
(
outputGradData
,
maskData
,
batch
,
imgSizeH
,
imgSizeW
,
channels
,
outputH
,
outputW
,
inputGradData
);
}
void
GpuMatrix
::
maxPoolForward
(
Matrix
&
inputMat
,
size_t
imgSizeH
,
size_t
imgSizeW
,
...
...
@@ -1981,6 +2039,74 @@ void CpuMatrix::inverse(MatrixPtr& matInv, bool memAlloc) {
CHECK_EQ
(
info
,
0
);
}
void
CpuMatrix
::
upsampleForward
(
Matrix
&
input
,
Matrix
&
mask
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
)
{
real
*
inputData
=
input
.
getData
();
real
*
maskData
=
mask
.
getData
();
real
*
outData
=
data_
;
size_t
inLength
=
imgSizeH
*
imgSizeW
;
size_t
outLength
=
outputH
*
outputW
;
size_t
batch
=
input
.
getHeight
();
CHECK
(
inLength
==
input
.
getWidth
()
/
channels
);
CHECK_EQ
(
batch
,
this
->
getHeight
());
CHECK_EQ
(
channels
*
outLength
,
this
->
getWidth
());
for
(
size_t
k
=
0
;
k
<
batch
;
k
++
)
{
for
(
size_t
c
=
0
;
c
<
channels
;
c
++
)
{
for
(
size_t
i
=
0
;
i
<
inLength
;
i
++
)
{
size_t
out_index
=
static_cast
<
int
>
(
maskData
[
i
]);
if
(
out_index
>=
outLength
)
{
LOG
(
FATAL
)
<<
"upsample index "
<<
out_index
<<
" out of range."
;
}
outData
[
out_index
]
=
inputData
[
i
];
}
inputData
+=
inLength
;
maskData
+=
inLength
;
outData
+=
outLength
;
}
}
}
void
CpuMatrix
::
upsampleBackward
(
Matrix
&
outputGrad
,
Matrix
&
mask
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
)
{
real
*
outputGradData
=
outputGrad
.
getData
();
real
*
maskData
=
mask
.
getData
();
real
*
inputGradData
=
data_
;
size_t
inLength
=
imgSizeH
*
imgSizeW
;
size_t
outLength
=
outputH
*
outputW
;
size_t
batch
=
outputGrad
.
getHeight
();
CHECK
(
inLength
==
this
->
getWidth
()
/
channels
);
CHECK_EQ
(
batch
,
this
->
getHeight
());
CHECK_EQ
(
channels
*
outLength
,
outputGrad
.
getWidth
());
for
(
size_t
k
=
0
;
k
<
batch
;
k
++
)
{
for
(
size_t
c
=
0
;
c
<
channels
;
c
++
)
{
for
(
size_t
i
=
0
;
i
<
inLength
;
i
++
)
{
size_t
out_index
=
static_cast
<
int
>
(
maskData
[
i
]);
if
(
out_index
>=
outLength
)
{
LOG
(
FATAL
)
<<
"upsample index "
<<
out_index
<<
" out of range."
;
}
inputGradData
[
i
]
=
outputGradData
[
out_index
];
}
inputGradData
+=
inLength
;
maskData
+=
inLength
;
outputGradData
+=
outLength
;
}
}
}
void
CpuMatrix
::
maxPoolForward
(
Matrix
&
inputMat
,
size_t
imgSizeH
,
size_t
imgSizeW
,
...
...
paddle/math/Matrix.h
浏览文件 @
7dc584f5
...
...
@@ -859,6 +859,26 @@ public:
LOG
(
FATAL
)
<<
"Not implemented"
;
}
virtual
void
upsampleForward
(
Matrix
&
input
,
Matrix
&
mask
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
virtual
void
upsampleBackward
(
Matrix
&
outputGrad
,
Matrix
&
mask
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
)
{
LOG
(
FATAL
)
<<
"Not implemeted"
;
}
/**
* Pooling forward operation, pick out the largest element
* in the sizeX of value, if the maskMatP is not NULL, it will
...
...
@@ -1417,6 +1437,22 @@ public:
void
classificationError
(
Matrix
&
output
,
IVector
&
label
,
size_t
topkSize
=
1
);
void
upsampleForward
(
Matrix
&
input
,
Matrix
&
mask
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
);
void
upsampleBackward
(
Matrix
&
outputGrad
,
Matrix
&
mask
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
);
void
maxPoolForward
(
Matrix
&
inputMat
,
size_t
imgSizeH
,
size_t
imgSizeW
,
...
...
@@ -1689,6 +1725,22 @@ public:
MatrixPtr
clone
(
size_t
height
,
size_t
width
,
bool
useGpu
=
false
);
void
upsampleForward
(
Matrix
&
input
,
Matrix
&
mask
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
);
void
upsampleBackward
(
Matrix
&
outputGrad
,
Matrix
&
mask
,
size_t
imgSizeH
,
size_t
imgSizeW
,
size_t
channels
,
size_t
outputH
,
size_t
outputW
);
void
maxPoolForward
(
Matrix
&
inputMat
,
size_t
imgSizeH
,
size_t
imgSizeW
,
...
...
proto/ModelConfig.proto
浏览文件 @
7dc584f5
...
...
@@ -321,6 +321,16 @@ message ClipConfig {
required
double
max
=
2
;
}
message
UpsampleConfig
{
required
ImageConfig
image_conf
=
1
;
optional
uint32
scale
=
2
[
default
=
2
];
optional
uint32
scale_y
=
3
[
default
=
2
];
optional
bool
pad_out_x
=
4
[
default
=
false
];
optional
bool
pad_out_y
=
5
[
default
=
false
];
optional
uint32
upsample_size
=
6
;
optional
uint32
upsample_size_y
=
7
;
}
message
ROIPoolConfig
{
required
uint32
pooled_width
=
1
;
required
uint32
pooled_height
=
2
;
...
...
@@ -357,6 +367,7 @@ message LayerInputConfig {
optional
ClipConfig
clip_conf
=
18
;
optional
ScaleSubRegionConfig
scale_sub_region_conf
=
19
;
optional
ROIPoolConfig
roi_pool_conf
=
20
;
optional
UpsampleConfig
upsample_conf
=
21
;
}
message
LayerConfig
{
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
7dc584f5
...
...
@@ -466,6 +466,7 @@ class Input(Cfg):
maxout
=
None
,
spp
=
None
,
pad
=
None
,
upsample
=
None
,
format
=
None
,
nnz
=
None
,
is_static
=
None
,
...
...
@@ -977,6 +978,11 @@ class Pad(Cfg):
def
__init__
(
self
,
channels
,
pad_c
,
pad_h
,
pad_w
):
self
.
add_keys
(
locals
())
@
config_class
class
Upsample
(
Cfg
):
def
__init__
(
self
,
scale
,
scale_y
,
pad_out_x
,
pad_out_y
,
upsample_size
,
upsample_size_y
):
self
.
add_keys
(
locals
())
@
config_class
class
Norm
(
Cfg
):
...
...
@@ -2387,6 +2393,44 @@ class SpatialPyramidPoolLayer(LayerBase):
output_x
=
(
pow
(
4
,
spp_conf
.
pyramid_height
)
-
1
)
/
(
4
-
1
)
self
.
set_cnn_layer
(
name
,
1
,
output_x
,
spp_conf
.
image_conf
.
channels
)
@
config_layer
(
'upsample'
)
class
UpsampleLayer
(
LayerBase
):
def
__init__
(
self
,
name
,
inputs
,
**
xargs
):
super
(
UpsampleLayer
,
self
).
__init__
(
name
,
'upsample'
,
0
,
inputs
=
inputs
,
**
xargs
)
input_layer
=
self
.
get_input_layer
(
0
)
image_conf
=
self
.
config
.
inputs
[
0
].
upsample_conf
.
image_conf
image_conf
.
img_size
=
input_layer
.
width
image_conf
.
img_size_y
=
input_layer
.
height
image_conf
.
channels
=
input_layer
.
size
/
(
input_layer
.
width
*
input_layer
.
height
)
upsample
=
self
.
inputs
[
0
].
upsample
output_x
=
0
output_y
=
0
output_size
=
0
if
upsample
.
scale
:
self
.
config
.
inputs
[
0
].
upsample_conf
.
scale
=
upsample
.
scale
self
.
config
.
inputs
[
0
].
upsample_conf
.
scale_y
=
upsample
.
scale_y
output_x
=
input_layer
.
width
*
upsample
.
scale
output_y
=
input_layer
.
height
*
upsample
.
scale_y
self
.
config
.
inputs
[
0
].
upsample_conf
.
pad_out_x
=
upsample
.
pad_out_x
self
.
config
.
inputs
[
0
].
upsample_conf
.
pad_out_y
=
upsample
.
pad_out_y
if
upsample
.
upsample_size
:
self
.
config
.
inputs
[
0
].
upsample_conf
.
upsample_size
=
upsample
.
upsample_size
self
.
config
.
inputs
[
0
].
upsample_conf
.
upsample_size_y
=
upsample
.
upsample_size_y
output_x
=
upsample
.
upsample_size
output_y
=
upsample
.
upsample_size_y
output_size
=
image_conf
.
channels
*
output_x
*
output_y
self
.
set_layer_height_width
(
output_y
,
output_x
)
self
.
set_layer_depth
(
input_layer
.
depth
)
self
.
set_layer_size
(
output_size
)
@
config_layer
(
'pad'
)
class
PadLayer
(
LayerBase
):
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
7dc584f5
...
...
@@ -146,6 +146,7 @@ __all__ = [
'resize_layer'
,
'sub_seq_layer'
,
'scale_sub_region_layer'
,
'upsample_layer'
,
]
...
...
@@ -163,6 +164,7 @@ class LayerType(object):
SEQUENCE_RESHAPE
=
'seqreshape'
POOLING_MAX
=
'max'
POOLING_AVG
=
'average'
UPSAMPLE_LAYER
=
'upsample'
FC_LAYER
=
'fc'
COST
=
'cost'
COSINE_SIM_VEC
=
'cos_vm'
...
...
@@ -2879,6 +2881,81 @@ def img_pool3d_layer(input,
num_filters
=
num_channels
,
size
=
l
.
config
.
size
)
@
wrap_name_default
(
"upsample"
)
@
layer_support
()
def
upsample_layer
(
input
,
name
=
None
,
scale
=
None
,
scale_y
=
None
,
upsample_size
=
None
,
upsample_size_y
=
None
,
pad_out_x
=
False
,
pad_out_y
=
False
,
layer_attr
=
None
):
"""
The DePooling process.
Inputs should be a list of length 2. The first input is a layer,
and the second input should be the MaxWithMaskPoolingLayer
The example usage is:
.. code-block:: python
pool1 = paddle.v2.layer.img_pool(input=input, pool_size=2, stride=2,
pool_type=paddle.pooling.MaxWithMask())
upsample = paddle.v2.layer.upsample(input=[layer1, pool1])
:param name: The name of this layer. It is optional.
:type name: basestring
:param input: contains an input layer and a MaxWithMaskPoolingLayer
:type input: list | tuple | collections.Sequence
:param scale: outputSize = scale * inputSize
:type scale: int | list | tuple | .
:param scale_y: scale_y will be equal to scale, if it's value is None,
:type scale: int | None.
:param upsample_size: specify the outputSize.
:type upsample_size: int | list | tuple.
:param upsample_size_y: specify the y dimension outputSize.
:type upsample_size_y: int.
:param pad_out_x: specify exact x dimension size. This parameter only works when scale is 2
:type pad_out_x: bool.
:param pad_out_y: specify exact y dimension size. This parameter only works when scale is 2
:type pad_out_y: bool.
:param layer_attr: Extra Layer Attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert
(
scale
is
not
None
)
or
(
upsample_size
is
not
None
),
\
'scale or upsample_size, there must be one to be designated'
assert
len
(
input
)
==
2
,
'layer input size must be 2'
assert
input
[
1
].
layer_type
==
LayerType
.
POOL_LAYER
,
\
'the second input should be the MaxPoolWithMaskLayer'
scale_y
=
scale
\
if
scale
is
not
None
else
scale_y
upsample_size_y
=
upsample_size
\
if
upsample_size
is
not
None
else
upsample_size_y
layer_type
=
LayerType
.
UPSAMPLE_LAYER
layer
=
Layer
(
name
=
name
,
type
=
layer_type
,
inputs
=
[
Input
(
input
[
0
].
name
,
upsample
=
Upsample
(
scale
,
scale_y
,
pad_out_x
,
pad_out_y
,
upsample_size
,
upsample_size_y
)),
Input
(
input
[
1
].
name
)
],
**
ExtraLayerAttribute
.
to_kwargs
(
layer_attr
))
sz
=
layer
.
config
.
size
return
LayerOutput
(
name
,
layer_type
=
layer_type
,
parents
=
input
,
size
=
sz
)
@
wrap_name_default
(
"spp"
)
@
layer_support
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录