Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ddfff3a7
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ddfff3a7
编写于
10月 30, 2016
作者:
L
liaogang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add bilinear interpolation layer
上级
86bb5ef1
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
781 addition
and
4 deletion
+781
-4
doc/ui/api/trainer_config_helpers/layers.rst
doc/ui/api/trainer_config_helpers/layers.rst
+6
-0
paddle/cuda/include/hl_cnn.h
paddle/cuda/include/hl_cnn.h
+56
-0
paddle/cuda/include/stub/hl_cnn_stub.h
paddle/cuda/include/stub/hl_cnn_stub.h
+24
-0
paddle/cuda/src/hl_cuda_cnn.cu
paddle/cuda/src/hl_cuda_cnn.cu
+133
-1
paddle/gserver/layers/BilinearInterpLayer.cpp
paddle/gserver/layers/BilinearInterpLayer.cpp
+87
-0
paddle/gserver/layers/BilinearInterpLayer.h
paddle/gserver/layers/BilinearInterpLayer.h
+45
-0
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+20
-0
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+154
-0
paddle/math/Matrix.h
paddle/math/Matrix.h
+44
-0
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+66
-0
proto/ModelConfig.proto.m4
proto/ModelConfig.proto.m4
+10
-0
python/paddle/trainer/config_parser.py
python/paddle/trainer/config_parser.py
+35
-0
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+67
-2
python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh
...trainer_config_helpers/tests/configs/generate_protostr.sh
+1
-1
python/paddle/trainer_config_helpers/tests/configs/test_bilinear_interp.py
...iner_config_helpers/tests/configs/test_bilinear_interp.py
+33
-0
未找到文件。
doc/ui/api/trainer_config_helpers/layers.rst
浏览文件 @
ddfff3a7
...
@@ -263,6 +263,12 @@ interpolation_layer
...
@@ -263,6 +263,12 @@ interpolation_layer
:members: interpolation_layer
:members: interpolation_layer
:noindex:
:noindex:
bilinear_interp_layer
-------------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: bilinear_interp_layer
:noindex:
power_layer
power_layer
-----------
-----------
.. automodule:: paddle.trainer_config_helpers.layers
.. automodule:: paddle.trainer_config_helpers.layers
...
...
paddle/cuda/include/hl_cnn.h
浏览文件 @
ddfff3a7
...
@@ -240,4 +240,60 @@ extern void hl_CMRNorm_backward(
...
@@ -240,4 +240,60 @@ extern void hl_CMRNorm_backward(
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
sizeX
,
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
sizeX
,
real
alpha
,
real
beta
);
real
alpha
,
real
beta
);
/**
* @brief Bilinear interpolation forward.
*
* @param[in] inData input value.
* @param[in] inImgH input image height.
* @param[in] inImgW input image width.
* @param[in] inputH input batchSize.
* @param[in] inputW input image data dim.
* @param[out] outData output value.
* @param[in] outImgH output image height.
* @param[in] outImgW output image width.
* @param[in] outputH output batchSize.
* @param[in] outputW output image data dim.
* @param[in] numChannels number of channels.
*
*/
extern
void
hl_bilinear_forward
(
const
real
*
inData
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
inputH
,
const
size_t
inputW
,
real
*
outData
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
outputH
,
const
size_t
outputW
,
const
size_t
numChannels
);
/**
* @brief Bilinear interpolation backward.
*
* @param[out] inGrad input gradient.
* @param[in] inImgH input image height.
* @param[in] inImgW input image width.
* @param[in] inputH input batchSize.
* @param[in] inputW input image data dim.
* @param[in] outGrad output gradient.
* @param[in] outImgH output image height.
* @param[in] outImgW output image width.
* @param[in] outputH output batchSize.
* @param[in] outputW output image data dim.
* @param[in] numChannels number of channels.
*
*/
extern
void
hl_bilinear_backward
(
real
*
inGrad
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
inputH
,
const
size_t
inputW
,
const
real
*
outGrad
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
outputH
,
const
size_t
outputW
,
const
size_t
numChannels
);
#endif
/* HL_CNN_H_ */
#endif
/* HL_CNN_H_ */
paddle/cuda/include/stub/hl_cnn_stub.h
浏览文件 @
ddfff3a7
...
@@ -89,4 +89,28 @@ inline void hl_CMRNorm_backward(
...
@@ -89,4 +89,28 @@ inline void hl_CMRNorm_backward(
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
sizeX
,
size_t
channels
,
size_t
height
,
size_t
width
,
size_t
sizeX
,
real
alpha
,
real
beta
)
{}
real
alpha
,
real
beta
)
{}
inline
void
hl_bilinear_forward
(
const
real
*
inData
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
inputH
,
const
size_t
inputW
,
real
*
outData
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
outputH
,
const
size_t
outputW
,
const
size_t
numChannels
)
{}
inline
void
hl_bilinear_backward
(
real
*
inGrad
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
inputH
,
const
size_t
inputW
,
const
real
*
outGrad
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
outputH
,
const
size_t
outputW
,
const
size_t
numChannels
)
{}
#endif // HL_CNN_STUB_H_
#endif // HL_CNN_STUB_H_
paddle/cuda/src/hl_cuda_cnn.cu
浏览文件 @
ddfff3a7
...
@@ -522,7 +522,7 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
...
@@ -522,7 +522,7 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
size_t
height
,
size_t
width
,
size_t
sizeX
,
size_t
height
,
size_t
width
,
size_t
sizeX
,
real
alpha
,
real
beta
)
{
real
alpha
,
real
beta
)
{
size_t
threadsNum
=
frameCnt
*
height
*
width
;
size_t
threadsNum
=
frameCnt
*
height
*
width
;
size_t
blocksX
=
(
threadsNum
+
1024
-
1
)
/
1024
;
size_t
blocksX
=
(
threadsNum
+
1024
-
1
)
/
1024
;
size_t
blocksY
=
1
;
size_t
blocksY
=
1
;
dim3
threads
(
1024
,
1
);
dim3
threads
(
1024
,
1
);
dim3
grid
(
blocksX
,
blocksY
);
dim3
grid
(
blocksX
,
blocksY
);
...
@@ -531,3 +531,135 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
...
@@ -531,3 +531,135 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
height
,
width
,
sizeX
,
alpha
,
beta
,
inDiff
);
height
,
width
,
sizeX
,
alpha
,
beta
,
inDiff
);
CHECK_SYNC
(
"hl_CMRNorm_backward"
);
CHECK_SYNC
(
"hl_CMRNorm_backward"
);
}
}
__global__
void
KeBilinearInterpFw
(
const
size_t
nthreads
,
const
real
*
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
inputH
,
const
size_t
inputW
,
real
*
out
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
outputH
,
const
size_t
outputW
,
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
outIdH
=
tid
/
(
outputW
/
numChannels
);
int
outIdW
=
tid
%
(
outputW
/
numChannels
);
int
inIdH
=
ratioH
*
(
outIdW
/
outImgW
);
int
hId
=
(
inIdH
<
inImgH
-
1
)
?
1
:
0
;
real
hlambda
=
ratioH
*
(
outIdW
/
outImgW
)
-
inIdH
;
int
inIdW
=
ratioW
*
(
tid
%
outImgW
);
int
wId
=
(
inIdW
<
inImgW
-
1
)
?
1
:
0
;
real
wlambda
=
ratioW
*
(
tid
%
outImgW
)
-
inIdW
;
const
real
*
inPos
=
&
in
[
outIdH
*
inputW
+
inIdH
*
inImgW
+
inIdW
];
real
*
outPos
=
&
out
[
outIdH
*
outputW
+
outIdW
];
for
(
int
c
=
0
;
c
<
numChannels
;
++
c
)
{
// bilinear interpolation
outPos
[
0
]
=
(
1.
f
-
hlambda
)
*
((
1.
f
-
wlambda
)
*
inPos
[
0
]
+
wlambda
*
inPos
[
wId
])
+
hlambda
*
((
1.
f
-
wlambda
)
*
inPos
[
hId
*
inImgW
]
+
wlambda
*
inPos
[
hId
*
inImgW
+
wId
]);
inPos
+=
inImgH
*
inImgW
;
outPos
+=
outImgH
*
outImgW
;
}
}
}
void
hl_bilinear_forward
(
const
real
*
inData
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
inputH
,
const
size_t
inputW
,
real
*
outData
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
outputH
,
const
size_t
outputW
,
const
size_t
numChannels
)
{
int
threadNum
=
outputH
*
outImgH
*
outImgW
;
int
blocks
=
(
threadNum
+
1024
-
1
)
/
1024
;
real
ratioH
=
(
outImgH
>
1
)
?
static_cast
<
float
>
(
inImgH
-
1
)
/
(
outImgH
-
1
)
:
0.
f
;
real
ratioW
=
(
outImgW
>
1
)
?
static_cast
<
float
>
(
inImgW
-
1
)
/
(
outImgW
-
1
)
:
0.
f
;
KeBilinearInterpFw
<<<
blocks
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
threadNum
,
inData
,
inImgH
,
inImgW
,
inputH
,
inputW
,
outData
,
outImgH
,
outImgW
,
outputH
,
outputW
,
numChannels
,
ratioH
,
ratioW
);
CHECK_SYNC
(
"hl_bilinear_forward failed"
);
}
__global__
void
KeBilinearInterpBw
(
const
size_t
nthreads
,
real
*
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
inputH
,
const
size_t
inputW
,
const
real
*
out
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
outputH
,
const
size_t
outputW
,
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
outIdH
=
tid
/
(
outputW
/
numChannels
);
int
outIdW
=
tid
%
(
outputW
/
numChannels
);
int
inIdH
=
ratioH
*
(
outIdW
/
outImgW
);
int
hId
=
(
inIdH
<
inImgH
-
1
)
?
1
:
0
;
real
hlambda
=
ratioH
*
(
outIdW
/
outImgW
)
-
inIdH
;
int
inIdW
=
ratioW
*
(
tid
%
outImgW
);
int
wId
=
(
inIdW
<
inImgW
-
1
)
?
1
:
0
;
real
wlambda
=
ratioW
*
(
tid
%
outImgW
)
-
inIdW
;
const
real
*
outPos
=
&
out
[
outIdH
*
outputW
+
outIdW
];
real
*
inPos
=
&
in
[
outIdH
*
inputW
+
inIdH
*
inImgW
+
inIdW
];
for
(
int
c
=
0
;
c
<
numChannels
;
++
c
)
{
atomicAdd
(
&
inPos
[
0
],
(
1.
f
-
hlambda
)
*
(
1.
f
-
wlambda
)
*
outPos
[
0
]);
atomicAdd
(
&
inPos
[
wId
],
(
1.
f
-
hlambda
)
*
wlambda
*
outPos
[
0
]);
atomicAdd
(
&
inPos
[
hId
*
inImgW
],
hlambda
*
(
1.
f
-
wlambda
)
*
outPos
[
0
]);
atomicAdd
(
&
inPos
[
hId
*
inImgW
+
wId
],
hlambda
*
wlambda
*
outPos
[
0
]);
inPos
+=
inImgH
*
inImgW
;
outPos
+=
outImgH
*
outImgW
;
}
}
}
void
hl_bilinear_backward
(
real
*
inGrad
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
inputH
,
const
size_t
inputW
,
const
real
*
outGrad
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
outputH
,
const
size_t
outputW
,
const
size_t
numChannels
)
{
int
threadNum
=
outputH
*
outImgH
*
outImgW
;
int
blocks
=
(
threadNum
+
1024
-
1
)
/
1024
;
real
ratioH
=
(
outImgH
>
1
)
?
static_cast
<
float
>
(
inImgH
-
1
)
/
(
outImgH
-
1
)
:
0.
f
;
real
ratioW
=
(
outImgW
>
1
)
?
static_cast
<
float
>
(
inImgW
-
1
)
/
(
outImgW
-
1
)
:
0.
f
;
KeBilinearInterpBw
<<<
blocks
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
threadNum
,
inGrad
,
inImgH
,
inImgW
,
inputH
,
inputW
,
outGrad
,
outImgH
,
outImgW
,
outputH
,
outputW
,
numChannels
,
ratioH
,
ratioW
);
CHECK_SYNC
(
"hl_bilinear_backward failed"
);
}
\ No newline at end of file
paddle/gserver/layers/BilinearInterpLayer.cpp
0 → 100644
浏览文件 @
ddfff3a7
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "BilinearInterpLayer.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h"
namespace
paddle
{
REGISTER_LAYER
(
bilinear_interp
,
BilinearInterpLayer
);
size_t
BilinearInterpLayer
::
getDataDimSize
()
{
getOutput
().
setFrameHeight
(
outImgH_
);
getOutput
().
setFrameWidth
(
outImgW_
);
return
outImgH_
*
outImgW_
*
numChannels_
;
}
bool
BilinearInterpLayer
::
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
)
{
/* Initialize the basic parent class */
Layer
::
init
(
layerMap
,
parameterMap
);
CHECK_EQ
(
1
,
config_
.
inputs_size
());
const
BilinearInterpConfig
&
conf
=
config_
.
inputs
(
0
).
bilinear_interp_conf
();
inImgH_
=
inputLayers_
[
0
]
->
getOutput
().
getFrameHeight
();
inImgW_
=
inputLayers_
[
0
]
->
getOutput
().
getFrameWidth
();
if
(
inImgH_
==
0
)
{
inImgH_
=
conf
.
img_size_y
();
}
if
(
inImgW_
==
0
)
{
inImgW_
=
conf
.
img_size_x
();
}
outImgH_
=
conf
.
out_size_y
();
outImgW_
=
conf
.
out_size_x
();
numChannels_
=
conf
.
num_channels
();
CHECK
(
outImgH_
>
0
&&
outImgW_
>
0
);
CHECK
(
inImgH_
>
0
&&
inImgW_
>
0
);
CHECK
(
numChannels_
);
return
true
;
}
void
BilinearInterpLayer
::
forward
(
PassType
passType
)
{
Layer
::
forward
(
passType
);
size_t
batchSize
=
getInput
(
0
).
getBatchSize
();
size_t
size
=
getDataDimSize
();
{
REGISTER_TIMER_INFO
(
"FwResetTimer"
,
getName
().
c_str
());
resetOutput
(
batchSize
,
size
);
}
MatrixPtr
inV
=
getInputValue
(
0
);
MatrixPtr
outV
=
getOutputValue
();
{
REGISTER_TIMER_INFO
(
"FwBilinearInterpTimer"
,
getName
().
c_str
());
outV
->
bilinearForward
(
*
inV
,
inImgH_
,
inImgW_
,
outImgH_
,
outImgW_
,
numChannels_
);
}
}
void
BilinearInterpLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
(
void
)
callback
;
MatrixPtr
inputG
=
getInputGrad
(
0
);
MatrixPtr
outG
=
getOutputGrad
();
{
REGISTER_TIMER_INFO
(
"BwBilinearInterpTimer"
,
getName
().
c_str
());
if
(
inputG
)
{
inputG
->
bilinearBackward
(
*
outG
,
outImgH_
,
outImgW_
,
inImgH_
,
inImgW_
,
numChannels_
);
}
}
}
}
// namespace paddle
paddle/gserver/layers/BilinearInterpLayer.h
0 → 100644
浏览文件 @
ddfff3a7
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Layer.h"
#include "paddle/math/Matrix.h"
namespace
paddle
{
/**
* @brief A layer for bilinear interpolation which is
* used on conv layer output.
*
* @note The config file api is bilinear_interp_layer.
*/
class
BilinearInterpLayer
:
public
Layer
{
protected:
size_t
outImgH_
,
outImgW_
;
size_t
inImgH_
,
inImgW_
;
size_t
numChannels_
;
public:
explicit
BilinearInterpLayer
(
const
LayerConfig
&
config
)
:
Layer
(
config
)
{}
virtual
~
BilinearInterpLayer
()
{}
size_t
getDataDimSize
();
bool
init
(
const
LayerMap
&
layerMap
,
const
ParameterMap
&
parameterMap
);
void
forward
(
PassType
passType
);
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
);
};
}
// namespace paddle
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
ddfff3a7
...
@@ -31,6 +31,26 @@ P_DECLARE_double(checkgrad_eps);
...
@@ -31,6 +31,26 @@ P_DECLARE_double(checkgrad_eps);
P_DECLARE_bool
(
thread_local_rand_use_global_seed
);
P_DECLARE_bool
(
thread_local_rand_use_global_seed
);
P_DECLARE_bool
(
prev_batch_state
);
P_DECLARE_bool
(
prev_batch_state
);
TEST
(
Layer
,
BilinearInterpLayer
)
{
TestConfig
config
;
config
.
layerConfig
.
set_type
(
"bilinear_interp"
);
config
.
biasSize
=
0
;
config
.
inputDefs
.
push_back
({
INPUT_DATA
,
"layer_0"
,
4096
,
0
});
LayerInputConfig
*
input
=
config
.
layerConfig
.
add_inputs
();
BilinearInterpConfig
*
bilinear
=
input
->
mutable_bilinear_interp_conf
();
bilinear
->
set_img_size_x
(
32
);
bilinear
->
set_img_size_y
(
32
);
bilinear
->
set_out_size_x
(
64
);
bilinear
->
set_out_size_y
(
64
);
bilinear
->
set_num_channels
(
4
);
for
(
auto
useGpu
:
{
false
,
true
})
{
testLayerGrad
(
config
,
"bilinear_interp"
,
10
,
false
,
useGpu
);
}
}
TEST
(
Operator
,
dot_mul
)
{
TEST
(
Operator
,
dot_mul
)
{
TestConfig
config
;
TestConfig
config
;
config
.
layerConfig
.
set_size
(
10
);
config
.
layerConfig
.
set_size
(
10
);
...
...
paddle/math/Matrix.cpp
浏览文件 @
ddfff3a7
...
@@ -23,6 +23,7 @@ limitations under the License. */
...
@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/utils/Logging.h"
#include "paddle/utils/Logging.h"
#include <string.h>
#include <string.h>
#include "hl_cnn.h"
#include "hl_gpu.h"
#include "hl_gpu.h"
#include "hl_table_apply.h"
#include "hl_table_apply.h"
#include "hl_top_k.h"
#include "hl_top_k.h"
...
@@ -1144,6 +1145,56 @@ void GpuMatrix::addColumnVector(const Matrix& b) {
...
@@ -1144,6 +1145,56 @@ void GpuMatrix::addColumnVector(const Matrix& b) {
BaseMatrix
::
addColVector
(
const_cast
<
Matrix
&>
(
b
));
BaseMatrix
::
addColVector
(
const_cast
<
Matrix
&>
(
b
));
}
}
void
GpuMatrix
::
bilinearForward
(
const
Matrix
&
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
numChannels
)
{
CHECK
(
dynamic_cast
<
const
GpuMatrix
*>
(
&
in
));
const
size_t
outputW
=
getWidth
();
const
size_t
outputH
=
getHeight
();
const
size_t
inputW
=
in
.
getWidth
();
const
size_t
inputH
=
in
.
getHeight
();
real
*
outData
=
getData
();
const
real
*
inData
=
in
.
getData
();
if
(
inImgH
==
outImgW
&&
inImgW
==
outImgW
)
{
this
->
copyFrom
(
in
);
}
else
{
hl_bilinear_forward
(
inData
,
inImgH
,
inImgW
,
inputH
,
inputW
,
outData
,
outImgH
,
outImgW
,
outputH
,
outputW
,
numChannels
);
}
}
void
GpuMatrix
::
bilinearBackward
(
const
Matrix
&
out
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
numChannels
)
{
CHECK
(
dynamic_cast
<
const
GpuMatrix
*>
(
&
out
));
const
size_t
inputW
=
getWidth
();
const
size_t
inputH
=
getHeight
();
const
size_t
outputW
=
out
.
getWidth
();
const
size_t
outputH
=
out
.
getHeight
();
real
*
inGrad
=
getData
();
const
real
*
outGrad
=
out
.
getData
();
if
(
outImgH
==
inImgH
&&
outImgW
==
inImgW
)
{
this
->
copyFrom
(
out
);
}
else
{
hl_bilinear_backward
(
inGrad
,
inImgH
,
inImgW
,
inputH
,
inputW
,
outGrad
,
outImgH
,
outImgW
,
outputH
,
outputW
,
numChannels
);
}
}
/**
/**
* CpuMatrix
* CpuMatrix
*/
*/
...
@@ -3598,6 +3649,109 @@ void CpuMatrix::classificationErrorMulti(Matrix& output, Matrix& label,
...
@@ -3598,6 +3649,109 @@ void CpuMatrix::classificationErrorMulti(Matrix& output, Matrix& label,
}
}
}
}
void
CpuMatrix
::
bilinearForward
(
const
Matrix
&
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
numChannels
)
{
CHECK
(
dynamic_cast
<
const
CpuMatrix
*>
(
&
in
));
size_t
outputW
=
getWidth
();
size_t
outputH
=
getHeight
();
size_t
inputW
=
in
.
getWidth
();
size_t
inputH
=
in
.
getHeight
();
real
*
outData
=
getData
();
const
real
*
inData
=
in
.
getData
();
const
real
ratioH
=
(
outImgH
>
1
)
?
static_cast
<
real
>
(
inImgH
-
1
)
/
(
outImgH
-
1
)
:
0.
f
;
const
real
ratioW
=
(
outImgW
>
1
)
?
static_cast
<
real
>
(
inImgW
-
1
)
/
(
outImgW
-
1
)
:
0.
f
;
if
(
inImgH
==
outImgH
&&
inImgW
==
outImgW
)
{
this
->
copyFrom
(
in
);
}
else
{
for
(
int
k
=
0
;
k
<
outputH
;
++
k
)
{
// loop for batches
for
(
int
i
=
0
;
i
<
outImgH
;
++
i
)
{
// loop for images
int
h
=
ratioH
*
i
;
int
hid
=
(
h
<
inImgH
-
1
)
?
1
:
0
;
real
hlambda
=
ratioH
*
i
-
h
;
for
(
int
j
=
0
;
j
<
outImgW
;
++
j
)
{
int
w
=
ratioW
*
j
;
int
wid
=
(
w
<
inImgW
-
1
)
?
1
:
0
;
real
wlambda
=
ratioW
*
j
-
w
;
// calculate four position for bilinear interpolation
const
real
*
inPos
=
&
inData
[
k
*
inputW
+
h
*
inImgW
+
w
];
real
*
outPos
=
&
outData
[
k
*
outputW
+
i
*
outImgW
+
j
];
for
(
int
c
=
0
;
c
<
numChannels
;
++
c
)
{
// loop for channels
// bilinear interpolation
outPos
[
0
]
=
(
1.
f
-
hlambda
)
*
((
1.
f
-
wlambda
)
*
inPos
[
0
]
+
wlambda
*
inPos
[
wid
])
+
hlambda
*
((
1.
f
-
wlambda
)
*
inPos
[
hid
*
inImgW
]
+
wlambda
*
inPos
[
hid
*
inImgW
+
wid
]);
inPos
+=
inImgH
*
inImgW
;
outPos
+=
outImgH
*
outImgW
;
}
}
}
}
}
}
void
CpuMatrix
::
bilinearBackward
(
const
Matrix
&
out
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
numChannels
)
{
CHECK
(
dynamic_cast
<
const
CpuMatrix
*>
(
&
out
));
size_t
inputW
=
getWidth
();
size_t
inputH
=
getHeight
();
size_t
outputW
=
out
.
getWidth
();
size_t
outputH
=
out
.
getHeight
();
real
*
inGrad
=
getData
();
const
real
*
outGrad
=
out
.
getData
();
const
real
ratioH
=
(
outImgH
>
1
)
?
static_cast
<
real
>
(
inImgH
-
1
)
/
(
outImgH
-
1
)
:
0.
f
;
const
real
ratioW
=
(
outImgW
>
1
)
?
static_cast
<
real
>
(
inImgW
-
1
)
/
(
outImgW
-
1
)
:
0.
f
;
if
(
inImgH
==
outImgH
&&
inImgW
==
outImgW
)
{
this
->
copyFrom
(
out
);
}
else
{
for
(
int
k
=
0
;
k
<
outputH
;
++
k
)
{
// loop for batches
for
(
int
i
=
0
;
i
<
outImgH
;
++
i
)
{
// loop for images
int
h
=
ratioH
*
i
;
int
hid
=
(
h
<
inImgH
-
1
)
?
1
:
0
;
real
hlambda
=
ratioH
*
i
-
h
;
for
(
int
j
=
0
;
j
<
outImgW
;
++
j
)
{
int
w
=
ratioW
*
j
;
int
wid
=
(
w
<
inImgW
-
1
)
?
1
:
0
;
real
wlambda
=
ratioW
*
j
-
w
;
real
*
inPos
=
&
inGrad
[
k
*
inputW
+
h
*
inImgW
+
w
];
const
real
*
outPos
=
&
outGrad
[
k
*
outputW
+
i
*
outImgW
+
j
];
for
(
int
c
=
0
;
c
<
numChannels
;
++
c
)
{
// loop for channels
inPos
[
0
]
+=
(
1.
f
-
hlambda
)
*
(
1.
f
-
wlambda
)
*
outPos
[
0
];
inPos
[
wid
]
+=
(
1.
f
-
hlambda
)
*
wlambda
*
outPos
[
0
];
inPos
[
hid
*
inImgW
]
+=
hlambda
*
(
1.
f
-
wlambda
)
*
outPos
[
0
];
inPos
[
hid
*
inImgW
+
wid
]
+=
hlambda
*
wlambda
*
outPos
[
0
];
inPos
+=
inImgH
*
inImgW
;
outPos
+=
outImgH
*
outImgW
;
}
}
}
}
}
}
////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////
// functions executed via cpu //
// functions executed via cpu //
////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////
...
...
paddle/math/Matrix.h
浏览文件 @
ddfff3a7
...
@@ -930,6 +930,22 @@ public:
...
@@ -930,6 +930,22 @@ public:
virtual
void
paramReluBackwardDiff
(
Matrix
&
oGrad
,
Matrix
&
data
,
Matrix
&
W
)
{
virtual
void
paramReluBackwardDiff
(
Matrix
&
oGrad
,
Matrix
&
data
,
Matrix
&
W
)
{
LOG
(
FATAL
)
<<
"Not implemented"
;
LOG
(
FATAL
)
<<
"Not implemented"
;
}
}
virtual
void
bilinearForward
(
const
Matrix
&
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
numChannels
)
{
LOG
(
FATAL
)
<<
"Not implemented"
;
}
virtual
void
bilinearBackward
(
const
Matrix
&
out
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
numChannels
)
{
LOG
(
FATAL
)
<<
"Not implemented"
;
}
};
};
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Matrix
&
mat
)
{
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Matrix
&
mat
)
{
...
@@ -1191,6 +1207,20 @@ public:
...
@@ -1191,6 +1207,20 @@ public:
int
contextLength
,
int
contextLength
,
int
contextStart
,
int
totalPad
,
int
contextStart
,
int
totalPad
,
size_t
beginPad
);
size_t
beginPad
);
void
bilinearForward
(
const
Matrix
&
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
numChannels
);
void
bilinearBackward
(
const
Matrix
&
out
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
numChannels
);
};
};
class
CpuMatrix
:
public
Matrix
{
class
CpuMatrix
:
public
Matrix
{
...
@@ -1469,6 +1499,20 @@ public:
...
@@ -1469,6 +1499,20 @@ public:
void
multiBinaryLabelCrossEntropy
(
Matrix
&
output
,
Matrix
&
label
);
void
multiBinaryLabelCrossEntropy
(
Matrix
&
output
,
Matrix
&
label
);
void
multiBinaryLabelCrossEntropyBp
(
Matrix
&
output
,
Matrix
&
label
);
void
multiBinaryLabelCrossEntropyBp
(
Matrix
&
output
,
Matrix
&
label
);
void
classificationErrorMulti
(
Matrix
&
output
,
Matrix
&
label
,
real
threshold
);
void
classificationErrorMulti
(
Matrix
&
output
,
Matrix
&
label
,
real
threshold
);
void
bilinearForward
(
const
Matrix
&
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
numChannels
);
void
bilinearBackward
(
const
Matrix
&
out
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
numChannels
);
};
};
class
SharedCpuMatrix
:
public
CpuMatrix
{
class
SharedCpuMatrix
:
public
CpuMatrix
{
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
ddfff3a7
...
@@ -88,6 +88,72 @@ void MatrixCheckErr(const Matrix& matrix1, const Matrix& matrix2) {
...
@@ -88,6 +88,72 @@ void MatrixCheckErr(const Matrix& matrix1, const Matrix& matrix2) {
EXPECT_EQ
(
count
,
0
)
<<
"There are "
<<
count
<<
" different element."
;
EXPECT_EQ
(
count
,
0
)
<<
"There are "
<<
count
<<
" different element."
;
}
}
void
testBilinearFwdBwd
(
int
numSamples
,
int
imgSizeH
,
int
imgSizeW
,
int
channels
)
{
int
inWidth
=
imgSizeH
*
imgSizeW
*
channels
;
int
outWidth
=
2
*
imgSizeH
*
2
*
imgSizeW
*
channels
;
// forward
MatrixPtr
input
=
CpuMatrix
::
create
(
numSamples
,
inWidth
,
false
,
false
);
MatrixPtr
inputGpu
=
GpuMatrix
::
create
(
numSamples
,
inWidth
,
false
,
true
);
MatrixPtr
target
=
CpuMatrix
::
create
(
numSamples
,
outWidth
,
false
,
false
);
MatrixPtr
targetGpu
=
GpuMatrix
::
create
(
numSamples
,
outWidth
,
false
,
true
);
MatrixPtr
targetCheck
=
CpuMatrix
::
create
(
numSamples
,
outWidth
,
false
,
false
);
input
->
randomizeUniform
();
inputGpu
->
copyFrom
(
*
input
);
target
->
bilinearForward
(
*
input
,
imgSizeH
,
imgSizeW
,
2
*
imgSizeH
,
2
*
imgSizeW
,
channels
);
targetGpu
->
bilinearForward
(
*
inputGpu
,
imgSizeH
,
imgSizeW
,
2
*
imgSizeH
,
2
*
imgSizeW
,
channels
);
// check
targetCheck
->
copyFrom
(
*
targetGpu
);
MatrixCheckErr
(
*
target
,
*
targetCheck
);
// backward
MatrixPtr
inputGrad
=
CpuMatrix
::
create
(
numSamples
,
inWidth
,
false
,
false
);
MatrixPtr
inputGpuGrad
=
GpuMatrix
::
create
(
numSamples
,
inWidth
,
false
,
true
);
MatrixPtr
targetGrad
=
CpuMatrix
::
create
(
numSamples
,
outWidth
,
false
,
false
);
MatrixPtr
targetGpuGrad
=
GpuMatrix
::
create
(
numSamples
,
outWidth
,
false
,
true
);
MatrixPtr
targetCheckGrad
=
CpuMatrix
::
create
(
numSamples
,
inWidth
,
false
,
false
);
inputGrad
->
randomizeUniform
();
targetGrad
->
randomizeUniform
();
inputGpuGrad
->
copyFrom
(
*
inputGrad
);
targetGpuGrad
->
copyFrom
(
*
targetGrad
);
inputGrad
->
bilinearBackward
(
*
targetGrad
,
2
*
imgSizeH
,
2
*
imgSizeW
,
imgSizeH
,
imgSizeW
,
channels
);
inputGpuGrad
->
bilinearBackward
(
*
targetGpuGrad
,
2
*
imgSizeH
,
2
*
imgSizeW
,
imgSizeH
,
imgSizeW
,
channels
);
// check
targetCheckGrad
->
copyFrom
(
*
inputGpuGrad
);
MatrixCheckErr
(
*
inputGrad
,
*
targetCheckGrad
);
}
TEST
(
Matrix
,
BilinearFwdBwd
)
{
for
(
auto
numSamples
:
{
5
,
10
})
{
for
(
auto
channels
:
{
8
,
16
})
{
for
(
auto
imgSizeH
:
{
14
,
28
})
{
for
(
auto
imgSizeW
:
{
16
,
30
})
{
VLOG
(
3
)
<<
" numSamples="
<<
numSamples
<<
" channels="
<<
channels
<<
" imgSizeH="
<<
imgSizeH
<<
" imgSizeW="
<<
imgSizeW
;
testBilinearFwdBwd
(
numSamples
,
imgSizeH
,
imgSizeW
,
channels
);
}
}
}
}
}
void
testMatrixProjectionForward
(
int
contextStart
,
int
contextLength
,
void
testMatrixProjectionForward
(
int
contextStart
,
int
contextLength
,
bool
padding
,
int
batchSize
,
int
inputDim
)
{
bool
padding
,
int
batchSize
,
int
inputDim
)
{
MatrixPtr
cpuInput
=
std
::
make_shared
<
CpuMatrix
>
(
batchSize
,
inputDim
);
MatrixPtr
cpuInput
=
std
::
make_shared
<
CpuMatrix
>
(
batchSize
,
inputDim
);
...
...
proto/ModelConfig.proto.m4
浏览文件 @
ddfff3a7
...
@@ -203,6 +203,15 @@ message OperatorConfig {
...
@@ -203,6 +203,15 @@ message OperatorConfig {
optional int32 num_filters = 7;
optional int32 num_filters = 7;
}
}
message BilinearInterpConfig {
// The size if input feature map.
required uint32 img_size_x = 1;
required uint32 img_size_y = 2;
// The size if output feature map.
required uint32 out_size_x = 3;
required uint32 out_size_y = 4;
required uint32 num_channels = 5;
}
message ImageConfig {
message ImageConfig {
// The image data dimensionality.
// The image data dimensionality.
...
@@ -225,6 +234,7 @@ message LayerInputConfig {
...
@@ -225,6 +234,7 @@ message LayerInputConfig {
// If the input layer has multi-output.
// If the input layer has multi-output.
// Set the argument name.
// Set the argument name.
optional string input_layer_argument = 9;
optional string input_layer_argument = 9;
optional BilinearInterpConfig bilinear_interp_conf = 10;
}
}
message LayerConfig {
message LayerConfig {
...
...
python/paddle/trainer/config_parser.py
浏览文件 @
ddfff3a7
...
@@ -461,6 +461,7 @@ class Input(Cfg):
...
@@ -461,6 +461,7 @@ class Input(Cfg):
sparse_update
=
None
,
sparse_update
=
None
,
gradient_clipping_threshold
=
None
,
gradient_clipping_threshold
=
None
,
conv
=
None
,
conv
=
None
,
bilinear_interp
=
None
,
norm
=
None
,
norm
=
None
,
pool
=
None
,
pool
=
None
,
image
=
None
,
image
=
None
,
...
@@ -723,6 +724,18 @@ class Conv(Cfg):
...
@@ -723,6 +724,18 @@ class Conv(Cfg):
if
output_x
is
not
None
:
if
output_x
is
not
None
:
config_assert
(
output_x
<=
0
)
config_assert
(
output_x
<=
0
)
# please refer to the comments in proto/ModelConfig.proto
@
config_class
class
BilinearInterp
(
Cfg
):
def
__init__
(
self
,
img_size_x
=
None
,
img_size_y
=
None
,
out_size_x
=
None
,
out_size_y
=
None
,
num_channels
=
None
):
self
.
add_keys
(
locals
())
# please refer to the comments in proto/ModelConfig.proto
# please refer to the comments in proto/ModelConfig.proto
@
config_class
@
config_class
class
Pool
(
Cfg
):
class
Pool
(
Cfg
):
...
@@ -953,6 +966,13 @@ def TestData(data_config, async_load_data=None):
...
@@ -953,6 +966,13 @@ def TestData(data_config, async_load_data=None):
" Data definition"
)
" Data definition"
)
g_config
.
test_data_config
.
async_load_data
=
async_load_data
g_config
.
test_data_config
.
async_load_data
=
async_load_data
def
parse_bilinear
(
bilinear
,
input_layer_name
,
bilinear_conf
):
bilinear_conf
.
img_size_x
=
bilinear
.
img_size_x
;
bilinear_conf
.
img_size_y
=
bilinear
.
img_size_y
;
bilinear_conf
.
out_size_x
=
bilinear
.
out_size_x
;
bilinear_conf
.
out_size_y
=
bilinear
.
out_size_y
;
bilinear_conf
.
num_channels
=
bilinear
.
num_channels
;
def
parse_pool
(
pool
,
input_layer_name
,
pool_conf
):
def
parse_pool
(
pool
,
input_layer_name
,
pool_conf
):
pool_conf
.
pool_type
=
pool
.
pool_type
pool_conf
.
pool_type
=
pool
.
pool_type
config_assert
(
pool
.
pool_type
in
[
'max-projection'
,
'avg-projection'
,
config_assert
(
pool
.
pool_type
in
[
'max-projection'
,
'avg-projection'
,
...
@@ -2306,6 +2326,21 @@ class InterpolationLayer(LayerBase):
...
@@ -2306,6 +2326,21 @@ class InterpolationLayer(LayerBase):
config_assert
(
input_layer1
.
size
==
input_layer2
.
size
,
config_assert
(
input_layer1
.
size
==
input_layer2
.
size
,
'the two vector inputs should be of the same size'
)
'the two vector inputs should be of the same size'
)
@
config_layer
(
'bilinear_interp'
)
class
BilinearInterpLayer
(
LayerBase
):
def
__init__
(
self
,
name
,
inputs
,
device
=
None
):
super
(
BilinearInterpLayer
,
self
).
__init__
(
name
,
'bilinear_interp'
,
0
,
inputs
=
inputs
,
device
=
device
)
input_layer
=
self
.
get_input_layer
(
0
)
self
.
set_layer_size
(
input_layer
.
size
)
parse_bilinear
(
self
.
inputs
[
0
].
bilinear_interp
,
input_layer
.
name
,
self
.
config
.
inputs
[
0
].
bilinear_interp_conf
);
@
config_layer
(
'sum_to_one_norm'
)
@
config_layer
(
'sum_to_one_norm'
)
class
SumToOneNormLayer
(
LayerBase
):
class
SumToOneNormLayer
(
LayerBase
):
def
__init__
(
def
__init__
(
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
ddfff3a7
...
@@ -40,8 +40,8 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel",
...
@@ -40,8 +40,8 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel",
'img_cmrnorm_layer'
,
'addto_layer'
,
'img_cmrnorm_layer'
,
'addto_layer'
,
'concat_layer'
,
'lstm_step_layer'
,
'recurrent_group'
,
'concat_layer'
,
'lstm_step_layer'
,
'recurrent_group'
,
'memory'
,
'StaticInput'
,
'expand_layer'
,
'scaling_layer'
,
'memory'
,
'StaticInput'
,
'expand_layer'
,
'scaling_layer'
,
'power_layer'
,
'interpolation_layer'
,
'
trans
_layer'
,
'power_layer'
,
'interpolation_layer'
,
'
bilinear_interp
_layer'
,
'sum_to_one_norm_layer'
,
'
trans_layer'
,
'
sum_to_one_norm_layer'
,
'get_output_layer'
,
'LayerType'
,
'context_projection'
,
'get_output_layer'
,
'LayerType'
,
'context_projection'
,
'beam_search'
,
'maxid_layer'
,
'GeneratedInput'
,
'SubsequenceInput'
,
'beam_search'
,
'maxid_layer'
,
'GeneratedInput'
,
'SubsequenceInput'
,
'gru_step_layer'
,
'recurrent_layer'
,
'gru_step_layer'
,
'recurrent_layer'
,
...
@@ -92,6 +92,7 @@ class LayerType(object):
...
@@ -92,6 +92,7 @@ class LayerType(object):
EXPAND_LAYER
=
'expand'
EXPAND_LAYER
=
'expand'
INTERPOLATION_LAYER
=
'interpolation'
INTERPOLATION_LAYER
=
'interpolation'
BILINEAR_INTERP_LAYER
=
'bilinear_interp'
POWER_LAYER
=
'power'
POWER_LAYER
=
'power'
SCALING_LAYER
=
'scaling'
SCALING_LAYER
=
'scaling'
TRANS_LAYER
=
'trans'
TRANS_LAYER
=
'trans'
...
@@ -1252,6 +1253,70 @@ def interpolation_layer(input, weight, name=None, layer_attr=None):
...
@@ -1252,6 +1253,70 @@ def interpolation_layer(input, weight, name=None, layer_attr=None):
size
=
input
[
0
].
size
)
size
=
input
[
0
].
size
)
@
wrap_name_default
()
@
layer_support
()
def
bilinear_interp_layer
(
input
,
img_size_x
=
None
,
img_size_y
=
None
,
out_size_x
=
None
,
out_size_y
=
None
,
num_channels
=
None
,
name
=
None
,
layer_attr
=
None
):
"""
This layer is to implement bilinear interpolation on conv layer output.
Please refer to Wikipedia: https://en.wikipedia.org/wiki/Bilinear_interpolation
The simple usage is:
.. code-block:: python
bilinear = bilinear_interp_layer(input,
img_size_x,
img_size_y,
out_size_x,
out_size_y,
num_channels)
:para input: A input layer.
:type input: LayerOutput.
:para img_size_x: previous layer output width.
:type img_size_x: int|None
:para img_size_y: previous layer output height.
:type img_size_y: int|None
:para out_size_x: bilinear interpolation output width.
:type out_size_x: int|None
:para out_size_y: bilinear interpolation output height.
:type out_size_y: int|None
:para num_channels: number of channels of input layer. If None,
it will be set automatically from previous output.
:type num_channels: int|None
:para name: The layer's name, which cna not be specified.
:type name: None|basestring
:para layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert
input
.
layer_type
==
LayerType
.
CONV_LAYER
assert
isinstance
(
input
.
activation
,
LinearActivation
)
assert
img_size_x
>
0
and
img_size_y
>
0
assert
out_size_x
>
0
and
out_size_y
>
0
if
num_channels
is
None
:
assert
input
.
numfilters
is
not
None
num_channels
=
input
.
num_filters
Layer
(
name
=
name
,
inputs
=
Input
(
input
.
name
,
bilinear_interp
=
BilinearInterp
(
img_size_x
=
img_size_x
,
img_size_y
=
img_size_y
,
out_size_x
=
out_size_x
,
out_size_y
=
out_size_y
,
num_channels
=
num_channels
)),
type
=
LayerType
.
BILINEAR_INTERP_LAYER
,
**
ExtraLayerAttribute
.
to_kwargs
(
layer_attr
))
return
LayerOutput
(
name
,
LayerType
.
BILINEAR_INTERP_LAYER
,
parents
=
[
input
])
@
wrap_name_default
()
@
wrap_name_default
()
@
layer_support
()
@
layer_support
()
def
power_layer
(
input
,
weight
,
name
=
None
,
layer_attr
=
None
):
def
power_layer
(
input
,
weight
,
name
=
None
,
layer_attr
=
None
):
...
...
python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh
浏览文件 @
ddfff3a7
...
@@ -8,7 +8,7 @@ configs=(test_fc layer_activations projections test_print_layer
...
@@ -8,7 +8,7 @@ configs=(test_fc layer_activations projections test_print_layer
test_sequence_pooling test_lstmemory_layer test_grumemory_layer
test_sequence_pooling test_lstmemory_layer test_grumemory_layer
last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
img_layers util_layers simple_rnn_layers unused_layers test_cost_layers
img_layers util_layers simple_rnn_layers unused_layers test_cost_layers
test_rnn_group
)
test_rnn_group
test_bilinear_interp
)
for
conf
in
${
configs
[*]
}
for
conf
in
${
configs
[*]
}
...
...
python/paddle/trainer_config_helpers/tests/configs/test_bilinear_interp.py
0 → 100644
浏览文件 @
ddfff3a7
from
paddle.trainer_config_helpers
import
*
settings
(
batch_size
=
1000
,
learning_rate
=
1e-5
)
data
=
data_layer
(
name
=
'data'
,
size
=
2304
)
conv
=
img_conv_layer
(
input
=
data
,
filter_size
=
3
,
num_channels
=
1
,
num_filters
=
16
,
padding
=
1
,
act
=
LinearActivation
(),
bias_attr
=
True
)
bilinear
=
bilinear_interp_layer
(
input
=
conv
,
img_size_x
=
32
,
img_size_y
=
32
,
out_size_x
=
64
,
out_size_y
=
64
,
num_channels
=
16
)
pool
=
img_pool_layer
(
input
=
bilinear
,
num_channels
=
4
,
pool_size
=
2
,
stride
=
2
,
pool_type
=
MaxPooling
())
fc
=
fc_layer
(
input
=
pool
,
size
=
384
,
bias_attr
=
False
)
outputs
(
fc
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录