Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b1f16d25
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b1f16d25
编写于
3月 29, 2017
作者:
T
Tao Luo
提交者:
GitHub
3月 29, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1713 from luotao1/avg
package sequenceAvgBackward
上级
06b2e4d2
53da530d
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
133 addition
and
47 deletion
+133
-47
paddle/cuda/include/hl_sequence.h
paddle/cuda/include/hl_sequence.h
+6
-0
paddle/cuda/include/stub/hl_sequence_stub.h
paddle/cuda/include/stub/hl_sequence_stub.h
+6
-0
paddle/cuda/src/hl_cuda_sequence.cu
paddle/cuda/src/hl_cuda_sequence.cu
+48
-3
paddle/gserver/layers/AverageLayer.cpp
paddle/gserver/layers/AverageLayer.cpp
+3
-39
paddle/gserver/layers/AverageLayer.h
paddle/gserver/layers/AverageLayer.h
+0
-2
paddle/math/Matrix.cpp
paddle/math/Matrix.cpp
+49
-0
paddle/math/Matrix.h
paddle/math/Matrix.h
+8
-0
paddle/math/tests/test_matrixCompare.cpp
paddle/math/tests/test_matrixCompare.cpp
+13
-3
未找到文件。
paddle/cuda/include/hl_sequence.h
浏览文件 @
b1f16d25
...
@@ -159,4 +159,10 @@ extern void hl_sequence_avg_forward(real* dst,
...
@@ -159,4 +159,10 @@ extern void hl_sequence_avg_forward(real* dst,
int
width
,
int
width
,
const
int
mode
);
const
int
mode
);
extern
void
hl_sequence_avg_backward
(
real
*
dst
,
real
*
src
,
const
int
*
starts
,
int
height
,
int
width
,
const
int
mode
);
#endif
/* HL_SEQUENCE_H_ */
#endif
/* HL_SEQUENCE_H_ */
paddle/cuda/include/stub/hl_sequence_stub.h
浏览文件 @
b1f16d25
...
@@ -57,4 +57,10 @@ inline void hl_sequence_avg_forward(real* dst,
...
@@ -57,4 +57,10 @@ inline void hl_sequence_avg_forward(real* dst,
int
width
,
int
width
,
const
int
mode
)
{}
const
int
mode
)
{}
inline
void
hl_sequence_avg_backward
(
real
*
dst
,
real
*
src
,
const
int
*
starts
,
int
height
,
int
width
,
const
int
mode
)
{}
#endif // HL_SEQUENCE_STUB_H_
#endif // HL_SEQUENCE_STUB_H_
paddle/cuda/src/hl_cuda_sequence.cu
浏览文件 @
b1f16d25
...
@@ -325,12 +325,12 @@ __global__ void KeSequenceAvgForward(real* dst,
...
@@ -325,12 +325,12 @@ __global__ void KeSequenceAvgForward(real* dst,
int
seqLength
=
end
-
start
;
int
seqLength
=
end
-
start
;
if
(
seqLength
==
0
)
return
;
if
(
seqLength
==
0
)
return
;
real
sum
=
0.0
;
real
sum
=
0.0
;
for
(
int
i
=
0
;
i
<
seqLength
;
i
++
)
{
for
(
int
i
=
start
;
i
<
end
;
i
++
)
{
sum
+=
src
[
(
start
+
i
)
*
width
+
col
];
sum
+=
src
[
i
*
width
+
col
];
}
}
sum
=
mode
==
1
?
sum
:
sum
=
mode
==
1
?
sum
:
(
mode
==
0
?
sum
/
seqLength
:
sum
*
my_rsqrt
((
real
)
seqLength
));
(
mode
==
0
?
sum
/
seqLength
:
sum
*
my_rsqrt
((
real
)
seqLength
));
dst
[
row
*
width
+
col
]
=
sum
;
dst
[
gid
]
=
sum
;
}
}
}
}
...
@@ -354,3 +354,48 @@ void hl_sequence_avg_forward(real* dst,
...
@@ -354,3 +354,48 @@ void hl_sequence_avg_forward(real* dst,
(
dst
,
src
,
starts
,
height
,
width
,
mode
);
(
dst
,
src
,
starts
,
height
,
width
,
mode
);
CHECK_SYNC
(
"hl_sequence_avg_forward failed"
);
CHECK_SYNC
(
"hl_sequence_avg_forward failed"
);
}
}
__global__
void
KeSequenceAvgBackward
(
real
*
dst
,
real
*
src
,
const
int
*
starts
,
int
height
,
int
width
,
const
int
mode
)
{
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
row
=
gid
/
width
;
int
col
=
gid
%
width
;
if
(
gid
<
height
*
width
)
{
int
start
=
starts
[
row
];
int
end
=
starts
[
row
+
1
];
int
seqLength
=
end
-
start
;
if
(
seqLength
==
0
)
return
;
real
grad
=
src
[
gid
];
grad
=
mode
==
1
?
grad
:
(
mode
==
0
?
grad
/
seqLength
:
grad
*
my_rsqrt
((
real
)
seqLength
));
for
(
int
i
=
start
;
i
<
end
;
i
++
)
{
dst
[
i
*
width
+
col
]
+=
grad
;
}
}
}
void
hl_sequence_avg_backward
(
real
*
dst
,
real
*
src
,
const
int
*
starts
,
int
height
,
int
width
,
const
int
mode
)
{
CHECK_NOTNULL
(
dst
);
CHECK_NOTNULL
(
src
);
CHECK_NOTNULL
(
starts
);
int
block
=
512
;
int
grid
=
DIVUP
(
width
*
height
,
512
);
CHECK
(
mode
==
0
||
mode
==
1
||
mode
==
2
)
<<
"mode error in hl_sequence_avg_backward!"
;
KeSequenceAvgBackward
<<<
grid
,
block
,
0
,
STREAM_DEFAULT
>>>
(
dst
,
src
,
starts
,
height
,
width
,
mode
);
CHECK_SYNC
(
"hl_sequence_avg_backward failed"
);
}
paddle/gserver/layers/AverageLayer.cpp
浏览文件 @
b1f16d25
...
@@ -26,8 +26,6 @@ bool AverageLayer::init(const LayerMap& layerMap,
...
@@ -26,8 +26,6 @@ bool AverageLayer::init(const LayerMap& layerMap,
const
ParameterMap
&
parameterMap
)
{
const
ParameterMap
&
parameterMap
)
{
SequencePoolLayer
::
init
(
layerMap
,
parameterMap
);
SequencePoolLayer
::
init
(
layerMap
,
parameterMap
);
dataMtx_
=
Matrix
::
create
(
nullptr
,
1
,
1
,
false
,
useGpu_
);
outMtx_
=
Matrix
::
create
(
nullptr
,
1
,
getSize
(),
false
,
useGpu_
);
// average strategy
// average strategy
if
(
config_
.
average_strategy
()
==
"average"
)
{
if
(
config_
.
average_strategy
()
==
"average"
)
{
mode_
=
kAverage
;
mode_
=
kAverage
;
...
@@ -60,43 +58,9 @@ void AverageLayer::forward(PassType passType) {
...
@@ -60,43 +58,9 @@ void AverageLayer::forward(PassType passType) {
void
AverageLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
void
AverageLayer
::
backward
(
const
UpdateCallback
&
callback
)
{
SequencePoolLayer
::
backward
(
callback
);
SequencePoolLayer
::
backward
(
callback
);
const
int
*
starts
=
startPositions_
->
getData
(
false
);
if
(
getInputGrad
(
0
))
{
MatrixPtr
grad
=
getInputGrad
(
0
);
getInputGrad
(
0
)
->
sequenceAvgBackward
(
*
getOutputGrad
(),
*
startPositions_
->
getVector
(
useGpu_
),
mode_
);
if
(
grad
)
{
size_t
dim
=
getSize
();
real
*
gradientData
=
getInputGrad
(
0
)
->
getData
();
real
*
gradient
=
getOutputGrad
()
->
getData
();
size_t
numSequences
=
startPositions_
->
getSize
()
-
1
;
for
(
size_t
sequenceId
=
0
;
sequenceId
<
numSequences
;
++
sequenceId
)
{
// TODO(Dangqingqing) optimization for GPU
int
sequenceLength
=
starts
[
sequenceId
+
1
]
-
starts
[
sequenceId
];
if
(
0
==
sequenceLength
)
{
// empty sequence
continue
;
}
dataMtx_
->
setData
(
gradientData
+
starts
[
sequenceId
]
*
dim
,
sequenceLength
,
dim
);
outMtx_
->
setData
(
gradient
+
sequenceId
*
dim
);
switch
(
mode_
)
{
case
kAverage
:
{
// plain average
dataMtx_
->
addBias
(
*
outMtx_
,
1.0
f
/
sequenceLength
);
break
;
}
case
kSum
:
{
// sum instead of average
dataMtx_
->
addBias
(
*
outMtx_
,
1.0
f
);
break
;
}
case
kAverageSquareRootN
:
{
// divide by square root of sequenceLength
dataMtx_
->
addBias
(
*
outMtx_
,
1.0
f
/
sqrt
(
sequenceLength
));
break
;
}
default:
{
LOG
(
FATAL
)
<<
"should not reach here"
;
}
}
}
}
}
}
}
...
...
paddle/gserver/layers/AverageLayer.h
浏览文件 @
b1f16d25
...
@@ -45,8 +45,6 @@ public:
...
@@ -45,8 +45,6 @@ public:
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
)
override
;
void
backward
(
const
UpdateCallback
&
callback
=
nullptr
)
override
;
protected:
protected:
MatrixPtr
outMtx_
;
MatrixPtr
dataMtx_
;
int
mode_
;
int
mode_
;
};
};
}
// namespace paddle
}
// namespace paddle
paddle/math/Matrix.cpp
浏览文件 @
b1f16d25
...
@@ -483,6 +483,20 @@ void GpuMatrix::sequenceAvgForward(Matrix& a,
...
@@ -483,6 +483,20 @@ void GpuMatrix::sequenceAvgForward(Matrix& a,
hl_sequence_avg_forward
(
dst
,
src
,
starts
,
height
,
width
,
mode
);
hl_sequence_avg_forward
(
dst
,
src
,
starts
,
height
,
width
,
mode
);
}
}
void
GpuMatrix
::
sequenceAvgBackward
(
Matrix
&
a
,
const
IVector
&
startsPos
,
int
mode
)
{
size_t
height
=
a
.
getHeight
();
size_t
width
=
getWidth
();
CHECK_EQ
(
height
,
startsPos
.
getSize
()
-
1
);
CHECK_EQ
(
width
,
a
.
getWidth
());
real
*
dst
=
getData
();
real
*
src
=
a
.
getData
();
const
int
*
starts
=
startsPos
.
getData
();
hl_sequence_avg_backward
(
dst
,
src
,
starts
,
height
,
width
,
mode
);
}
/* this = scaleAB*(a*b) + scaleT*this */
/* this = scaleAB*(a*b) + scaleT*this */
void
GpuMatrix
::
mul
(
const
GpuMatrix
&
a
,
void
GpuMatrix
::
mul
(
const
GpuMatrix
&
a
,
const
GpuMatrix
&
b
,
const
GpuMatrix
&
b
,
...
@@ -2304,6 +2318,41 @@ void CpuMatrix::sequenceAvgForward(Matrix& a,
...
@@ -2304,6 +2318,41 @@ void CpuMatrix::sequenceAvgForward(Matrix& a,
}
}
}
}
void
CpuMatrix
::
sequenceAvgBackward
(
Matrix
&
a
,
const
IVector
&
startsPos
,
int
mode
)
{
size_t
height
=
a
.
getHeight
();
size_t
width
=
getWidth
();
CHECK_EQ
(
height
,
startsPos
.
getSize
()
-
1
);
CHECK_EQ
(
width
,
a
.
getWidth
());
real
*
dst
=
getData
();
real
*
src
=
a
.
getData
();
const
int
*
starts
=
startsPos
.
getData
();
MatrixPtr
outMtx
=
Matrix
::
create
(
nullptr
,
1
,
width
,
false
,
false
);
MatrixPtr
dataMtx
=
Matrix
::
create
(
nullptr
,
1
,
width
,
false
,
false
);
for
(
size_t
i
=
0
;
i
<
height
;
++
i
)
{
int
sequenceLength
=
starts
[
i
+
1
]
-
starts
[
i
];
if
(
0
==
sequenceLength
)
{
// empty sequence
continue
;
}
outMtx
->
setData
(
dst
+
starts
[
i
]
*
width
,
sequenceLength
,
width
);
dataMtx
->
setData
(
src
+
i
*
width
);
if
(
mode
==
0
)
{
// plain average
outMtx
->
addBias
(
*
dataMtx
,
1.0
f
/
sequenceLength
);
}
else
if
(
mode
==
1
)
{
// sum instead of average
outMtx
->
addBias
(
*
dataMtx
,
1.0
f
);
}
else
if
(
mode
==
2
)
{
// divide by square root of sequenceLength
outMtx
->
addBias
(
*
dataMtx
,
1.0
f
/
std
::
sqrt
(
sequenceLength
));
}
else
{
LOG
(
FATAL
)
<<
"should not reach here"
;
}
}
}
/* this = scaleAB*(a*b) + scaleT*this*/
/* this = scaleAB*(a*b) + scaleT*this*/
void
CpuMatrix
::
mul
(
const
Matrix
&
a
,
void
CpuMatrix
::
mul
(
const
Matrix
&
a
,
const
Matrix
&
b
,
const
Matrix
&
b
,
...
...
paddle/math/Matrix.h
浏览文件 @
b1f16d25
...
@@ -461,6 +461,12 @@ public:
...
@@ -461,6 +461,12 @@ public:
LOG
(
FATAL
)
<<
"Not implemented"
;
LOG
(
FATAL
)
<<
"Not implemented"
;
}
}
virtual
void
sequenceAvgBackward
(
Matrix
&
a
,
const
IVector
&
startsPos
,
int
mode
)
{
LOG
(
FATAL
)
<<
"Not implemented"
;
}
/**
/**
* @code
* @code
* this = scaleAB*(a*b) + scaleT*this
* this = scaleAB*(a*b) + scaleT*this
...
@@ -1203,6 +1209,7 @@ public:
...
@@ -1203,6 +1209,7 @@ public:
void
collectSharedBias
(
Matrix
&
a
,
real
scale
);
void
collectSharedBias
(
Matrix
&
a
,
real
scale
);
void
sequenceAvgForward
(
Matrix
&
a
,
const
IVector
&
startsPos
,
int
mode
);
void
sequenceAvgForward
(
Matrix
&
a
,
const
IVector
&
startsPos
,
int
mode
);
void
sequenceAvgBackward
(
Matrix
&
a
,
const
IVector
&
startsPos
,
int
mode
);
/**
/**
* @code
* @code
...
@@ -1619,6 +1626,7 @@ public:
...
@@ -1619,6 +1626,7 @@ public:
void
collectSharedBias
(
Matrix
&
a
,
real
scale
);
void
collectSharedBias
(
Matrix
&
a
,
real
scale
);
void
sequenceAvgForward
(
Matrix
&
a
,
const
IVector
&
startsPos
,
int
mode
);
void
sequenceAvgForward
(
Matrix
&
a
,
const
IVector
&
startsPos
,
int
mode
);
void
sequenceAvgBackward
(
Matrix
&
a
,
const
IVector
&
startsPos
,
int
mode
);
/**
/**
* @code
* @code
...
...
paddle/math/tests/test_matrixCompare.cpp
浏览文件 @
b1f16d25
...
@@ -685,7 +685,7 @@ TEST(SMatrix, topK) {
...
@@ -685,7 +685,7 @@ TEST(SMatrix, topK) {
}
}
}
}
void
testMatrixSequenceAvg
Forward
(
int
batchSize
,
int
inputDim
,
int
mode
)
{
void
testMatrixSequenceAvg
(
int
batchSize
,
int
inputDim
,
int
mode
)
{
MatrixPtr
cpuInput
=
std
::
make_shared
<
CpuMatrix
>
(
batchSize
,
inputDim
);
MatrixPtr
cpuInput
=
std
::
make_shared
<
CpuMatrix
>
(
batchSize
,
inputDim
);
MatrixPtr
gpuInput
=
std
::
make_shared
<
GpuMatrix
>
(
batchSize
,
inputDim
);
MatrixPtr
gpuInput
=
std
::
make_shared
<
GpuMatrix
>
(
batchSize
,
inputDim
);
cpuInput
->
randomizeUniform
();
cpuInput
->
randomizeUniform
();
...
@@ -706,15 +706,25 @@ void testMatrixSequenceAvgForward(int batchSize, int inputDim, int mode) {
...
@@ -706,15 +706,25 @@ void testMatrixSequenceAvgForward(int batchSize, int inputDim, int mode) {
gpuOutput
->
sequenceAvgForward
(
*
gpuInput
,
*
gpuSequence
,
mode
);
gpuOutput
->
sequenceAvgForward
(
*
gpuInput
,
*
gpuSequence
,
mode
);
TensorCheckErr
(
*
cpuOutput
,
*
gpuOutput
);
TensorCheckErr
(
*
cpuOutput
,
*
gpuOutput
);
MatrixPtr
cpuInGrad
=
std
::
make_shared
<
CpuMatrix
>
(
batchSize
,
inputDim
);
MatrixPtr
gpuInGrad
=
std
::
make_shared
<
GpuMatrix
>
(
batchSize
,
inputDim
);
cpuInGrad
->
randomizeUniform
();
gpuInGrad
->
copyFrom
(
*
cpuInGrad
);
cpuInGrad
->
sequenceAvgBackward
(
*
cpuOutput
,
*
cpuSequence
,
mode
);
gpuInGrad
->
sequenceAvgBackward
(
*
gpuOutput
,
*
gpuSequence
,
mode
);
TensorCheckErr
(
*
cpuInGrad
,
*
gpuInGrad
);
}
}
TEST
(
Matrix
,
sequenceAvg
Forward
)
{
TEST
(
Matrix
,
sequenceAvg
)
{
for
(
auto
batchSize
:
{
10
,
128
,
6000
})
{
for
(
auto
batchSize
:
{
10
,
128
,
6000
})
{
for
(
auto
inputDim
:
{
32
,
100
,
512
})
{
for
(
auto
inputDim
:
{
32
,
100
,
512
})
{
for
(
auto
mode
:
{
0
,
1
,
2
})
{
for
(
auto
mode
:
{
0
,
1
,
2
})
{
VLOG
(
3
)
<<
" batchSize="
<<
batchSize
<<
" inputDim="
<<
inputDim
VLOG
(
3
)
<<
" batchSize="
<<
batchSize
<<
" inputDim="
<<
inputDim
<<
" mode="
<<
mode
;
<<
" mode="
<<
mode
;
testMatrixSequenceAvg
Forward
(
batchSize
,
inputDim
,
mode
);
testMatrixSequenceAvg
(
batchSize
,
inputDim
,
mode
);
}
}
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录