Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9b16e540
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9b16e540
编写于
1月 13, 2019
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update gru_grad_op
test=develop
上级
e477d789
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
55 addition
and
34 deletion
+55
-34
paddle/fluid/operators/math/detail/gru_cpu_kernel.h
paddle/fluid/operators/math/detail/gru_cpu_kernel.h
+30
-22
paddle/fluid/operators/math/detail/gru_kernel.h
paddle/fluid/operators/math/detail/gru_kernel.h
+21
-10
paddle/fluid/operators/math/gru_compute.cc
paddle/fluid/operators/math/gru_compute.cc
+4
-2
未找到文件。
paddle/fluid/operators/math/detail/gru_cpu_kernel.h
浏览文件 @
9b16e540
...
...
@@ -256,7 +256,8 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
prev_out_grad
,
T
*
output_grad
,
int
frame_size
,
ActivationType
active_node
)
{
ActivationType
active_node
,
bool
origin_mode
)
{
T
r_update_gate_value
;
T
r_update_gate_grad
;
T
r_frame_state_value
;
...
...
@@ -282,7 +283,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad
(
&
r_update_gate_value
,
&
r_update_gate_grad
,
&
r_frame_state_value
,
&
r_frame_state_grad
,
&
r_prev_out_value
,
&
r_prev_out_grad
,
&
r_out_grad
,
active_node
);
&
r_prev_out_grad
,
&
r_out_grad
,
active_node
,
origin_mode
);
update_gate_grad
[
i
]
=
r_update_gate_grad
;
frame_state_grad
[
i
]
=
r_frame_state_grad
;
...
...
@@ -297,7 +298,8 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
prev_out_grad
,
T
*
reset_output_grad
,
int
frame_size
,
ActivationType
active_gate
)
{
ActivationType
active_gate
,
bool
origin_mode
)
{
T
r_update_gate_value
;
T
r_update_gate_grad
;
T
r_reset_gate_value
;
...
...
@@ -327,7 +329,8 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
op_reset_grad
(
&
r_update_gate_value
,
&
r_update_gate_grad
,
&
r_reset_gate_value
,
&
r_reset_gate_grad
,
&
r_prev_out_value
,
&
r_prev_out_grad
,
&
r_reset_output_grad
,
active_gate
);
&
r_prev_out_grad
,
&
r_reset_output_grad
,
active_gate
,
origin_mode
);
update_gate_grad
[
i
]
=
r_update_gate_grad
;
reset_gate_grad
[
i
]
=
r_reset_gate_grad
;
...
...
@@ -341,8 +344,8 @@ template <class OpStateGrad, typename T>
void
hl_avx_gru_backward_state_grad
(
OpStateGrad
op_state_grad
,
T
*
gate_value
,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
prev_out_grad
,
T
*
output_grad
,
int
frame_size
,
ActivationType
active_n
ode
)
{
int
frame_size
,
ActivationType
active_node
,
bool
origin_m
ode
)
{
#ifdef __AVX__
__m256
r_update_gate_value
;
__m256
r_update_gate_grad
;
...
...
@@ -371,7 +374,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad
(
&
r_update_gate_value
,
&
r_update_gate_grad
,
&
r_frame_state_value
,
&
r_frame_state_grad
,
&
r_prev_out_value
,
&
r_prev_out_grad
,
&
r_out_grad
,
active_node
);
&
r_prev_out_grad
,
&
r_out_grad
,
active_node
,
origin_mode
);
update_gate_grad
[
i
]
=
r_update_gate_grad
;
frame_state_grad
[
i
]
=
r_frame_state_grad
;
...
...
@@ -386,8 +389,8 @@ template <class OpResetGrad, typename T>
void
hl_avx_gru_backward_reset_grad
(
OpResetGrad
op_reset_grad
,
T
*
gate_value
,
T
*
gate_grad
,
T
*
prev_out_value
,
T
*
prev_out_grad
,
T
*
reset_output_grad
,
int
frame_size
,
ActivationType
active_gat
e
)
{
int
frame_size
,
ActivationType
active_gate
,
bool
origin_mod
e
)
{
#ifdef __AVX__
__m256
r_update_gate_value
;
__m256
r_update_gate_grad
;
...
...
@@ -419,7 +422,8 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
op_reset_grad
(
&
r_update_gate_value
,
&
r_update_gate_grad
,
&
r_reset_gate_value
,
&
r_reset_gate_grad
,
&
r_prev_out_value
,
&
r_prev_out_grad
,
&
r_reset_output_grad
,
active_gate
);
&
r_prev_out_grad
,
&
r_reset_output_grad
,
active_gate
,
origin_mode
);
update_gate_grad
[
i
]
=
r_update_gate_grad
;
reset_gate_grad
[
i
]
=
r_reset_gate_grad
;
...
...
@@ -434,16 +438,18 @@ template <class OpStateGrad, typename T>
inline
void
backward_state_grad
(
OpStateGrad
op_state_grad
,
GRUMetaValue
<
T
>
value
,
GRUMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
ActivationType
active_node
)
{
ActivationType
active_node
,
bool
origin_mode
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpStateGrad
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_backward_state_grad
(
op_state_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
grad
.
output_grad
,
frame_size
,
active_node
);
hl_avx_gru_backward_state_grad
(
op_state_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
grad
.
output_grad
,
frame_size
,
active_node
,
origin_mode
);
}
else
{
hl_naive_gru_backward_state_grad
(
op_state_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
grad
.
output_grad
,
frame_size
,
active_node
);
hl_naive_gru_backward_state_grad
(
op_state_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
grad
.
output_grad
,
frame_size
,
active_node
,
origin_mode
);
}
value
.
gate_value
+=
frame_size
*
3
;
...
...
@@ -463,16 +469,18 @@ template <class OpResetGrad, typename T>
inline
void
backward_reset_grad
(
OpResetGrad
op_reset_grad
,
GRUMetaValue
<
T
>
value
,
GRUMetaGrad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
ActivationType
active_gate
)
{
ActivationType
active_gate
,
bool
origin_mode
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpResetGrad
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_backward_reset_grad
(
op_reset_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
grad
.
reset_output_grad
,
frame_size
,
active_gate
);
hl_avx_gru_backward_reset_grad
(
op_reset_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
grad
.
reset_output_grad
,
frame_size
,
active_gate
,
origin_mode
);
}
else
{
hl_naive_gru_backward_reset_grad
(
op_reset_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
grad
.
reset_output_grad
,
frame_size
,
active_gate
);
grad
.
prev_out_grad
,
grad
.
reset_output_grad
,
frame_size
,
active_gate
,
origin_mode
);
}
value
.
gate_value
+=
frame_size
*
3
;
...
...
paddle/fluid/operators/math/detail/gru_kernel.h
浏览文件 @
9b16e540
...
...
@@ -103,14 +103,24 @@ class gru_stateGrad {
HOSTDEVICE
void
operator
()(
T
*
value_update_gate
,
T
*
grad_update_gate
,
T
*
value_frame_state
,
T
*
grad_frame_state
,
T
*
value_prev_out
,
T
*
grad_prev_out
,
T
*
grad_output
,
ActivationType
act_input
)
{
*
grad_update_gate
=
(
*
grad_output
*
(
*
value_frame_state
));
*
grad_update_gate
-=
(
*
grad_output
*
(
*
value_prev_out
));
*
grad_prev_out
-=
(
*
grad_output
*
(
*
value_update_gate
));
*
grad_prev_out
+=
*
grad_output
;
T
*
grad_output
,
ActivationType
act_input
,
bool
origin_mode
)
{
if
(
origin_mode
)
{
*
grad_update_gate
=
(
*
grad_output
)
*
((
*
value_prev_out
)
-
(
*
value_frame_state
));
*
grad_prev_out
+=
(
*
grad_output
*
(
*
value_update_gate
));
*
grad_frame_state
=
activation
(
*
grad_output
*
(
static_cast
<
T
>
(
1.0
)
-
(
*
value_update_gate
)),
*
value_frame_state
,
act_input
);
}
else
{
*
grad_update_gate
=
(
*
grad_output
)
*
((
*
value_frame_state
)
-
(
*
value_prev_out
));
*
grad_prev_out
+=
(
*
grad_output
*
(
static_cast
<
T
>
(
1.0
)
-
*
value_update_gate
));
*
grad_frame_state
=
activation
(
*
grad_output
*
(
*
value_update_gate
),
*
value_frame_state
,
act_input
);
}
}
#ifndef __NVCC__
#ifndef __AVX__
static
const
bool
avx
=
false
;
...
...
@@ -121,7 +131,7 @@ class gru_stateGrad {
__m256
*
value_frame_state
,
__m256
*
grad_frame_state
,
__m256
*
value_prev_out
,
__m256
*
grad_prev_out
,
__m256
*
grad_output
,
ActivationType
act_input
)
{
ActivationType
act_input
,
bool
origin_mode
)
{
*
grad_update_gate
=
_mm256_mul_ps
(
*
grad_output
,
*
value_frame_state
);
*
grad_update_gate
=
_mm256_sub_ps
(
*
grad_update_gate
,
_mm256_mul_ps
(
*
grad_output
,
*
value_prev_out
));
...
...
@@ -143,7 +153,8 @@ class gru_resetGrad {
HOSTDEVICE
void
operator
()(
T
*
value_update_gate
,
T
*
grad_update_gate
,
T
*
value_reset_gate
,
T
*
grad_reset_gate
,
T
*
value_prev_out
,
T
*
grad_prev_out
,
T
*
grad_reset_output
,
ActivationType
act_gate
)
{
T
*
grad_reset_output
,
ActivationType
act_gate
,
bool
origin_mode
)
{
*
grad_reset_gate
=
(
*
grad_reset_output
*
(
*
value_prev_out
));
*
grad_prev_out
+=
(
*
grad_reset_output
*
(
*
value_reset_gate
));
*
grad_update_gate
=
...
...
@@ -160,7 +171,7 @@ class gru_resetGrad {
__m256
*
grad_update_gate
,
__m256
*
value_reset_gate
,
__m256
*
grad_reset_gate
,
__m256
*
value_prev_out
,
__m256
*
grad_prev_out
,
__m256
*
grad_reset_output
,
ActivationType
act_gate
)
{
ActivationType
act_gate
,
bool
origin_mode
)
{
*
grad_reset_gate
=
_mm256_mul_ps
(
*
grad_reset_output
,
*
value_prev_out
);
*
grad_prev_out
=
_mm256_add_ps
(
*
grad_prev_out
,
_mm256_mul_ps
(
*
grad_reset_output
,
*
value_reset_gate
));
...
...
paddle/fluid/operators/math/gru_compute.cc
浏览文件 @
9b16e540
...
...
@@ -60,7 +60,8 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
bool
origin_mode
)
{
#ifndef __NVCC__
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
active_node
);
grad
,
frame_size
,
batch_size
,
active_node
,
origin_mode
);
auto
blas
=
math
::
GetBlas
<
platform
::
CPUDeviceContext
,
T
>
(
context
);
if
(
value
.
prev_out_value
&&
grad
.
prev_out_grad
)
{
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
,
1
,
...
...
@@ -77,7 +78,8 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
}
detail
::
backward_reset_grad
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
,
grad
,
frame_size
,
batch_size
,
active_gate
);
grad
,
frame_size
,
batch_size
,
active_gate
,
origin_mode
);
if
(
grad
.
prev_out_grad
&&
value
.
prev_out_value
)
{
blas
.
GEMM
(
false
,
true
,
batch_size
,
frame_size
,
frame_size
*
2
,
1
,
grad
.
gate_grad
,
frame_size
*
3
,
value
.
gate_weight
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录