Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
4b7bd642
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
4b7bd642
编写于
1月 02, 2018
作者:
G
Guo Sheng
提交者:
GitHub
1月 02, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7102 from guoshengCS/refine-act-GRU
Refine the activation type in the GRU operator related
上级
f58fe6d3
443391ce
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
66 addition
and
87 deletion
+66
-87
paddle/operators/gru_op.h
paddle/operators/gru_op.h
+16
-9
paddle/operators/math/detail/gru_cpu_kernel.h
paddle/operators/math/detail/gru_cpu_kernel.h
+16
-18
paddle/operators/math/detail/gru_gpu_kernel.h
paddle/operators/math/detail/gru_gpu_kernel.h
+4
-6
paddle/operators/math/detail/gru_kernel.h
paddle/operators/math/detail/gru_kernel.h
+8
-9
paddle/operators/math/gru_compute.cc
paddle/operators/math/gru_compute.cc
+6
-6
paddle/operators/math/gru_compute.cu
paddle/operators/math/gru_compute.cu
+6
-6
paddle/operators/math/gru_compute.h
paddle/operators/math/gru_compute.h
+10
-11
paddle/operators/math/lstm_compute.h
paddle/operators/math/lstm_compute.h
+0
-22
未找到文件。
paddle/operators/gru_op.h
浏览文件 @
4b7bd642
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
#include "paddle/operators/math/sequence2batch.h"
...
@@ -70,7 +71,7 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -70,7 +71,7 @@ class GRUKernel : public framework::OpKernel<T> {
}
}
int
frame_size
=
hidden_dims
[
1
];
int
frame_size
=
hidden_dims
[
1
];
math
::
hl_gru_v
alue
<
T
>
gru_value
;
math
::
GRUMetaV
alue
<
T
>
gru_value
;
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state_weight
=
gru_value
.
state_weight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
...
@@ -89,6 +90,10 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -89,6 +90,10 @@ class GRUKernel : public framework::OpKernel<T> {
}
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
active_node
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"activation"
));
auto
active_gate
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
));
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
for
(
size_t
n
=
0
;
n
<
num_batch
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
...
@@ -101,9 +106,8 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -101,9 +106,8 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value
.
gate_value
=
gate_t
.
data
<
T
>
();
gru_value
.
gate_value
=
gate_t
.
data
<
T
>
();
gru_value
.
reset_output_value
=
reset_hidden_prev_t
.
data
<
T
>
();
gru_value
.
reset_output_value
=
reset_hidden_prev_t
.
data
<
T
>
();
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
math
::
GRUUnitFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
active_node
,
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"activation"
)),
active_gate
);
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
)));
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
gru_value
.
prev_out_value
=
gru_value
.
output_value
;
}
}
...
@@ -170,12 +174,12 @@ class GRUGradKernel : public framework::OpKernel<T> {
...
@@ -170,12 +174,12 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_hidden_grad
.
set_lod
(
batch_hidden
->
lod
());
batch_hidden_grad
.
set_lod
(
batch_hidden
->
lod
());
to_batch
(
dev_ctx
,
*
hidden_grad
,
batch_hidden_grad
,
false
,
is_reverse
);
to_batch
(
dev_ctx
,
*
hidden_grad
,
batch_hidden_grad
,
false
,
is_reverse
);
math
::
hl_gru_v
alue
<
T
>
gru_value
;
math
::
GRUMetaV
alue
<
T
>
gru_value
;
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
gate_weight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state_weight
=
gru_value
.
state_weight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
math
::
hl_gru_g
rad
<
T
>
gru_grad
;
math
::
GRUMetaG
rad
<
T
>
gru_grad
;
if
(
weight_grad
)
{
if
(
weight_grad
)
{
gru_grad
.
gate_weight_grad
=
gru_grad
.
gate_weight_grad
=
weight_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
weight_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
@@ -189,6 +193,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
...
@@ -189,6 +193,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
auto
batch_starts
=
batch_hidden_grad
.
lod
()[
0
];
auto
batch_starts
=
batch_hidden_grad
.
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
auto
active_node
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"activation"
));
auto
active_gate
=
math
::
detail
::
GetActivationType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
));
for
(
int
n
=
static_cast
<
int
>
(
num_batch
)
-
1
;
n
>=
0
;
n
--
)
{
for
(
int
n
=
static_cast
<
int
>
(
num_batch
)
-
1
;
n
>=
0
;
n
--
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
...
@@ -219,9 +227,8 @@ class GRUGradKernel : public framework::OpKernel<T> {
...
@@ -219,9 +227,8 @@ class GRUGradKernel : public framework::OpKernel<T> {
}
}
math
::
GRUUnitGradFunctor
<
DeviceContext
,
T
>::
compute
(
math
::
GRUUnitGradFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
gru_value
,
gru_grad
,
frame_size
,
cur_batch_size
,
dev_ctx
,
gru_value
,
gru_grad
,
frame_size
,
cur_batch_size
,
active_node
,
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"activation"
)),
active_gate
);
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
)));
}
}
if
(
input_grad
)
{
if
(
input_grad
)
{
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
input_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/operators/math/detail/gru_cpu_kernel.h
浏览文件 @
4b7bd642
...
@@ -28,7 +28,7 @@ template <class OpResetOutput, typename T>
...
@@ -28,7 +28,7 @@ template <class OpResetOutput, typename T>
void
hl_naive_gru_forward_reset_output
(
OpResetOutput
op_reset_output
,
void
hl_naive_gru_forward_reset_output
(
OpResetOutput
op_reset_output
,
T
*
gate_value
,
T
*
reset_output_value
,
T
*
gate_value
,
T
*
reset_output_value
,
T
*
prev_output_value
,
int
frame_size
,
T
*
prev_output_value
,
int
frame_size
,
activation_mode_t
active_gate
)
{
ActivationType
active_gate
)
{
T
r_value_update_gate
;
T
r_value_update_gate
;
T
r_value_reset_gate
;
T
r_value_reset_gate
;
T
r_value_reset_output
;
T
r_value_reset_output
;
...
@@ -56,7 +56,7 @@ template <class OpFinalOutput, typename T>
...
@@ -56,7 +56,7 @@ template <class OpFinalOutput, typename T>
void
hl_naive_gru_forward_final_output
(
OpFinalOutput
op_final_output
,
void
hl_naive_gru_forward_final_output
(
OpFinalOutput
op_final_output
,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
output_value
,
int
frame_size
,
T
*
output_value
,
int
frame_size
,
activation_mode_t
active_node
)
{
ActivationType
active_node
)
{
T
r_value_update_gate
;
T
r_value_update_gate
;
T
r_value_frame_state
;
T
r_value_frame_state
;
T
r_prev_out
=
0
;
T
r_prev_out
=
0
;
...
@@ -83,7 +83,7 @@ template <class OpResetOutput, typename T>
...
@@ -83,7 +83,7 @@ template <class OpResetOutput, typename T>
void
hl_avx_gru_forward_reset_output
(
OpResetOutput
op_reset_output
,
void
hl_avx_gru_forward_reset_output
(
OpResetOutput
op_reset_output
,
T
*
gate_value
,
T
*
reset_output_value
,
T
*
gate_value
,
T
*
reset_output_value
,
T
*
prev_output_value
,
int
frame_size
,
T
*
prev_output_value
,
int
frame_size
,
activation_mode_t
active_gate
)
{
ActivationType
active_gate
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r_value_update_gate
;
__m256
r_value_update_gate
;
__m256
r_value_reset_gate
;
__m256
r_value_reset_gate
;
...
@@ -113,7 +113,7 @@ template <class OpFinalOutput, typename T>
...
@@ -113,7 +113,7 @@ template <class OpFinalOutput, typename T>
void
hl_avx_gru_forward_final_output
(
OpFinalOutput
op_final_output
,
void
hl_avx_gru_forward_final_output
(
OpFinalOutput
op_final_output
,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
output_value
,
int
frame_size
,
T
*
output_value
,
int
frame_size
,
activation_mode_t
active_node
)
{
ActivationType
active_node
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r_value_update_gate
;
__m256
r_value_update_gate
;
__m256
r_value_frame_state
;
__m256
r_value_frame_state
;
...
@@ -140,9 +140,8 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
...
@@ -140,9 +140,8 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
template
<
class
OpResetOutput
,
typename
T
>
template
<
class
OpResetOutput
,
typename
T
>
inline
void
forward_reset_output
(
OpResetOutput
op_reset_output
,
inline
void
forward_reset_output
(
OpResetOutput
op_reset_output
,
hl_gru_value
<
T
>
value
,
int
frame_size
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
int
batch_size
,
ActivationType
active_gate
)
{
activation_mode_t
active_gate
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpResetOutput
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
if
(
OpResetOutput
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_forward_reset_output
(
hl_avx_gru_forward_reset_output
(
...
@@ -164,9 +163,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
...
@@ -164,9 +163,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
template
<
class
OpFinalOutput
,
typename
T
>
template
<
class
OpFinalOutput
,
typename
T
>
inline
void
forward_final_output
(
OpFinalOutput
op_final_output
,
inline
void
forward_final_output
(
OpFinalOutput
op_final_output
,
hl_gru_value
<
T
>
value
,
int
frame_size
,
GRUMetaValue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
int
batch_size
,
ActivationType
active_node
)
{
activation_mode_t
active_node
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpFinalOutput
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
if
(
OpFinalOutput
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_forward_final_output
(
op_final_output
,
value
.
gate_value
,
hl_avx_gru_forward_final_output
(
op_final_output
,
value
.
gate_value
,
...
@@ -191,7 +189,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
...
@@ -191,7 +189,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
prev_out_grad
,
T
*
output_grad
,
T
*
prev_out_grad
,
T
*
output_grad
,
int
frame_size
,
int
frame_size
,
activation_mode_t
active_node
)
{
ActivationType
active_node
)
{
T
r_update_gate_value
;
T
r_update_gate_value
;
T
r_update_gate_grad
;
T
r_update_gate_grad
;
T
r_frame_state_value
;
T
r_frame_state_value
;
...
@@ -232,7 +230,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
...
@@ -232,7 +230,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
prev_out_grad
,
T
*
reset_output_grad
,
T
*
prev_out_grad
,
T
*
reset_output_grad
,
int
frame_size
,
int
frame_size
,
activation_mode_t
active_gate
)
{
ActivationType
active_gate
)
{
T
r_update_gate_value
;
T
r_update_gate_value
;
T
r_update_gate_grad
;
T
r_update_gate_grad
;
T
r_reset_gate_value
;
T
r_reset_gate_value
;
...
@@ -277,7 +275,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
...
@@ -277,7 +275,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
prev_out_grad
,
T
*
output_grad
,
T
*
prev_out_grad
,
T
*
output_grad
,
int
frame_size
,
int
frame_size
,
activation_mode_t
active_node
)
{
ActivationType
active_node
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r_update_gate_value
;
__m256
r_update_gate_value
;
__m256
r_update_gate_grad
;
__m256
r_update_gate_grad
;
...
@@ -320,7 +318,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
...
@@ -320,7 +318,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
prev_out_grad
,
T
*
reset_output_grad
,
T
*
prev_out_grad
,
T
*
reset_output_grad
,
int
frame_size
,
int
frame_size
,
activation_mode_t
active_gate
)
{
ActivationType
active_gate
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r_update_gate_value
;
__m256
r_update_gate_value
;
__m256
r_update_gate_grad
;
__m256
r_update_gate_grad
;
...
@@ -364,9 +362,9 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
...
@@ -364,9 +362,9 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
template
<
class
OpStateGrad
,
typename
T
>
template
<
class
OpStateGrad
,
typename
T
>
inline
void
backward_state_grad
(
OpStateGrad
op_state_grad
,
inline
void
backward_state_grad
(
OpStateGrad
op_state_grad
,
hl_gru_value
<
T
>
value
,
hl_gru_g
rad
<
T
>
grad
,
GRUMetaValue
<
T
>
value
,
GRUMetaG
rad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
)
{
ActivationType
active_node
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpStateGrad
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
if
(
OpStateGrad
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_backward_state_grad
(
hl_avx_gru_backward_state_grad
(
...
@@ -393,9 +391,9 @@ inline void backward_state_grad(OpStateGrad op_state_grad,
...
@@ -393,9 +391,9 @@ inline void backward_state_grad(OpStateGrad op_state_grad,
template
<
class
OpResetGrad
,
typename
T
>
template
<
class
OpResetGrad
,
typename
T
>
inline
void
backward_reset_grad
(
OpResetGrad
op_reset_grad
,
inline
void
backward_reset_grad
(
OpResetGrad
op_reset_grad
,
hl_gru_value
<
T
>
value
,
hl_gru_g
rad
<
T
>
grad
,
GRUMetaValue
<
T
>
value
,
GRUMetaG
rad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_gate
)
{
ActivationType
active_gate
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpResetGrad
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
if
(
OpResetGrad
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_backward_reset_grad
(
hl_avx_gru_backward_reset_grad
(
...
...
paddle/operators/math/detail/gru_gpu_kernel.h
浏览文件 @
4b7bd642
...
@@ -19,8 +19,6 @@ limitations under the License. */
...
@@ -19,8 +19,6 @@ limitations under the License. */
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/device_context.h"
#include <glog/logging.h>
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
...
@@ -35,7 +33,7 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
...
@@ -35,7 +33,7 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
T
*
gate_value
,
T
*
reset_output_value
,
T
*
gate_value
,
T
*
reset_output_value
,
T
*
prev_output_value
,
int
frame_size
,
T
*
prev_output_value
,
int
frame_size
,
int
batch_size
,
int
batch_size
,
activation_mode_t
active_gate
)
{
ActivationType
active_gate
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
if
(
frame_idx
>=
frame_size
)
return
;
...
@@ -74,7 +72,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
...
@@ -74,7 +72,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
gate_value
,
T
*
prev_output_value
,
T
*
output_value
,
int
frame_size
,
T
*
output_value
,
int
frame_size
,
int
batch_size
,
int
batch_size
,
activation_mode_t
active_node
)
{
ActivationType
active_node
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
if
(
frame_idx
>=
frame_size
)
return
;
int
batch_idx
=
0
;
int
batch_idx
=
0
;
...
@@ -111,7 +109,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
...
@@ -111,7 +109,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
prev_out_grad
,
T
*
output_grad
,
T
*
prev_out_grad
,
T
*
output_grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
)
{
ActivationType
active_node
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
if
(
frame_idx
>=
frame_size
)
return
;
int
batch_idx
=
0
;
int
batch_idx
=
0
;
...
@@ -159,7 +157,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
...
@@ -159,7 +157,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
prev_out_grad
,
T
*
reset_output_grad
,
T
*
prev_out_grad
,
T
*
reset_output_grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_gate
)
{
ActivationType
active_gate
)
{
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame_idx
>=
frame_size
)
return
;
if
(
frame_idx
>=
frame_size
)
return
;
int
batch_idx
=
0
;
int
batch_idx
=
0
;
...
...
paddle/operators/math/detail/gru_kernel.h
浏览文件 @
4b7bd642
...
@@ -30,7 +30,7 @@ class gru_resetOutput {
...
@@ -30,7 +30,7 @@ class gru_resetOutput {
public:
public:
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
value_reset_gate
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
value_reset_gate
,
T
&
prev_out
,
T
&
value_reset_output
,
T
&
prev_out
,
T
&
value_reset_output
,
activation_mode_t
act_gate
)
{
ActivationType
act_gate
)
{
value_update_gate
=
activation
(
value_update_gate
,
act_gate
);
value_update_gate
=
activation
(
value_update_gate
,
act_gate
);
value_reset_gate
=
activation
(
value_reset_gate
,
act_gate
);
value_reset_gate
=
activation
(
value_reset_gate
,
act_gate
);
value_reset_output
=
prev_out
*
value_reset_gate
;
value_reset_output
=
prev_out
*
value_reset_gate
;
...
@@ -43,7 +43,7 @@ class gru_resetOutput {
...
@@ -43,7 +43,7 @@ class gru_resetOutput {
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
value_reset_gate
,
__m256
&
prev_out
,
__m256
&
value_reset_gate
,
__m256
&
prev_out
,
__m256
&
value_reset_output
,
__m256
&
value_reset_output
,
activation_mode_t
act_gate
)
{
ActivationType
act_gate
)
{
value_update_gate
=
activation
(
value_update_gate
,
act_gate
);
value_update_gate
=
activation
(
value_update_gate
,
act_gate
);
value_reset_gate
=
activation
(
value_reset_gate
,
act_gate
);
value_reset_gate
=
activation
(
value_reset_gate
,
act_gate
);
value_reset_output
=
_mm256_mul_ps
(
prev_out
,
value_reset_gate
);
value_reset_output
=
_mm256_mul_ps
(
prev_out
,
value_reset_gate
);
...
@@ -57,7 +57,7 @@ class gru_finalOutput {
...
@@ -57,7 +57,7 @@ class gru_finalOutput {
public:
public:
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
value_frame_state
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
value_frame_state
,
T
&
prev_out
,
T
&
value_output
,
T
&
prev_out
,
T
&
value_output
,
activation_mode_t
act_input
)
{
ActivationType
act_input
)
{
value_frame_state
=
activation
(
value_frame_state
,
act_input
);
value_frame_state
=
activation
(
value_frame_state
,
act_input
);
value_output
=
prev_out
-
(
value_update_gate
*
prev_out
)
+
value_output
=
prev_out
-
(
value_update_gate
*
prev_out
)
+
(
value_update_gate
*
value_frame_state
);
(
value_update_gate
*
value_frame_state
);
...
@@ -69,8 +69,7 @@ class gru_finalOutput {
...
@@ -69,8 +69,7 @@ class gru_finalOutput {
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
value_frame_state
,
__m256
&
prev_out
,
__m256
&
value_frame_state
,
__m256
&
prev_out
,
__m256
&
value_output
,
__m256
&
value_output
,
ActivationType
act_input
)
{
activation_mode_t
act_input
)
{
value_frame_state
=
activation
(
value_frame_state
,
act_input
);
value_frame_state
=
activation
(
value_frame_state
,
act_input
);
value_output
=
_mm256_add_ps
(
value_output
=
_mm256_add_ps
(
_mm256_sub_ps
(
prev_out
,
_mm256_mul_ps
(
value_update_gate
,
prev_out
)),
_mm256_sub_ps
(
prev_out
,
_mm256_mul_ps
(
value_update_gate
,
prev_out
)),
...
@@ -89,7 +88,7 @@ class gru_stateGrad {
...
@@ -89,7 +88,7 @@ class gru_stateGrad {
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
grad_update_gate
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
grad_update_gate
,
T
&
value_frame_state
,
T
&
grad_frame_state
,
T
&
value_frame_state
,
T
&
grad_frame_state
,
T
&
value_prev_out
,
T
&
grad_prev_out
,
T
&
value_prev_out
,
T
&
grad_prev_out
,
T
&
grad_output
,
activation_mode_t
act_input
)
{
T
&
grad_output
,
ActivationType
act_input
)
{
grad_update_gate
=
(
grad_output
*
value_frame_state
);
grad_update_gate
=
(
grad_output
*
value_frame_state
);
grad_update_gate
-=
(
grad_output
*
value_prev_out
);
grad_update_gate
-=
(
grad_output
*
value_prev_out
);
grad_prev_out
-=
(
grad_output
*
value_update_gate
);
grad_prev_out
-=
(
grad_output
*
value_update_gate
);
...
@@ -107,7 +106,7 @@ class gru_stateGrad {
...
@@ -107,7 +106,7 @@ class gru_stateGrad {
__m256
&
value_frame_state
,
__m256
&
value_frame_state
,
__m256
&
grad_frame_state
,
__m256
&
value_prev_out
,
__m256
&
grad_frame_state
,
__m256
&
value_prev_out
,
__m256
&
grad_prev_out
,
__m256
&
grad_output
,
__m256
&
grad_prev_out
,
__m256
&
grad_output
,
activation_mode_t
act_input
)
{
ActivationType
act_input
)
{
grad_update_gate
=
_mm256_mul_ps
(
grad_output
,
value_frame_state
);
grad_update_gate
=
_mm256_mul_ps
(
grad_output
,
value_frame_state
);
grad_update_gate
=
_mm256_sub_ps
(
grad_update_gate
=
_mm256_sub_ps
(
grad_update_gate
,
_mm256_mul_ps
(
grad_output
,
value_prev_out
));
grad_update_gate
,
_mm256_mul_ps
(
grad_output
,
value_prev_out
));
...
@@ -128,7 +127,7 @@ class gru_resetGrad {
...
@@ -128,7 +127,7 @@ class gru_resetGrad {
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
grad_update_gate
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
grad_update_gate
,
T
&
value_reset_gate
,
T
&
grad_reset_gate
,
T
&
value_reset_gate
,
T
&
grad_reset_gate
,
T
&
value_prev_out
,
T
&
grad_prev_out
,
T
&
value_prev_out
,
T
&
grad_prev_out
,
T
&
grad_reset_output
,
activation_mode_t
act_gate
)
{
T
&
grad_reset_output
,
ActivationType
act_gate
)
{
grad_reset_gate
=
(
grad_reset_output
*
value_prev_out
);
grad_reset_gate
=
(
grad_reset_output
*
value_prev_out
);
grad_prev_out
+=
(
grad_reset_output
*
value_reset_gate
);
grad_prev_out
+=
(
grad_reset_output
*
value_reset_gate
);
grad_update_gate
=
grad_update_gate
=
...
@@ -144,7 +143,7 @@ class gru_resetGrad {
...
@@ -144,7 +143,7 @@ class gru_resetGrad {
__m256
&
grad_update_gate
,
__m256
&
value_reset_gate
,
__m256
&
grad_update_gate
,
__m256
&
value_reset_gate
,
__m256
&
grad_reset_gate
,
__m256
&
value_prev_out
,
__m256
&
grad_reset_gate
,
__m256
&
value_prev_out
,
__m256
&
grad_prev_out
,
__m256
&
grad_reset_output
,
__m256
&
grad_prev_out
,
__m256
&
grad_reset_output
,
activation_mode_t
act_gate
)
{
ActivationType
act_gate
)
{
grad_reset_gate
=
_mm256_mul_ps
(
grad_reset_output
,
value_prev_out
);
grad_reset_gate
=
_mm256_mul_ps
(
grad_reset_output
,
value_prev_out
);
grad_prev_out
=
_mm256_add_ps
(
grad_prev_out
=
_mm256_add_ps
(
grad_prev_out
,
_mm256_mul_ps
(
grad_reset_output
,
value_reset_gate
));
grad_prev_out
,
_mm256_mul_ps
(
grad_reset_output
,
value_reset_gate
));
...
...
paddle/operators/math/gru_compute.cc
浏览文件 @
4b7bd642
...
@@ -21,9 +21,9 @@ namespace math {
...
@@ -21,9 +21,9 @@ namespace math {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
GRUUnitFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
hl_gru_v
alue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
GRUMetaV
alue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
const
detail
::
ActivationType
active_node
,
activation_mode_t
active_gate
)
{
const
detail
::
ActivationType
active_gate
)
{
#ifndef __NVCC__
#ifndef __NVCC__
if
(
value
.
prev_out_value
)
{
if
(
value
.
prev_out_value
)
{
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
math
::
gemm
<
platform
::
CPUDeviceContext
,
T
>
(
...
@@ -51,10 +51,10 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
...
@@ -51,10 +51,10 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
struct
GRUUnitGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
static
void
compute
(
const
platform
::
CPUDeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_g
rad
<
T
>
grad
,
GRUMetaValue
<
T
>
value
,
GRUMetaG
rad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
const
detail
::
ActivationType
active_node
,
activation_mode_t
active_gate
)
{
const
detail
::
ActivationType
active_gate
)
{
#ifndef __NVCC__
#ifndef __NVCC__
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
active_node
);
grad
,
frame_size
,
batch_size
,
active_node
);
...
...
paddle/operators/math/gru_compute.cu
浏览文件 @
4b7bd642
...
@@ -21,9 +21,9 @@ namespace math {
...
@@ -21,9 +21,9 @@ namespace math {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
GRUUnitFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
hl_gru_v
alue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
GRUMetaV
alue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
const
detail
::
ActivationType
active_node
,
activation_mode_t
active_gate
)
{
const
detail
::
ActivationType
active_gate
)
{
auto
stream
=
context
.
stream
();
auto
stream
=
context
.
stream
();
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
...
@@ -88,10 +88,10 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
...
@@ -88,10 +88,10 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitGradFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
struct
GRUUnitGradFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
static
void
compute
(
const
platform
::
CUDADeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_g
rad
<
T
>
grad
,
GRUMetaValue
<
T
>
value
,
GRUMetaG
rad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
const
detail
::
ActivationType
active_node
,
activation_mode_t
active_gate
)
{
const
detail
::
ActivationType
active_gate
)
{
auto
stream
=
context
.
stream
();
auto
stream
=
context
.
stream
();
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
...
...
paddle/operators/math/gru_compute.h
浏览文件 @
4b7bd642
...
@@ -11,7 +11,7 @@ limitations under the License. */
...
@@ -11,7 +11,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/operators/math/
lstm_compute
.h"
#include "paddle/operators/math/
detail/activation_functions
.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/enforce.h"
...
@@ -19,9 +19,8 @@ namespace paddle {
...
@@ -19,9 +19,8 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
// TODO(guosheng): refine code style in gru_compute
template
<
typename
T
>
template
<
typename
T
>
struct
hl_gru_v
alue
{
struct
GRUMetaV
alue
{
T
*
gate_weight
;
T
*
gate_weight
;
T
*
state_weight
;
T
*
state_weight
;
T
*
gate_value
;
T
*
gate_value
;
...
@@ -31,7 +30,7 @@ struct hl_gru_value {
...
@@ -31,7 +30,7 @@ struct hl_gru_value {
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
hl_gru_g
rad
{
struct
GRUMetaG
rad
{
T
*
gate_weight_grad
;
T
*
gate_weight_grad
;
T
*
state_weight_grad
;
T
*
state_weight_grad
;
T
*
gate_grad
;
T
*
gate_grad
;
...
@@ -42,18 +41,18 @@ struct hl_gru_grad {
...
@@ -42,18 +41,18 @@ struct hl_gru_grad {
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
GRUUnitFunctor
{
struct
GRUUnitFunctor
{
static
void
compute
(
const
DeviceContext
&
context
,
hl_gru_v
alue
<
T
>
value
,
static
void
compute
(
const
DeviceContext
&
context
,
GRUMetaV
alue
<
T
>
value
,
int
frame_size
,
int
batch_size
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
const
detail
::
ActivationType
active_node
,
activation_mode_t
active_gate
);
const
detail
::
ActivationType
active_gate
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
struct
GRUUnitGradFunctor
{
struct
GRUUnitGradFunctor
{
static
void
compute
(
const
DeviceContext
&
context
,
hl_gru_v
alue
<
T
>
value
,
static
void
compute
(
const
DeviceContext
&
context
,
GRUMetaV
alue
<
T
>
value
,
hl_gru_g
rad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
GRUMetaG
rad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
const
detail
::
ActivationType
active_node
,
activation_mode_t
active_gate
);
const
detail
::
ActivationType
active_gate
);
};
};
}
// namespace math
}
// namespace math
...
...
paddle/operators/math/lstm_compute.h
浏览文件 @
4b7bd642
...
@@ -22,14 +22,6 @@ namespace paddle {
...
@@ -22,14 +22,6 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
typedef
enum
{
HL_ACTIVATION_SIGMOID
=
0
,
HL_ACTIVATION_RELU
=
1
,
HL_ACTIVATION_TANH
=
2
,
HL_ACTIVATION_LINEAR
=
3
,
HL_ACTIVATION_END
}
activation_mode_t
;
template
<
class
T
>
template
<
class
T
>
struct
LstmMetaValue
{
struct
LstmMetaValue
{
T
*
gate_value
;
T
*
gate_value
;
...
@@ -54,20 +46,6 @@ struct LstmMetaGrad {
...
@@ -54,20 +46,6 @@ struct LstmMetaGrad {
T
*
check_og_grad
;
T
*
check_og_grad
;
};
};
inline
activation_mode_t
ActiveType
(
const
std
::
string
&
type
)
{
if
(
type
==
"sigmoid"
)
{
return
HL_ACTIVATION_SIGMOID
;
}
else
if
(
type
==
"relu"
)
{
return
HL_ACTIVATION_RELU
;
}
else
if
(
type
==
"tanh"
)
{
return
HL_ACTIVATION_TANH
;
}
else
if
(
type
==
"linear"
||
type
==
"identity"
||
type
==
""
)
{
return
HL_ACTIVATION_LINEAR
;
}
else
{
PADDLE_THROW
(
"Do not support activation type."
);
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
LstmUnitFunctor
{
class
LstmUnitFunctor
{
public:
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录