Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
db175755
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
db175755
编写于
11月 08, 2016
作者:
L
liaogang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Follow comments
上级
5a1e7dbc
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
88 addition
and
69 deletion
+88
-69
paddle/cuda/src/hl_cuda_cnn.cu
paddle/cuda/src/hl_cuda_cnn.cu
+8
-8
paddle/gserver/layers/BilinearInterpLayer.cpp
paddle/gserver/layers/BilinearInterpLayer.cpp
+7
-2
paddle/gserver/layers/BilinearInterpLayer.h
paddle/gserver/layers/BilinearInterpLayer.h
+1
-0
paddle/gserver/tests/test_LayerGrad.cpp
paddle/gserver/tests/test_LayerGrad.cpp
+10
-0
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+30
-39
paddle/math/Matrix.h
paddle/math/Matrix.h
+18
-6
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+6
-5
python/paddle/trainer_config_helpers/layers.py
python/paddle/trainer_config_helpers/layers.py
+8
-9
未找到文件。
paddle/cuda/src/hl_cuda_cnn.cu
浏览文件 @
db175755
...
...
@@ -532,8 +532,7 @@ void hl_CMRNorm_backward(size_t frameCnt, const real* inV,
CHECK_SYNC
(
"hl_CMRNorm_backward"
);
}
__global__
void
KeBilinearInterpFw
(
const
size_t
nthreads
,
const
real
*
in
,
__global__
void
KeBilinearInterpFw
(
const
real
*
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
inputH
,
...
...
@@ -546,6 +545,7 @@ __global__ void KeBilinearInterpFw(const size_t nthreads,
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
)
{
int
nthreads
=
outputH
*
outputW
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
outIdH
=
tid
/
outputW
;
...
...
@@ -593,13 +593,12 @@ void hl_bilinear_forward(const real* inData,
int
blocks
=
(
threadNum
+
1024
-
1
)
/
1024
;
KeBilinearInterpFw
<<<
blocks
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
threadNum
,
inData
,
inImgH
,
inImgW
,
inputH
,
inputW
,
outData
,
outImg
H
,
outImg
W
,
outputH
,
outputW
,
numChannels
,
ratioH
,
ratioW
);
inData
,
inImgH
,
inImgW
,
inputH
,
inputW
,
outData
,
outImgH
,
outImgW
,
outputH
,
outputW
,
numChannels
,
ratioH
,
ratioW
);
CHECK_SYNC
(
"hl_bilinear_forward failed"
);
}
__global__
void
KeBilinearInterpBw
(
const
size_t
nthreads
,
real
*
in
,
__global__
void
KeBilinearInterpBw
(
real
*
in
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
inputH
,
...
...
@@ -612,6 +611,7 @@ __global__ void KeBilinearInterpBw(const size_t nthreads,
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
)
{
int
nthreads
=
outputH
*
outputW
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
nthreads
)
{
int
outIdH
=
tid
/
outputW
;
...
...
@@ -659,8 +659,8 @@ void hl_bilinear_backward(real* inGrad,
int
blocks
=
(
threadNum
+
1024
-
1
)
/
1024
;
KeBilinearInterpBw
<<<
blocks
,
1024
,
0
,
STREAM_DEFAULT
>>>
(
threadNum
,
inGrad
,
inImgH
,
inImgW
,
inputH
,
inputW
,
outGrad
,
outImg
H
,
outImg
W
,
outputH
,
outputW
,
numChannels
,
ratioH
,
ratioW
);
inGrad
,
inImgH
,
inImgW
,
inputH
,
inputW
,
outGrad
,
outImgH
,
outImgW
,
outputH
,
outputW
,
numChannels
,
ratioH
,
ratioW
);
CHECK_SYNC
(
"hl_bilinear_backward failed"
);
}
...
...
paddle/gserver/layers/BilinearInterpLayer.cpp
浏览文件 @
db175755
...
...
@@ -40,6 +40,11 @@ size_t BilinearInterpLayer::getSize() {
CHECK
(
inImgH_
>
0
&&
inImgW_
>
0
);
CHECK
(
numChannels_
);
ratioH_
=
(
outImgH_
>
1
)
?
static_cast
<
real
>
(
inImgH_
-
1
)
/
(
outImgH_
-
1
)
:
0.
f
;
ratioW_
=
(
outImgW_
>
1
)
?
static_cast
<
real
>
(
inImgW_
-
1
)
/
(
outImgW_
-
1
)
:
0.
f
;
getOutput
().
setFrameHeight
(
outImgH_
);
getOutput
().
setFrameWidth
(
outImgW_
);
return
outImgH_
*
outImgW_
*
numChannels_
;
...
...
@@ -70,7 +75,7 @@ void BilinearInterpLayer::forward(PassType passType) {
{
REGISTER_TIMER_INFO
(
"FwBilinearInterpTimer"
,
getName
().
c_str
());
outV
->
bilinearForward
(
*
inV
,
inImgH_
,
inImgW_
,
outImgH_
,
outImgW_
,
numChannels_
);
numChannels_
,
ratioH_
,
ratioW_
);
}
}
...
...
@@ -83,7 +88,7 @@ void BilinearInterpLayer::backward(const UpdateCallback& callback) {
REGISTER_TIMER_INFO
(
"BwBilinearInterpTimer"
,
getName
().
c_str
());
if
(
inputG
)
{
inputG
->
bilinearBackward
(
*
outG
,
outImgH_
,
outImgW_
,
inImgH_
,
inImgW_
,
numChannels_
);
numChannels_
,
ratioH_
,
ratioW_
);
}
}
}
...
...
paddle/gserver/layers/BilinearInterpLayer.h
浏览文件 @
db175755
...
...
@@ -29,6 +29,7 @@ class BilinearInterpLayer : public Layer {
protected:
size_t
outImgH_
,
outImgW_
;
size_t
inImgH_
,
inImgW_
;
real
ratioH_
,
ratioW_
;
size_t
numChannels_
;
public:
...
...
paddle/gserver/tests/test_LayerGrad.cpp
浏览文件 @
db175755
...
...
@@ -50,6 +50,16 @@ TEST(Layer, BilinearInterpLayer) {
for
(
auto
useGpu
:
{
false
,
true
})
{
testLayerGrad
(
config
,
"bilinear_interp"
,
10
,
false
,
useGpu
);
}
bilinear
->
set_img_size_x
(
32
);
bilinear
->
set_img_size_y
(
32
);
bilinear
->
set_out_size_x
(
32
);
bilinear
->
set_out_size_y
(
32
);
bilinear
->
set_num_channels
(
4
);
for
(
auto
useGpu
:
{
false
,
true
})
{
testLayerGrad
(
config
,
"bilinear_interp"
,
10
,
false
,
useGpu
);
}
}
TEST
(
Operator
,
dot_mul
)
{
...
...
paddle/math/Matrix.cpp
浏览文件 @
db175755
...
...
@@ -1227,7 +1227,9 @@ void GpuMatrix::bilinearForward(const Matrix& in,
const
size_t
inImgW
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
numChannels
)
{
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
)
{
CHECK
(
dynamic_cast
<
const
GpuMatrix
*>
(
&
in
));
const
size_t
outputW
=
getWidth
();
...
...
@@ -1238,11 +1240,6 @@ void GpuMatrix::bilinearForward(const Matrix& in,
real
*
outData
=
getData
();
const
real
*
inData
=
in
.
getData
();
real
ratioH
=
(
outImgH
>
1
)
?
static_cast
<
real
>
(
inImgH
-
1
)
/
(
outImgH
-
1
)
:
0.
f
;
real
ratioW
=
(
outImgW
>
1
)
?
static_cast
<
real
>
(
inImgW
-
1
)
/
(
outImgW
-
1
)
:
0.
f
;
if
(
inImgH
==
outImgW
&&
inImgW
==
outImgW
)
{
this
->
copyFrom
(
in
);
}
else
{
...
...
@@ -1258,7 +1255,9 @@ void GpuMatrix::bilinearBackward(const Matrix& out,
const
size_t
outImgW
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
numChannels
)
{
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
)
{
CHECK
(
dynamic_cast
<
const
GpuMatrix
*>
(
&
out
));
const
size_t
inputW
=
getWidth
();
...
...
@@ -1269,13 +1268,8 @@ void GpuMatrix::bilinearBackward(const Matrix& out,
real
*
inGrad
=
getData
();
const
real
*
outGrad
=
out
.
getData
();
real
ratioH
=
(
outImgH
>
1
)
?
static_cast
<
real
>
(
inImgH
-
1
)
/
(
outImgH
-
1
)
:
0.
f
;
real
ratioW
=
(
outImgW
>
1
)
?
static_cast
<
real
>
(
inImgW
-
1
)
/
(
outImgW
-
1
)
:
0.
f
;
if
(
outImgH
==
inImgH
&&
outImgW
==
inImgW
)
{
this
->
add
Bias
(
const_cast
<
Matrix
&>
(
out
),
1.
f
);
this
->
add
(
const_cast
<
Matrix
&>
(
out
)
);
}
else
{
hl_bilinear_backward
(
inGrad
,
inImgH
,
inImgW
,
inputH
,
inputW
,
outGrad
,
...
...
@@ -3908,7 +3902,9 @@ void CpuMatrix::bilinearForward(const Matrix& in,
const
size_t
inImgW
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
numChannels
)
{
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
)
{
CHECK
(
dynamic_cast
<
const
CpuMatrix
*>
(
&
in
));
size_t
outputW
=
getWidth
();
...
...
@@ -3920,11 +3916,6 @@ void CpuMatrix::bilinearForward(const Matrix& in,
real
*
outData
=
getData
();
const
real
*
inData
=
in
.
getData
();
const
real
ratioH
=
(
outImgH
>
1
)
?
static_cast
<
real
>
(
inImgH
-
1
)
/
(
outImgH
-
1
)
:
0.
f
;
const
real
ratioW
=
(
outImgW
>
1
)
?
static_cast
<
real
>
(
inImgW
-
1
)
/
(
outImgW
-
1
)
:
0.
f
;
if
(
inImgH
==
outImgH
&&
inImgW
==
outImgW
)
{
this
->
copyFrom
(
in
);
}
else
{
...
...
@@ -3932,21 +3923,23 @@ void CpuMatrix::bilinearForward(const Matrix& in,
for
(
size_t
i
=
0
;
i
<
outImgH
;
++
i
)
{
// loop for images
size_t
h
=
ratioH
*
i
;
size_t
hid
=
(
h
<
inImgH
-
1
)
?
1
:
0
;
real
hlambda
=
ratioH
*
i
-
h
;
real
h1lambda
=
ratioH
*
i
-
h
;
real
h2lambda
=
1
-
h1lambda
;
for
(
size_t
j
=
0
;
j
<
outImgW
;
++
j
)
{
size_t
w
=
ratioW
*
j
;
size_t
wid
=
(
w
<
inImgW
-
1
)
?
1
:
0
;
real
wlambda
=
ratioW
*
j
-
w
;
real
w1lambda
=
ratioW
*
j
-
w
;
real
w2lambda
=
1
-
w1lambda
;
// calculate four position for bilinear interpolation
const
real
*
inPos
=
&
inData
[
k
*
inputW
+
h
*
inImgW
+
w
];
real
*
outPos
=
&
outData
[
k
*
outputW
+
i
*
outImgW
+
j
];
for
(
size_t
c
=
0
;
c
<
numChannels
;
++
c
)
{
// loop for channels
// bilinear interpolation
outPos
[
0
]
=
(
1.
f
-
hlambda
)
*
((
1.
f
-
wlambda
)
*
inPos
[
0
]
+
w
lambda
*
inPos
[
wid
])
+
h
lambda
*
((
1.
f
-
wlambda
)
*
inPos
[
hid
*
inImgW
]
+
wlambda
*
inPos
[
hid
*
inImgW
+
wid
]);
outPos
[
0
]
=
h2lambda
*
(
w2lambda
*
inPos
[
0
]
+
w1
lambda
*
inPos
[
wid
])
+
h
1lambda
*
(
w2lambda
*
inPos
[
hid
*
inImgW
]
+
w
1
lambda
*
inPos
[
hid
*
inImgW
+
wid
]);
inPos
+=
inImgH
*
inImgW
;
outPos
+=
outImgH
*
outImgW
;
}
...
...
@@ -3961,7 +3954,9 @@ void CpuMatrix::bilinearBackward(const Matrix& out,
const
size_t
outImgW
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
numChannels
)
{
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
)
{
CHECK
(
dynamic_cast
<
const
CpuMatrix
*>
(
&
out
));
size_t
inputW
=
getWidth
();
...
...
@@ -3973,32 +3968,28 @@ void CpuMatrix::bilinearBackward(const Matrix& out,
real
*
inGrad
=
getData
();
const
real
*
outGrad
=
out
.
getData
();
const
real
ratioH
=
(
outImgH
>
1
)
?
static_cast
<
real
>
(
inImgH
-
1
)
/
(
outImgH
-
1
)
:
0.
f
;
const
real
ratioW
=
(
outImgW
>
1
)
?
static_cast
<
real
>
(
inImgW
-
1
)
/
(
outImgW
-
1
)
:
0.
f
;
if
(
inImgH
==
outImgH
&&
inImgW
==
outImgW
)
{
this
->
add
Bias
(
const_cast
<
Matrix
&>
(
out
),
1.
f
);
this
->
add
(
const_cast
<
Matrix
&>
(
out
)
);
}
else
{
for
(
size_t
k
=
0
;
k
<
batchSize
;
++
k
)
{
// loop for batches
for
(
size_t
i
=
0
;
i
<
outImgH
;
++
i
)
{
// loop for images
size_t
h
=
ratioH
*
i
;
size_t
hid
=
(
h
<
inImgH
-
1
)
?
1
:
0
;
real
hlambda
=
ratioH
*
i
-
h
;
real
h
1
lambda
=
ratioH
*
i
-
h
;
real
h2lambda
=
1
-
h1lambda
;
for
(
size_t
j
=
0
;
j
<
outImgW
;
++
j
)
{
size_t
w
=
ratioW
*
j
;
size_t
wid
=
(
w
<
inImgW
-
1
)
?
1
:
0
;
real
wlambda
=
ratioW
*
j
-
w
;
real
w1lambda
=
ratioW
*
j
-
w
;
real
w2lambda
=
1
-
w1lambda
;
real
*
inPos
=
&
inGrad
[
k
*
inputW
+
h
*
inImgW
+
w
];
const
real
*
outPos
=
&
outGrad
[
k
*
outputW
+
i
*
outImgW
+
j
];
for
(
size_t
c
=
0
;
c
<
numChannels
;
++
c
)
{
// loop for channels
inPos
[
0
]
+=
(
1.
f
-
hlambda
)
*
(
1.
f
-
wlambda
)
*
outPos
[
0
];
inPos
[
wid
]
+=
(
1.
f
-
hlambda
)
*
w
lambda
*
outPos
[
0
];
inPos
[
hid
*
inImgW
]
+=
h
lambda
*
(
1.
f
-
wlambda
)
*
outPos
[
0
];
inPos
[
hid
*
inImgW
+
wid
]
+=
h
lambda
*
w
lambda
*
outPos
[
0
];
inPos
[
0
]
+=
h2lambda
*
w2lambda
*
outPos
[
0
];
inPos
[
wid
]
+=
h2lambda
*
w1
lambda
*
outPos
[
0
];
inPos
[
hid
*
inImgW
]
+=
h
1lambda
*
w2lambda
*
outPos
[
0
];
inPos
[
hid
*
inImgW
+
wid
]
+=
h
1lambda
*
w1
lambda
*
outPos
[
0
];
inPos
+=
inImgH
*
inImgW
;
outPos
+=
outImgH
*
outImgW
;
}
...
...
paddle/math/Matrix.h
浏览文件 @
db175755
...
...
@@ -997,7 +997,9 @@ public:
const
size_t
inImgW
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
numChannels
)
{
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
)
{
LOG
(
FATAL
)
<<
"Not implemented"
;
}
virtual
void
bilinearBackward
(
const
Matrix
&
out
,
...
...
@@ -1005,7 +1007,9 @@ public:
const
size_t
outImgW
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
numChannels
)
{
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
)
{
LOG
(
FATAL
)
<<
"Not implemented"
;
}
};
...
...
@@ -1283,14 +1287,18 @@ public:
const
size_t
inImgW
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
numChannels
);
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
);
void
bilinearBackward
(
const
Matrix
&
out
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
numChannels
);
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
);
};
class
CpuMatrix
:
public
Matrix
{
...
...
@@ -1583,14 +1591,18 @@ public:
const
size_t
inImgW
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
numChannels
);
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
);
void
bilinearBackward
(
const
Matrix
&
out
,
const
size_t
outImgH
,
const
size_t
outImgW
,
const
size_t
inImgH
,
const
size_t
inImgW
,
const
size_t
numChannels
);
const
size_t
numChannels
,
const
real
ratioH
,
const
real
ratioW
);
};
class
SharedCpuMatrix
:
public
CpuMatrix
{
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
db175755
...
...
@@ -94,7 +94,8 @@ void testBilinearFwdBwd(int numSamples, int imgSizeH, int imgSizeW,
int
channels
)
{
int
inWidth
=
imgSizeH
*
imgSizeW
*
channels
;
int
outWidth
=
2
*
imgSizeH
*
2
*
imgSizeW
*
channels
;
real
ratioH
=
0.5
;
real
ratioW
=
0.5
;
// forward
MatrixPtr
input
=
CpuMatrix
::
create
(
numSamples
,
inWidth
,
false
,
false
);
MatrixPtr
inputGpu
=
GpuMatrix
::
create
(
numSamples
,
inWidth
,
false
,
true
);
...
...
@@ -107,9 +108,9 @@ void testBilinearFwdBwd(int numSamples, int imgSizeH, int imgSizeW,
inputGpu
->
copyFrom
(
*
input
);
target
->
bilinearForward
(
*
input
,
imgSizeH
,
imgSizeW
,
2
*
imgSizeH
,
2
*
imgSizeW
,
channels
);
2
*
imgSizeH
,
2
*
imgSizeW
,
channels
,
ratioH
,
ratioW
);
targetGpu
->
bilinearForward
(
*
inputGpu
,
imgSizeH
,
imgSizeW
,
2
*
imgSizeH
,
2
*
imgSizeW
,
channels
);
2
*
imgSizeH
,
2
*
imgSizeW
,
channels
,
ratioH
,
ratioW
);
// check
targetCheck
->
copyFrom
(
*
targetGpu
);
...
...
@@ -131,9 +132,9 @@ void testBilinearFwdBwd(int numSamples, int imgSizeH, int imgSizeW,
targetGpuGrad
->
copyFrom
(
*
targetGrad
);
inputGrad
->
bilinearBackward
(
*
targetGrad
,
2
*
imgSizeH
,
2
*
imgSizeW
,
imgSizeH
,
imgSizeW
,
channels
);
imgSizeH
,
imgSizeW
,
channels
,
ratioH
,
ratioW
);
inputGpuGrad
->
bilinearBackward
(
*
targetGpuGrad
,
2
*
imgSizeH
,
2
*
imgSizeW
,
imgSizeH
,
imgSizeW
,
channels
);
imgSizeH
,
imgSizeW
,
channels
,
ratioH
,
ratioW
);
// check
targetCheckGrad
->
copyFrom
(
*
inputGpuGrad
);
...
...
python/paddle/trainer_config_helpers/layers.py
浏览文件 @
db175755
...
...
@@ -1272,19 +1272,17 @@ def bilinear_interp_layer(input,
.. code-block:: python
bilinear = bilinear_interp_layer(input,
out_size_x,
out_size_y)
bilinear = bilinear_interp_layer(input=layer1, out_size_x=64, out_size_y=64)
:para
input: A input layer.
:para
m
input: A input layer.
:type input: LayerOutput.
:para
out_size_x: bilinear interpolation output width.
:para
m
out_size_x: bilinear interpolation output width.
:type out_size_x: int|None
:para
out_size_y: bilinear interpolation output height.
:para
m
out_size_y: bilinear interpolation output height.
:type out_size_y: int|None
:para
name: The layer's name, which cna not be specified.
:para
m
name: The layer's name, which cna not be specified.
:type name: None|basestring
:para
layer_attr: Extra Layer attribute.
:para
m
layer_attr: Extra Layer attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
...
...
@@ -1301,7 +1299,8 @@ def bilinear_interp_layer(input,
num_channels
=
num_channels
)),
type
=
LayerType
.
BILINEAR_INTERP_LAYER
,
**
ExtraLayerAttribute
.
to_kwargs
(
layer_attr
))
return
LayerOutput
(
name
,
LayerType
.
BILINEAR_INTERP_LAYER
,
parents
=
[
input
])
return
LayerOutput
(
name
,
LayerType
.
BILINEAR_INTERP_LAYER
,
parents
=
[
input
],
num_filters
=
num_channels
)
@
wrap_name_default
()
@
layer_support
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录