Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3e552cdc
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3e552cdc
编写于
11月 29, 2017
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix gru_op related code style
上级
dcf3ffd9
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
617 addition
and
599 deletion
+617
-599
paddle/operators/gru_op.h
paddle/operators/gru_op.h
+23
-23
paddle/operators/math/detail/gru_cpu_kernel.h
paddle/operators/math/detail/gru_cpu_kernel.h
+272
-268
paddle/operators/math/detail/gru_gpu_kernel.h
paddle/operators/math/detail/gru_gpu_kernel.h
+126
-126
paddle/operators/math/detail/gru_kernel.h
paddle/operators/math/detail/gru_kernel.h
+72
-63
paddle/operators/math/gru_compute.cc
paddle/operators/math/gru_compute.cc
+33
-31
paddle/operators/math/gru_compute.cu
paddle/operators/math/gru_compute.cu
+75
-73
paddle/operators/math/gru_compute.h
paddle/operators/math/gru_compute.h
+16
-15
未找到文件。
paddle/operators/gru_op.h
浏览文件 @
3e552cdc
...
...
@@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel<T> {
int
frame_size
=
hidden_dims
[
1
];
math
::
hl_gru_value
<
T
>
gru_value
;
gru_value
.
gate
W
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state
W
eight
=
gru_value
.
gate
_w
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state
_w
eight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
Tensor
ordered_h0
;
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
...
...
@@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel<T> {
// to reorder.
ReorderInitState
<
Place
,
T
>
(
context
.
device_context
(),
*
h0
,
order
,
&
ordered_h0
,
true
);
gru_value
.
prev
OutV
alue
=
ordered_h0
.
data
<
T
>
();
gru_value
.
prev
_out_v
alue
=
ordered_h0
.
data
<
T
>
();
}
else
{
gru_value
.
prev
OutV
alue
=
nullptr
;
gru_value
.
prev
_out_v
alue
=
nullptr
;
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
...
...
@@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel<T> {
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
gru_value
.
output
V
alue
=
hidden_t
.
data
<
T
>
();
gru_value
.
gate
V
alue
=
gate_t
.
data
<
T
>
();
gru_value
.
reset
OutputV
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
gru_value
.
output
_v
alue
=
hidden_t
.
data
<
T
>
();
gru_value
.
gate
_v
alue
=
gate_t
.
data
<
T
>
();
gru_value
.
reset
_output_v
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
math
::
GRUUnitFunctor
<
Place
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"activation"
)),
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
)));
gru_value
.
prev
OutValue
=
gru_value
.
outputV
alue
;
gru_value
.
prev
_out_value
=
gru_value
.
output_v
alue
;
}
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
...
...
@@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel<T> {
to_batch
(
dev_ctx
,
*
hidden_grad
,
batch_hidden_grad
,
false
,
is_reverse
);
math
::
hl_gru_value
<
T
>
gru_value
;
gru_value
.
gate
W
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state
W
eight
=
gru_value
.
gate
_w
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state
_w
eight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
math
::
hl_gru_grad
<
T
>
gru_grad
;
if
(
weight_grad
)
{
gru_grad
.
gate
WeightG
rad
=
gru_grad
.
gate
_weight_g
rad
=
weight_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
zero
(
dev_ctx
,
weight_grad
,
static_cast
<
T
>
(
0.0
));
gru_grad
.
state
WeightG
rad
=
gru_grad
.
state
_weight_g
rad
=
weight_grad
->
data
<
T
>
()
+
2
*
frame_size
*
frame_size
;
}
else
{
gru_grad
.
gate
WeightG
rad
=
nullptr
;
gru_grad
.
state
WeightG
rad
=
nullptr
;
gru_grad
.
gate
_weight_g
rad
=
nullptr
;
gru_grad
.
state
_weight_g
rad
=
nullptr
;
}
auto
batch_starts
=
batch_hidden_grad
.
lod
()[
0
];
...
...
@@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel<T> {
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
gru_value
.
gate
V
alue
=
gate_t
.
data
<
T
>
();
gru_value
.
gate
_v
alue
=
gate_t
.
data
<
T
>
();
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
gru_value
.
reset
OutputV
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
gru_value
.
reset
_output_v
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
Tensor
hidden_grad_t
=
batch_hidden_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
output
G
rad
=
hidden_grad_t
.
data
<
T
>
();
gru_grad
.
output
_g
rad
=
hidden_grad_t
.
data
<
T
>
();
Tensor
gate_grad_t
=
batch_gate_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
gate
G
rad
=
gate_grad_t
.
data
<
T
>
();
gru_grad
.
gate
_g
rad
=
gate_grad_t
.
data
<
T
>
();
Tensor
reset_hidden_prev_grad_t
=
batch_reset_hidden_prev_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
reset
OutputG
rad
=
reset_hidden_prev_grad_t
.
data
<
T
>
();
gru_grad
.
reset
_output_g
rad
=
reset_hidden_prev_grad_t
.
data
<
T
>
();
if
(
n
==
0
)
{
gru_value
.
prev
OutV
alue
=
h0
?
ordered_h0
.
data
<
T
>
()
:
nullptr
;
gru_grad
.
prev
OutG
rad
=
gru_value
.
prev
_out_v
alue
=
h0
?
ordered_h0
.
data
<
T
>
()
:
nullptr
;
gru_grad
.
prev
_out_g
rad
=
h0
&&
h0_grad
?
ordered_h0_grad
.
data
<
T
>
()
:
nullptr
;
}
else
{
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
Tensor
hidden_prev_t
=
batch_hidden
->
Slice
(
bstart_pre
,
bstart
);
gru_value
.
prev
OutV
alue
=
hidden_prev_t
.
data
<
T
>
();
gru_value
.
prev
_out_v
alue
=
hidden_prev_t
.
data
<
T
>
();
Tensor
hidden_prev_grad_t
=
batch_hidden_grad
.
Slice
(
bstart_pre
,
bstart
);
gru_grad
.
prev
OutG
rad
=
hidden_prev_grad_t
.
data
<
T
>
();
gru_grad
.
prev
_out_g
rad
=
hidden_prev_grad_t
.
data
<
T
>
();
}
math
::
GRUUnitGradFunctor
<
Place
,
T
>::
compute
(
...
...
paddle/operators/math/detail/gru_cpu_kernel.h
浏览文件 @
3e552cdc
...
...
@@ -25,393 +25,397 @@ namespace detail {
#ifndef __NVCC__
template
<
class
OpResetOutput
,
typename
T
>
void
hl_naive_gru_forward_reset_output
(
OpResetOutput
op
ResetO
utput
,
T
*
gate
Value
,
T
*
resetOutputV
alue
,
T
*
prev
OutputValue
,
int
frameS
ize
,
void
hl_naive_gru_forward_reset_output
(
OpResetOutput
op
_reset_o
utput
,
T
*
gate
_value
,
T
*
reset_output_v
alue
,
T
*
prev
_output_value
,
int
frame_s
ize
,
activation_mode_t
active_gate
)
{
T
r
ValueUpdateG
ate
;
T
r
ValueResetG
ate
;
T
r
ValueResetO
utput
;
T
r
PrevO
ut
=
0
;
T
*
update
Gate
=
gateV
alue
;
T
*
reset
Gate
=
gateValue
+
frameS
ize
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
r
ValueUpdateGate
=
updateG
ate
[
i
];
r
ValueResetGate
=
resetG
ate
[
i
];
if
(
prev
OutputV
alue
)
{
r
PrevOut
=
prevOutputV
alue
[
i
];
T
r
_value_update_g
ate
;
T
r
_value_reset_g
ate
;
T
r
_value_reset_o
utput
;
T
r
_prev_o
ut
=
0
;
T
*
update
_gate
=
gate_v
alue
;
T
*
reset
_gate
=
gate_value
+
frame_s
ize
;
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
_value_update_gate
=
update_g
ate
[
i
];
r
_value_reset_gate
=
reset_g
ate
[
i
];
if
(
prev
_output_v
alue
)
{
r
_prev_out
=
prev_output_v
alue
[
i
];
}
op
ResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevO
ut
,
rValueResetO
utput
,
active_gate
);
op
_reset_output
(
r_value_update_gate
,
r_value_reset_gate
,
r_prev_o
ut
,
r_value_reset_o
utput
,
active_gate
);
update
Gate
[
i
]
=
rValueUpdateG
ate
;
reset
Gate
[
i
]
=
rValueResetG
ate
;
reset
OutputValue
[
i
]
=
rValueResetO
utput
;
update
_gate
[
i
]
=
r_value_update_g
ate
;
reset
_gate
[
i
]
=
r_value_reset_g
ate
;
reset
_output_value
[
i
]
=
r_value_reset_o
utput
;
}
}
template
<
class
OpFinalOutput
,
typename
T
>
void
hl_naive_gru_forward_final_output
(
OpFinalOutput
op
FinalO
utput
,
T
*
gate
Value
,
T
*
prevOutputV
alue
,
T
*
output
Value
,
int
frameS
ize
,
void
hl_naive_gru_forward_final_output
(
OpFinalOutput
op
_final_o
utput
,
T
*
gate
_value
,
T
*
prev_output_v
alue
,
T
*
output
_value
,
int
frame_s
ize
,
activation_mode_t
active_node
)
{
T
r
ValueUpdateG
ate
;
T
r
ValueFrameS
tate
;
T
r
PrevO
ut
=
0
;
T
r
O
utput
;
T
*
update
Gate
=
gateV
alue
;
T
*
frame
State
=
gateValue
+
frameS
ize
*
2
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
r
ValueUpdateGate
=
updateG
ate
[
i
];
r
ValueFrameState
=
frameS
tate
[
i
];
if
(
prev
OutputV
alue
)
{
r
PrevOut
=
prevOutputV
alue
[
i
];
T
r
_value_update_g
ate
;
T
r
_value_frame_s
tate
;
T
r
_prev_o
ut
=
0
;
T
r
_o
utput
;
T
*
update
_gate
=
gate_v
alue
;
T
*
frame
_state
=
gate_value
+
frame_s
ize
*
2
;
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
_value_update_gate
=
update_g
ate
[
i
];
r
_value_frame_state
=
frame_s
tate
[
i
];
if
(
prev
_output_v
alue
)
{
r
_prev_out
=
prev_output_v
alue
[
i
];
}
op
FinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutp
ut
,
active_node
);
op
_final_output
(
r_value_update_gate
,
r_value_frame_state
,
r_prev_o
ut
,
r_output
,
active_node
);
frame
State
[
i
]
=
rValueFrameS
tate
;
output
Value
[
i
]
=
rO
utput
;
frame
_state
[
i
]
=
r_value_frame_s
tate
;
output
_value
[
i
]
=
r_o
utput
;
}
}
template
<
class
OpResetOutput
,
typename
T
>
void
hl_avx_gru_forward_reset_output
(
OpResetOutput
op
ResetOutput
,
T
*
gateValue
,
T
*
resetOutputValue
,
T
*
prevOutputV
alue
,
int
frameS
ize
,
void
hl_avx_gru_forward_reset_output
(
OpResetOutput
op
_reset_output
,
T
*
gate_value
,
T
*
reset_output_v
alue
,
T
*
prev_output_value
,
int
frame_s
ize
,
activation_mode_t
active_gate
)
{
#ifdef __AVX__
__m256
r
ValueUpdateG
ate
;
__m256
r
ValueResetG
ate
;
__m256
r
ValueResetO
utput
;
__m256
r
PrevO
ut
=
_mm256_set1_ps
(
0.0
f
);
__m256
*
update
Gate
=
(
__m256
*
)
gateV
alue
;
__m256
*
reset
Gate
=
(
__m256
*
)(
gateValue
+
frameS
ize
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
r
ValueUpdateGate
=
updateG
ate
[
i
];
r
ValueResetGate
=
resetG
ate
[
i
];
if
(
prev
OutputV
alue
)
{
r
PrevOut
=
((
__m256
*
)
prevOutputV
alue
)[
i
];
__m256
r
_value_update_g
ate
;
__m256
r
_value_reset_g
ate
;
__m256
r
_value_reset_o
utput
;
__m256
r
_prev_o
ut
=
_mm256_set1_ps
(
0.0
f
);
__m256
*
update
_gate
=
(
__m256
*
)
gate_v
alue
;
__m256
*
reset
_gate
=
(
__m256
*
)(
gate_value
+
frame_s
ize
);
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
_value_update_gate
=
update_g
ate
[
i
];
r
_value_reset_gate
=
reset_g
ate
[
i
];
if
(
prev
_output_v
alue
)
{
r
_prev_out
=
((
__m256
*
)
prev_output_v
alue
)[
i
];
}
op
ResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevO
ut
,
rValueResetO
utput
,
active_gate
);
op
_reset_output
(
r_value_update_gate
,
r_value_reset_gate
,
r_prev_o
ut
,
r_value_reset_o
utput
,
active_gate
);
update
Gate
[
i
]
=
rValueUpdateG
ate
;
reset
Gate
[
i
]
=
rValueResetG
ate
;
((
__m256
*
)
reset
OutputValue
)[
i
]
=
rValueResetO
utput
;
update
_gate
[
i
]
=
r_value_update_g
ate
;
reset
_gate
[
i
]
=
r_value_reset_g
ate
;
((
__m256
*
)
reset
_output_value
)[
i
]
=
r_value_reset_o
utput
;
}
#endif
}
template
<
class
OpFinalOutput
,
typename
T
>
void
hl_avx_gru_forward_final_output
(
OpFinalOutput
op
FinalOutput
,
T
*
gateValue
,
T
*
prevOutputValue
,
T
*
outputV
alue
,
int
frameS
ize
,
void
hl_avx_gru_forward_final_output
(
OpFinalOutput
op
_final_output
,
T
*
gate_value
,
T
*
prev_output_v
alue
,
T
*
output_value
,
int
frame_s
ize
,
activation_mode_t
active_node
)
{
#ifdef __AVX__
__m256
r
ValueUpdateG
ate
;
__m256
r
ValueFrameS
tate
;
__m256
r
PrevO
ut
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
O
utput
;
__m256
*
update
Gate
=
(
__m256
*
)
gateV
alue
;
__m256
*
frame
State
=
(
__m256
*
)(
gateValue
+
frameS
ize
*
2
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
r
ValueUpdateGate
=
updateG
ate
[
i
];
r
ValueFrameState
=
frameS
tate
[
i
];
if
(
prev
OutputV
alue
)
{
r
PrevOut
=
((
__m256
*
)
prevOutputV
alue
)[
i
];
__m256
r
_value_update_g
ate
;
__m256
r
_value_frame_s
tate
;
__m256
r
_prev_o
ut
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_o
utput
;
__m256
*
update
_gate
=
(
__m256
*
)
gate_v
alue
;
__m256
*
frame
_state
=
(
__m256
*
)(
gate_value
+
frame_s
ize
*
2
);
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
_value_update_gate
=
update_g
ate
[
i
];
r
_value_frame_state
=
frame_s
tate
[
i
];
if
(
prev
_output_v
alue
)
{
r
_prev_out
=
((
__m256
*
)
prev_output_v
alue
)[
i
];
}
op
FinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutp
ut
,
active_node
);
op
_final_output
(
r_value_update_gate
,
r_value_frame_state
,
r_prev_o
ut
,
r_output
,
active_node
);
frame
State
[
i
]
=
rValueFrameS
tate
;
((
__m256
*
)
output
Value
)[
i
]
=
rO
utput
;
frame
_state
[
i
]
=
r_value_frame_s
tate
;
((
__m256
*
)
output
_value
)[
i
]
=
r_o
utput
;
}
#endif
}
template
<
class
OpResetOutput
,
typename
T
>
inline
void
forward_reset_output
(
OpResetOutput
opResetOutput
,
hl_gru_value
<
T
>
value
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_gate
)
{
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
if
(
OpResetOutput
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
inline
void
forward_reset_output
(
OpResetOutput
op_reset_output
,
hl_gru_value
<
T
>
value
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_gate
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpResetOutput
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_forward_reset_output
(
op
ResetOutput
,
value
.
gateValue
,
value
.
resetOutputV
alue
,
value
.
prev
OutValue
,
frameS
ize
,
active_gate
);
op
_reset_output
,
value
.
gate_value
,
value
.
reset_output_v
alue
,
value
.
prev
_out_value
,
frame_s
ize
,
active_gate
);
}
else
{
hl_naive_gru_forward_reset_output
(
op
ResetOutput
,
value
.
gateValue
,
value
.
resetOutputV
alue
,
value
.
prev
OutValue
,
frameS
ize
,
active_gate
);
op
_reset_output
,
value
.
gate_value
,
value
.
reset_output_v
alue
,
value
.
prev
_out_value
,
frame_s
ize
,
active_gate
);
}
value
.
gate
Value
+=
frameS
ize
*
3
;
value
.
reset
OutputValue
+=
frameS
ize
;
if
(
value
.
prev
OutV
alue
)
{
value
.
prev
OutValue
+=
frameS
ize
;
value
.
gate
_value
+=
frame_s
ize
*
3
;
value
.
reset
_output_value
+=
frame_s
ize
;
if
(
value
.
prev
_out_v
alue
)
{
value
.
prev
_out_value
+=
frame_s
ize
;
}
}
}
template
<
class
OpFinalOutput
,
typename
T
>
inline
void
forward_final_output
(
OpFinalOutput
opFinalOutput
,
hl_gru_value
<
T
>
value
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
)
{
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
if
(
OpFinalOutput
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_forward_final_output
(
opFinalOutput
,
value
.
gateValue
,
value
.
prevOutValue
,
value
.
outputValue
,
frameSize
,
active_node
);
inline
void
forward_final_output
(
OpFinalOutput
op_final_output
,
hl_gru_value
<
T
>
value
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpFinalOutput
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_forward_final_output
(
op_final_output
,
value
.
gate_value
,
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
active_node
);
}
else
{
hl_naive_gru_forward_final_output
(
opFinalOutput
,
value
.
gateValue
,
value
.
prevOutValue
,
value
.
outputV
alue
,
frameS
ize
,
active_node
);
hl_naive_gru_forward_final_output
(
op_final_output
,
value
.
gate_value
,
value
.
prev_out_v
alue
,
value
.
output_value
,
frame_s
ize
,
active_node
);
}
value
.
gate
Value
+=
frameS
ize
*
3
;
value
.
output
Value
+=
frameS
ize
;
if
(
value
.
prev
OutV
alue
)
{
value
.
prev
OutValue
+=
frameS
ize
;
value
.
gate
_value
+=
frame_s
ize
*
3
;
value
.
output
_value
+=
frame_s
ize
;
if
(
value
.
prev
_out_v
alue
)
{
value
.
prev
_out_value
+=
frame_s
ize
;
}
}
}
template
<
class
OpStateGrad
,
typename
T
>
void
hl_naive_gru_backward_state_grad
(
OpStateGrad
op
StateGrad
,
T
*
gateV
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
prev
OutGrad
,
T
*
outputG
rad
,
int
frame
S
ize
,
void
hl_naive_gru_backward_state_grad
(
OpStateGrad
op
_state_grad
,
T
*
gate_v
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
_out_grad
,
T
*
output_g
rad
,
int
frame
_s
ize
,
activation_mode_t
active_node
)
{
T
r
UpdateGateV
alue
;
T
r
UpdateGateG
rad
;
T
r
FrameStateV
alue
;
T
r
FrameStateG
rad
;
T
r
OutG
rad
;
T
r
PrevOutV
alue
=
0
;
T
r
PrevOutG
rad
=
0
;
T
*
update
GateValue
=
gateV
alue
;
T
*
update
GateGrad
=
gateG
rad
;
T
*
frame
StateValue
=
gateValue
+
frameS
ize
*
2
;
T
*
frame
StateGrad
=
gateGrad
+
frameS
ize
*
2
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
r
UpdateGateValue
=
updateGateV
alue
[
i
];
r
FrameStateValue
=
frameStateV
alue
[
i
];
r
OutGrad
=
outputG
rad
[
i
];
if
(
prev
OutV
alue
)
{
r
PrevOutValue
=
prevOutV
alue
[
i
];
T
r
_update_gate_v
alue
;
T
r
_update_gate_g
rad
;
T
r
_frame_state_v
alue
;
T
r
_frame_state_g
rad
;
T
r
_out_g
rad
;
T
r
_prev_out_v
alue
=
0
;
T
r
_prev_out_g
rad
=
0
;
T
*
update
_gate_value
=
gate_v
alue
;
T
*
update
_gate_grad
=
gate_g
rad
;
T
*
frame
_state_value
=
gate_value
+
frame_s
ize
*
2
;
T
*
frame
_state_grad
=
gate_grad
+
frame_s
ize
*
2
;
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
_update_gate_value
=
update_gate_v
alue
[
i
];
r
_frame_state_value
=
frame_state_v
alue
[
i
];
r
_out_grad
=
output_g
rad
[
i
];
if
(
prev
_out_v
alue
)
{
r
_prev_out_value
=
prev_out_v
alue
[
i
];
}
if
(
prev
OutG
rad
)
{
r
PrevOutGrad
=
prevOutG
rad
[
i
];
if
(
prev
_out_g
rad
)
{
r
_prev_out_grad
=
prev_out_g
rad
[
i
];
}
op
StateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateV
alue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutG
rad
,
active_node
);
op
_state_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_frame_state_v
alue
,
r_frame_state_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
r_out_grad
,
active_node
);
update
GateGrad
[
i
]
=
rUpdateGateG
rad
;
frame
StateGrad
[
i
]
=
rFrameStateG
rad
;
if
(
prev
OutG
rad
)
{
prev
OutGrad
[
i
]
=
rPrevOutG
rad
;
update
_gate_grad
[
i
]
=
r_update_gate_g
rad
;
frame
_state_grad
[
i
]
=
r_frame_state_g
rad
;
if
(
prev
_out_g
rad
)
{
prev
_out_grad
[
i
]
=
r_prev_out_g
rad
;
}
}
}
template
<
class
OpResetGrad
,
typename
T
>
void
hl_naive_gru_backward_reset_grad
(
OpResetGrad
op
ResetGrad
,
T
*
gateV
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
prev
OutGrad
,
T
*
resetOutputG
rad
,
int
frame
S
ize
,
void
hl_naive_gru_backward_reset_grad
(
OpResetGrad
op
_reset_grad
,
T
*
gate_v
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
_out_grad
,
T
*
reset_output_g
rad
,
int
frame
_s
ize
,
activation_mode_t
active_gate
)
{
T
r
UpdateGateV
alue
;
T
r
UpdateGateG
rad
;
T
r
ResetGateV
alue
;
T
r
ResetGateG
rad
;
T
r
ResetOutputG
rad
=
0
;
T
r
PrevOutV
alue
=
0
;
T
r
PrevOutG
rad
=
0
;
T
*
update
GateValue
=
gateV
alue
;
T
*
update
GateGrad
=
gateG
rad
;
T
*
reset
GateValue
=
gateValue
+
frameS
ize
;
T
*
reset
GateGrad
=
gateGrad
+
frameS
ize
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
r
UpdateGateValue
=
updateGateV
alue
[
i
];
r
UpdateGateGrad
=
updateGateG
rad
[
i
];
r
ResetGateValue
=
resetGateV
alue
[
i
];
if
(
prev
OutValue
&&
prevOutG
rad
)
{
r
ResetOutputGrad
=
resetOutputG
rad
[
i
];
T
r
_update_gate_v
alue
;
T
r
_update_gate_g
rad
;
T
r
_reset_gate_v
alue
;
T
r
_reset_gate_g
rad
;
T
r
_reset_output_g
rad
=
0
;
T
r
_prev_out_v
alue
=
0
;
T
r
_prev_out_g
rad
=
0
;
T
*
update
_gate_value
=
gate_v
alue
;
T
*
update
_gate_grad
=
gate_g
rad
;
T
*
reset
_gate_value
=
gate_value
+
frame_s
ize
;
T
*
reset
_gate_grad
=
gate_grad
+
frame_s
ize
;
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
_update_gate_value
=
update_gate_v
alue
[
i
];
r
_update_gate_grad
=
update_gate_g
rad
[
i
];
r
_reset_gate_value
=
reset_gate_v
alue
[
i
];
if
(
prev
_out_value
&&
prev_out_g
rad
)
{
r
_reset_output_grad
=
reset_output_g
rad
[
i
];
}
if
(
prev
OutV
alue
)
{
r
PrevOutValue
=
prevOutV
alue
[
i
];
if
(
prev
_out_v
alue
)
{
r
_prev_out_value
=
prev_out_v
alue
[
i
];
}
if
(
prev
OutG
rad
)
{
r
PrevOutGrad
=
prevOutG
rad
[
i
];
if
(
prev
_out_g
rad
)
{
r
_prev_out_grad
=
prev_out_g
rad
[
i
];
}
op
ResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateV
alue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputG
rad
,
active_gate
);
op
_reset_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_reset_gate_v
alue
,
r_reset_gate_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
r_reset_output_grad
,
active_gate
);
update
GateGrad
[
i
]
=
rUpdateGateG
rad
;
reset
GateGrad
[
i
]
=
rResetGateG
rad
;
if
(
prev
OutG
rad
)
{
prev
OutGrad
[
i
]
=
rPrevOutG
rad
;
update
_gate_grad
[
i
]
=
r_update_gate_g
rad
;
reset
_gate_grad
[
i
]
=
r_reset_gate_g
rad
;
if
(
prev
_out_g
rad
)
{
prev
_out_grad
[
i
]
=
r_prev_out_g
rad
;
}
}
}
template
<
class
OpStateGrad
,
typename
T
>
void
hl_avx_gru_backward_state_grad
(
OpStateGrad
op
StateGrad
,
T
*
gateV
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
prev
OutGrad
,
T
*
outputG
rad
,
int
frame
S
ize
,
void
hl_avx_gru_backward_state_grad
(
OpStateGrad
op
_state_grad
,
T
*
gate_v
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
_out_grad
,
T
*
output_g
rad
,
int
frame
_s
ize
,
activation_mode_t
active_node
)
{
#ifdef __AVX__
__m256
r
UpdateGateV
alue
;
__m256
r
UpdateGateG
rad
;
__m256
r
FrameStateV
alue
;
__m256
r
FrameStateG
rad
;
__m256
r
OutG
rad
;
__m256
r
PrevOutV
alue
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
PrevOutG
rad
=
_mm256_set1_ps
(
0.0
f
);
__m256
*
update
GateValue
=
(
__m256
*
)
gateV
alue
;
__m256
*
update
GateGrad
=
(
__m256
*
)
gateG
rad
;
__m256
*
frame
StateValue
=
(
__m256
*
)(
gateValue
+
frameS
ize
*
2
);
__m256
*
frame
StateGrad
=
(
__m256
*
)(
gateGrad
+
frameS
ize
*
2
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
r
UpdateGateValue
=
updateGateV
alue
[
i
];
r
FrameStateValue
=
frameStateV
alue
[
i
];
r
OutGrad
=
((
__m256
*
)
outputG
rad
)[
i
];
if
(
prev
OutV
alue
)
{
r
PrevOutValue
=
((
__m256
*
)
prevOutV
alue
)[
i
];
__m256
r
_update_gate_v
alue
;
__m256
r
_update_gate_g
rad
;
__m256
r
_frame_state_v
alue
;
__m256
r
_frame_state_g
rad
;
__m256
r
_out_g
rad
;
__m256
r
_prev_out_v
alue
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_out_g
rad
=
_mm256_set1_ps
(
0.0
f
);
__m256
*
update
_gate_value
=
(
__m256
*
)
gate_v
alue
;
__m256
*
update
_gate_grad
=
(
__m256
*
)
gate_g
rad
;
__m256
*
frame
_state_value
=
(
__m256
*
)(
gate_value
+
frame_s
ize
*
2
);
__m256
*
frame
_state_grad
=
(
__m256
*
)(
gate_grad
+
frame_s
ize
*
2
);
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
_update_gate_value
=
update_gate_v
alue
[
i
];
r
_frame_state_value
=
frame_state_v
alue
[
i
];
r
_out_grad
=
((
__m256
*
)
output_g
rad
)[
i
];
if
(
prev
_out_v
alue
)
{
r
_prev_out_value
=
((
__m256
*
)
prev_out_v
alue
)[
i
];
}
if
(
prev
OutG
rad
)
{
r
PrevOutGrad
=
((
__m256
*
)
prevOutG
rad
)[
i
];
if
(
prev
_out_g
rad
)
{
r
_prev_out_grad
=
((
__m256
*
)
prev_out_g
rad
)[
i
];
}
op
StateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateV
alue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutG
rad
,
active_node
);
op
_state_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_frame_state_v
alue
,
r_frame_state_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
r_out_grad
,
active_node
);
update
GateGrad
[
i
]
=
rUpdateGateG
rad
;
frame
StateGrad
[
i
]
=
rFrameStateG
rad
;
if
(
prev
OutG
rad
)
{
((
__m256
*
)
prev
OutGrad
)[
i
]
=
rPrevOutG
rad
;
update
_gate_grad
[
i
]
=
r_update_gate_g
rad
;
frame
_state_grad
[
i
]
=
r_frame_state_g
rad
;
if
(
prev
_out_g
rad
)
{
((
__m256
*
)
prev
_out_grad
)[
i
]
=
r_prev_out_g
rad
;
}
}
#endif
}
template
<
class
OpResetGrad
,
typename
T
>
void
hl_avx_gru_backward_reset_grad
(
OpResetGrad
op
ResetGrad
,
T
*
gateV
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
prev
OutGrad
,
T
*
resetOutputG
rad
,
int
frame
S
ize
,
void
hl_avx_gru_backward_reset_grad
(
OpResetGrad
op
_reset_grad
,
T
*
gate_v
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
_out_grad
,
T
*
reset_output_g
rad
,
int
frame
_s
ize
,
activation_mode_t
active_gate
)
{
#ifdef __AVX__
__m256
r
UpdateGateV
alue
;
__m256
r
UpdateGateG
rad
;
__m256
r
ResetGateV
alue
;
__m256
r
ResetGateG
rad
;
__m256
r
ResetOutputG
rad
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
PrevOutV
alue
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
PrevOutG
rad
=
_mm256_set1_ps
(
0.0
f
);
__m256
*
update
GateValue
=
(
__m256
*
)
gateV
alue
;
__m256
*
update
GateGrad
=
(
__m256
*
)
gateG
rad
;
__m256
*
reset
GateValue
=
(
__m256
*
)(
gateValue
+
frameS
ize
);
__m256
*
reset
GateGrad
=
(
__m256
*
)(
gateGrad
+
frameS
ize
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
r
UpdateGateValue
=
updateGateV
alue
[
i
];
r
UpdateGateGrad
=
updateGateG
rad
[
i
];
r
ResetGateValue
=
resetGateV
alue
[
i
];
if
(
prev
OutValue
&&
prevOutG
rad
)
{
r
ResetOutputGrad
=
((
__m256
*
)
resetOutputG
rad
)[
i
];
__m256
r
_update_gate_v
alue
;
__m256
r
_update_gate_g
rad
;
__m256
r
_reset_gate_v
alue
;
__m256
r
_reset_gate_g
rad
;
__m256
r
_reset_output_g
rad
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_out_v
alue
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_out_g
rad
=
_mm256_set1_ps
(
0.0
f
);
__m256
*
update
_gate_value
=
(
__m256
*
)
gate_v
alue
;
__m256
*
update
_gate_grad
=
(
__m256
*
)
gate_g
rad
;
__m256
*
reset
_gate_value
=
(
__m256
*
)(
gate_value
+
frame_s
ize
);
__m256
*
reset
_gate_grad
=
(
__m256
*
)(
gate_grad
+
frame_s
ize
);
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
_update_gate_value
=
update_gate_v
alue
[
i
];
r
_update_gate_grad
=
update_gate_g
rad
[
i
];
r
_reset_gate_value
=
reset_gate_v
alue
[
i
];
if
(
prev
_out_value
&&
prev_out_g
rad
)
{
r
_reset_output_grad
=
((
__m256
*
)
reset_output_g
rad
)[
i
];
}
if
(
prev
OutV
alue
)
{
r
PrevOutValue
=
((
__m256
*
)
prevOutV
alue
)[
i
];
if
(
prev
_out_v
alue
)
{
r
_prev_out_value
=
((
__m256
*
)
prev_out_v
alue
)[
i
];
}
if
(
prev
OutG
rad
)
{
r
PrevOutGrad
=
((
__m256
*
)
prevOutG
rad
)[
i
];
if
(
prev
_out_g
rad
)
{
r
_prev_out_grad
=
((
__m256
*
)
prev_out_g
rad
)[
i
];
}
op
ResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateV
alue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputG
rad
,
active_gate
);
op
_reset_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_reset_gate_v
alue
,
r_reset_gate_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
r_reset_output_grad
,
active_gate
);
update
GateGrad
[
i
]
=
rUpdateGateG
rad
;
reset
GateGrad
[
i
]
=
rResetGateG
rad
;
if
(
prev
OutG
rad
)
{
((
__m256
*
)
prev
OutGrad
)[
i
]
=
rPrevOutG
rad
;
update
_gate_grad
[
i
]
=
r_update_gate_g
rad
;
reset
_gate_grad
[
i
]
=
r_reset_gate_g
rad
;
if
(
prev
_out_g
rad
)
{
((
__m256
*
)
prev
_out_grad
)[
i
]
=
r_prev_out_g
rad
;
}
}
#endif
}
template
<
class
OpStateGrad
,
typename
T
>
inline
void
backward_state_grad
(
OpStateGrad
opStateGrad
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
)
{
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
if
(
OpStateGrad
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
inline
void
backward_state_grad
(
OpStateGrad
op_state_grad
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpStateGrad
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_backward_state_grad
(
op
StateGrad
,
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutV
alue
,
grad
.
prev
OutGrad
,
grad
.
outputGrad
,
frameS
ize
,
active_node
);
op
_state_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_v
alue
,
grad
.
prev
_out_grad
,
grad
.
output_grad
,
frame_s
ize
,
active_node
);
}
else
{
hl_naive_gru_backward_state_grad
(
op
StateGrad
,
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutV
alue
,
grad
.
prev
OutGrad
,
grad
.
outputGrad
,
frameS
ize
,
active_node
);
op
_state_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_v
alue
,
grad
.
prev
_out_grad
,
grad
.
output_grad
,
frame_s
ize
,
active_node
);
}
value
.
gate
Value
+=
frameS
ize
*
3
;
if
(
value
.
prev
OutV
alue
)
{
value
.
prev
OutValue
+=
frameS
ize
;
value
.
gate
_value
+=
frame_s
ize
*
3
;
if
(
value
.
prev
_out_v
alue
)
{
value
.
prev
_out_value
+=
frame_s
ize
;
}
grad
.
gate
Grad
+=
frameS
ize
*
3
;
grad
.
output
Grad
+=
frameS
ize
;
if
(
grad
.
prev
OutG
rad
)
{
grad
.
prev
OutGrad
+=
frameS
ize
;
grad
.
gate
_grad
+=
frame_s
ize
*
3
;
grad
.
output
_grad
+=
frame_s
ize
;
if
(
grad
.
prev
_out_g
rad
)
{
grad
.
prev
_out_grad
+=
frame_s
ize
;
}
}
}
template
<
class
OpResetGrad
,
typename
T
>
inline
void
backward_reset_grad
(
OpResetGrad
opResetGrad
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_gate
)
{
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
if
(
OpResetGrad
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
inline
void
backward_reset_grad
(
OpResetGrad
op_reset_grad
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_gate
)
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpResetGrad
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_backward_reset_grad
(
op
ResetGrad
,
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutV
alue
,
grad
.
prev
OutGrad
,
grad
.
resetOutputGrad
,
frameS
ize
,
active_gate
);
op
_reset_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_v
alue
,
grad
.
prev
_out_grad
,
grad
.
reset_output_grad
,
frame_s
ize
,
active_gate
);
}
else
{
hl_naive_gru_backward_reset_grad
(
op
ResetGrad
,
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutV
alue
,
grad
.
prev
OutGrad
,
grad
.
resetOutputGrad
,
frameS
ize
,
active_gate
);
op
_reset_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_v
alue
,
grad
.
prev
_out_grad
,
grad
.
reset_output_grad
,
frame_s
ize
,
active_gate
);
}
value
.
gate
Value
+=
frameS
ize
*
3
;
if
(
value
.
prev
OutV
alue
)
{
value
.
prev
OutValue
+=
frameS
ize
;
value
.
gate
_value
+=
frame_s
ize
*
3
;
if
(
value
.
prev
_out_v
alue
)
{
value
.
prev
_out_value
+=
frame_s
ize
;
}
grad
.
gate
Grad
+=
frameS
ize
*
3
;
grad
.
reset
OutputGrad
+=
frameS
ize
;
if
(
grad
.
prev
OutG
rad
)
{
grad
.
prev
OutGrad
+=
frameS
ize
;
grad
.
gate
_grad
+=
frame_s
ize
*
3
;
grad
.
reset
_output_grad
+=
frame_s
ize
;
if
(
grad
.
prev
_out_g
rad
)
{
grad
.
prev
_out_grad
+=
frame_s
ize
;
}
}
}
...
...
paddle/operators/math/detail/gru_gpu_kernel.h
浏览文件 @
3e552cdc
...
...
@@ -27,174 +27,174 @@ namespace math {
namespace
detail
{
/*
* threads(frame
PerBlock, batchPerB
lock)
* grid(frame
Blocks, batchB
locks)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
_blocks, batch_b
locks)
*/
template
<
class
OpResetOutput
,
bool
is
B
atch
,
typename
T
>
__global__
void
KeGruForwardResetOutput
(
OpResetOutput
op
ResetO
utput
,
T
*
gate
Value
,
T
*
resetOutputV
alue
,
T
*
prev
OutputValue
,
int
frameS
ize
,
int
batch
S
ize
,
template
<
class
OpResetOutput
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruForwardResetOutput
(
OpResetOutput
op
_reset_o
utput
,
T
*
gate
_value
,
T
*
reset_output_v
alue
,
T
*
prev
_output_value
,
int
frame_s
ize
,
int
batch
_s
ize
,
activation_mode_t
active_gate
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
int
batch
I
dx
=
0
;
if
(
is
B
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
reset
OutputValue
+=
batchIdx
*
frameS
ize
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
_i
dx
=
0
;
if
(
is
_b
atch
)
{
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
reset
_output_value
+=
batch_idx
*
frame_s
ize
;
}
T
r
PrevO
ut
=
0
;
T
r
ValueResetO
utput
;
T
r
ValueUpdateGate
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
ValueResetGate
=
gateValue
[
frameIdx
+
frameS
ize
*
1
];
T
r
_prev_o
ut
=
0
;
T
r
_value_reset_o
utput
;
T
r
_value_update_gate
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
_value_reset_gate
=
gate_value
[
frame_idx
+
frame_s
ize
*
1
];
if
(
prev
OutputV
alue
)
{
if
(
is
Batch
)
prevOutputValue
+=
batchIdx
*
frameS
ize
;
r
PrevOut
=
prevOutputValue
[
frameI
dx
];
if
(
prev
_output_v
alue
)
{
if
(
is
_batch
)
prev_output_value
+=
batch_idx
*
frame_s
ize
;
r
_prev_out
=
prev_output_value
[
frame_i
dx
];
}
op
ResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
rValueResetOutp
ut
,
active_gate
);
op
_reset_output
(
r_value_update_gate
,
r_value_reset_gate
,
r_prev_o
ut
,
r_value_reset_output
,
active_gate
);
gate
Value
[
frameIdx
+
frameSize
*
0
]
=
rValueUpdateG
ate
;
gate
Value
[
frameIdx
+
frameSize
*
1
]
=
rValueResetG
ate
;
reset
OutputValue
[
frameIdx
]
=
rValueResetO
utput
;
gate
_value
[
frame_idx
+
frame_size
*
0
]
=
r_value_update_g
ate
;
gate
_value
[
frame_idx
+
frame_size
*
1
]
=
r_value_reset_g
ate
;
reset
_output_value
[
frame_idx
]
=
r_value_reset_o
utput
;
}
/*
* threads(frame
PerBlock, batchPerB
lock)
* grid(frame
Blocks, batchB
locks)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
_blocks, batch_b
locks)
*/
template
<
class
OpFinalOutput
,
bool
is
B
atch
,
typename
T
>
__global__
void
KeGruForwardFinalOutput
(
OpFinalOutput
op
FinalO
utput
,
T
*
gate
Value
,
T
*
prevOutputV
alue
,
T
*
output
Value
,
int
frameS
ize
,
int
batch
S
ize
,
template
<
class
OpFinalOutput
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruForwardFinalOutput
(
OpFinalOutput
op
_final_o
utput
,
T
*
gate
_value
,
T
*
prev_output_v
alue
,
T
*
output
_value
,
int
frame_s
ize
,
int
batch
_s
ize
,
activation_mode_t
active_node
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
int
batch
I
dx
=
0
;
if
(
is
B
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
output
Value
+=
batchIdx
*
frameS
ize
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
_i
dx
=
0
;
if
(
is
_b
atch
)
{
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
output
_value
+=
batch_idx
*
frame_s
ize
;
}
T
r
O
utput
;
T
r
PrevO
ut
=
0
;
T
r
ValueUpdateGate
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
ValueFrameState
=
gateValue
[
frameIdx
+
frameS
ize
*
2
];
T
r
_o
utput
;
T
r
_prev_o
ut
=
0
;
T
r
_value_update_gate
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
_value_frame_state
=
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
if
(
prev
OutputV
alue
)
{
if
(
is
Batch
)
prevOutputValue
+=
batchIdx
*
frameS
ize
;
r
PrevOut
=
prevOutputValue
[
frameI
dx
];
if
(
prev
_output_v
alue
)
{
if
(
is
_batch
)
prev_output_value
+=
batch_idx
*
frame_s
ize
;
r
_prev_out
=
prev_output_value
[
frame_i
dx
];
}
op
FinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutp
ut
,
active_node
);
op
_final_output
(
r_value_update_gate
,
r_value_frame_state
,
r_prev_o
ut
,
r_output
,
active_node
);
gate
Value
[
frameIdx
+
frameSize
*
2
]
=
rValueFrameS
tate
;
output
Value
[
frameIdx
]
=
rO
utput
;
gate
_value
[
frame_idx
+
frame_size
*
2
]
=
r_value_frame_s
tate
;
output
_value
[
frame_idx
]
=
r_o
utput
;
}
/*
* threads(frame
PerBlock, batchPerB
lock)
* grid(frame
Blocks, batchB
locks)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
_blocks, batch_b
locks)
*/
template
<
class
OpStateGrad
,
bool
is
B
atch
,
typename
T
>
__global__
void
KeGruBackwardStateGrad
(
OpStateGrad
op
StateGrad
,
T
*
gateV
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
prev
OutGrad
,
T
*
outputG
rad
,
int
frame
Size
,
int
batchS
ize
,
template
<
class
OpStateGrad
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruBackwardStateGrad
(
OpStateGrad
op
_state_grad
,
T
*
gate_v
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
_out_grad
,
T
*
output_g
rad
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
int
batch
I
dx
=
0
;
if
(
is
B
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
Grad
+=
batchIdx
*
3
*
frameS
ize
;
output
Grad
+=
batchIdx
*
frameS
ize
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
_i
dx
=
0
;
if
(
is
_b
atch
)
{
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
gate
_grad
+=
batch_idx
*
3
*
frame_s
ize
;
output
_grad
+=
batch_idx
*
frame_s
ize
;
}
T
r
UpdateGateG
rad
;
T
r
FrameStateG
rad
;
T
r
PrevOutV
alue
=
0
;
T
r
PrevOutG
rad
=
0
;
T
r
UpdateGateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
FrameStateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
2
];
T
r
OutGrad
=
outputGrad
[
frameI
dx
];
T
r
_update_gate_g
rad
;
T
r
_frame_state_g
rad
;
T
r
_prev_out_v
alue
=
0
;
T
r
_prev_out_g
rad
=
0
;
T
r
_update_gate_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
_frame_state_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
T
r
_out_grad
=
output_grad
[
frame_i
dx
];
if
(
prev
OutValue
&&
prevOutG
rad
)
{
if
(
is
Batch
)
prevOutValue
+=
batchIdx
*
frameS
ize
;
r
PrevOutValue
=
prevOutValue
[
frameI
dx
];
if
(
prev
_out_value
&&
prev_out_g
rad
)
{
if
(
is
_batch
)
prev_out_value
+=
batch_idx
*
frame_s
ize
;
r
_prev_out_value
=
prev_out_value
[
frame_i
dx
];
if
(
is
Batch
)
prevOutGrad
+=
batchIdx
*
frameS
ize
;
r
PrevOutGrad
=
prevOutGrad
[
frameI
dx
];
if
(
is
_batch
)
prev_out_grad
+=
batch_idx
*
frame_s
ize
;
r
_prev_out_grad
=
prev_out_grad
[
frame_i
dx
];
}
op
StateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateV
alue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutG
rad
,
active_node
);
op
_state_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_frame_state_v
alue
,
r_frame_state_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
r_out_grad
,
active_node
);
gate
Grad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateG
rad
;
gate
Grad
[
frameIdx
+
frameSize
*
2
]
=
rFrameStateG
rad
;
if
(
prev
OutG
rad
)
{
prev
OutGrad
[
frameIdx
]
=
rPrevOutG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
0
]
=
r_update_gate_g
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
2
]
=
r_frame_state_g
rad
;
if
(
prev
_out_g
rad
)
{
prev
_out_grad
[
frame_idx
]
=
r_prev_out_g
rad
;
}
}
/*
* threads(frame
PerBlock, batchPerB
lock)
* grid(frame
Blocks, batchB
locks)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
_blocks, batch_b
locks)
*/
template
<
class
OpResetGrad
,
bool
is
B
atch
,
typename
T
>
__global__
void
KeGruBackwardResetGrad
(
OpResetGrad
op
ResetGrad
,
T
*
gateV
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
prev
OutGrad
,
T
*
resetOutputG
rad
,
int
frame
Size
,
int
batchS
ize
,
template
<
class
OpResetGrad
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruBackwardResetGrad
(
OpResetGrad
op
_reset_grad
,
T
*
gate_v
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
_out_grad
,
T
*
reset_output_g
rad
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_gate
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
int
batch
I
dx
=
0
;
if
(
is
B
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
Grad
+=
batchIdx
*
3
*
frameS
ize
;
reset
OutputGrad
+=
batchIdx
*
frameS
ize
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
_i
dx
=
0
;
if
(
is
_b
atch
)
{
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
gate
_grad
+=
batch_idx
*
3
*
frame_s
ize
;
reset
_output_grad
+=
batch_idx
*
frame_s
ize
;
}
T
r
ResetGateG
rad
;
T
r
PrevOutV
alue
=
0
;
T
r
PrevOutG
rad
=
0
;
T
r
ResetOutputG
rad
=
0
;
T
r
UpdateGateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
UpdateGateGrad
=
gateGrad
[
frameIdx
+
frameS
ize
*
0
];
T
r
ResetGateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
1
];
if
(
prev
OutValue
&&
prevOutG
rad
)
{
if
(
is
Batch
)
prevOutValue
+=
batchIdx
*
frameS
ize
;
if
(
is
Batch
)
prevOutGrad
+=
batchIdx
*
frameS
ize
;
r
PrevOutValue
=
prevOutValue
[
frameI
dx
];
r
PrevOutGrad
=
prevOutGrad
[
frameI
dx
];
r
ResetOutputGrad
=
resetOutputGrad
[
frameI
dx
];
T
r
_reset_gate_g
rad
;
T
r
_prev_out_v
alue
=
0
;
T
r
_prev_out_g
rad
=
0
;
T
r
_reset_output_g
rad
=
0
;
T
r
_update_gate_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
_update_gate_grad
=
gate_grad
[
frame_idx
+
frame_s
ize
*
0
];
T
r
_reset_gate_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
1
];
if
(
prev
_out_value
&&
prev_out_g
rad
)
{
if
(
is
_batch
)
prev_out_value
+=
batch_idx
*
frame_s
ize
;
if
(
is
_batch
)
prev_out_grad
+=
batch_idx
*
frame_s
ize
;
r
_prev_out_value
=
prev_out_value
[
frame_i
dx
];
r
_prev_out_grad
=
prev_out_grad
[
frame_i
dx
];
r
_reset_output_grad
=
reset_output_grad
[
frame_i
dx
];
}
op
ResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateV
alue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputG
rad
,
active_gate
);
op
_reset_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_reset_gate_v
alue
,
r_reset_gate_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
r_reset_output_grad
,
active_gate
);
gate
Grad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateG
rad
;
gate
Grad
[
frameIdx
+
frameSize
*
1
]
=
rResetGateG
rad
;
if
(
prev
OutG
rad
)
{
prev
OutGrad
[
frameIdx
]
=
rPrevOutG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
0
]
=
r_update_gate_g
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
1
]
=
r_reset_gate_g
rad
;
if
(
prev
_out_g
rad
)
{
prev
_out_grad
[
frame_idx
]
=
r_prev_out_g
rad
;
}
}
}
// namespace detail
...
...
paddle/operators/math/detail/gru_kernel.h
浏览文件 @
3e552cdc
...
...
@@ -28,23 +28,25 @@ namespace forward {
template
<
typename
T
>
class
gru_resetOutput
{
public:
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueResetGate
,
T
&
prevOut
,
T
&
valueResetOutput
,
activation_mode_t
actGate
)
{
valueUpdateGate
=
activation
(
valueUpdateGate
,
actGate
);
valueResetGate
=
activation
(
valueResetGate
,
actGate
);
valueResetOutput
=
prevOut
*
valueResetGate
;
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
value_reset_gate
,
T
&
prev_out
,
T
&
value_reset_output
,
activation_mode_t
act_gate
)
{
value_update_gate
=
activation
(
value_update_gate
,
act_gate
);
value_reset_gate
=
activation
(
value_reset_gate
,
act_gate
);
value_reset_output
=
prev_out
*
value_reset_gate
;
}
#ifndef __NVCC__
#ifndef __AVX__
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueResetGate
,
__m256
&
prevOut
,
__m256
&
valueResetOutput
,
activation_mode_t
actGate
)
{
valueUpdateGate
=
activation
(
valueUpdateGate
,
actGate
);
valueResetGate
=
activation
(
valueResetGate
,
actGate
);
valueResetOutput
=
_mm256_mul_ps
(
prevOut
,
valueResetGate
);
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
value_reset_gate
,
__m256
&
prev_out
,
__m256
&
value_reset_output
,
activation_mode_t
act_gate
)
{
value_update_gate
=
activation
(
value_update_gate
,
act_gate
);
value_reset_gate
=
activation
(
value_reset_gate
,
act_gate
);
value_reset_output
=
_mm256_mul_ps
(
prev_out
,
value_reset_gate
);
}
#endif
#endif
...
...
@@ -53,24 +55,26 @@ class gru_resetOutput {
template
<
typename
T
>
class
gru_finalOutput
{
public:
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueFrameState
,
T
&
prevOut
,
T
&
valueOutput
,
activation_mode_t
actInput
)
{
valueFrameState
=
activation
(
valueFrameState
,
actInput
);
valueOutput
=
prevOut
-
(
valueUpdateGate
*
prevOut
)
+
(
valueUpdateGate
*
valueFrameState
);
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
value_frame_state
,
T
&
prev_out
,
T
&
value_output
,
activation_mode_t
act_input
)
{
value_frame_state
=
activation
(
value_frame_state
,
act_input
);
value_output
=
prev_out
-
(
value_update_gate
*
prev_out
)
+
(
value_update_gate
*
value_frame_state
);
}
#ifndef __NVCC__
#ifndef __AVX__
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueFrameState
,
__m256
&
prevOut
,
__m256
&
valueOutput
,
activation_mode_t
actInput
)
{
valueFrameState
=
activation
(
valueFrameState
,
actInput
);
valueOutput
=
_mm256_add_ps
(
_mm256_sub_ps
(
prevOut
,
_mm256_mul_ps
(
valueUpdateGate
,
prevOut
)),
_mm256_mul_ps
(
valueUpdateGate
,
valueFrameState
));
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
value_frame_state
,
__m256
&
prev_out
,
__m256
&
value_output
,
activation_mode_t
act_input
)
{
value_frame_state
=
activation
(
value_frame_state
,
act_input
);
value_output
=
_mm256_add_ps
(
_mm256_sub_ps
(
prev_out
,
_mm256_mul_ps
(
value_update_gate
,
prev_out
)),
_mm256_mul_ps
(
value_update_gate
,
value_frame_state
));
}
#endif
#endif
...
...
@@ -82,34 +86,37 @@ namespace backward {
template
<
typename
T
>
class
gru_stateGrad
{
public:
HOSTDEVICE
void
operator
()(
T
&
value
UpdateGate
,
T
&
gradUpdateG
ate
,
T
&
value
FrameState
,
T
&
gradFrameS
tate
,
T
&
value
PrevOut
,
T
&
gradPrevOut
,
T
&
gradOutp
ut
,
activation_mode_t
actI
nput
)
{
grad
UpdateGate
=
(
gradOutput
*
valueFrameS
tate
);
grad
UpdateGate
-=
(
gradOutput
*
valuePrevO
ut
);
grad
PrevOut
-=
(
gradOutput
*
valueUpdateG
ate
);
grad
PrevOut
+=
gradO
utput
;
grad
FrameState
=
activation
(
gradOutput
*
valueUpdateGate
,
valueFrameState
,
actI
nput
);
HOSTDEVICE
void
operator
()(
T
&
value
_update_gate
,
T
&
grad_update_g
ate
,
T
&
value
_frame_state
,
T
&
grad_frame_s
tate
,
T
&
value
_prev_out
,
T
&
grad_prev_o
ut
,
T
&
grad_output
,
activation_mode_t
act_i
nput
)
{
grad
_update_gate
=
(
grad_output
*
value_frame_s
tate
);
grad
_update_gate
-=
(
grad_output
*
value_prev_o
ut
);
grad
_prev_out
-=
(
grad_output
*
value_update_g
ate
);
grad
_prev_out
+=
grad_o
utput
;
grad
_frame_state
=
activation
(
grad_output
*
value_update_gate
,
value_frame_state
,
act_i
nput
);
}
#ifndef __NVCC__
#ifndef __AVX__
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
gradUpdateGate
,
__m256
&
valueFrameState
,
__m256
&
gradFrameState
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
gradOutput
,
activation_mode_t
actInput
)
{
gradUpdateGate
=
_mm256_mul_ps
(
gradOutput
,
valueFrameState
);
gradUpdateGate
=
_mm256_sub_ps
(
gradUpdateGate
,
_mm256_mul_ps
(
gradOutput
,
valuePrevOut
));
gradPrevOut
=
_mm256_add_ps
(
_mm256_sub_ps
(
gradPrevOut
,
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
)),
gradOutput
);
gradFrameState
=
activation
(
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
),
valueFrameState
,
actInput
);
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
grad_update_gate
,
__m256
&
value_frame_state
,
__m256
&
grad_frame_state
,
__m256
&
value_prev_out
,
__m256
&
grad_prev_out
,
__m256
&
grad_output
,
activation_mode_t
act_input
)
{
grad_update_gate
=
_mm256_mul_ps
(
grad_output
,
value_frame_state
);
grad_update_gate
=
_mm256_sub_ps
(
grad_update_gate
,
_mm256_mul_ps
(
grad_output
,
value_prev_out
));
grad_prev_out
=
_mm256_add_ps
(
_mm256_sub_ps
(
grad_prev_out
,
_mm256_mul_ps
(
grad_output
,
value_update_gate
)),
grad_output
);
grad_frame_state
=
activation
(
_mm256_mul_ps
(
grad_output
,
value_update_gate
),
value_frame_state
,
act_input
);
}
#endif
#endif
...
...
@@ -118,30 +125,32 @@ class gru_stateGrad {
template
<
typename
T
>
class
gru_resetGrad
{
public:
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
gradUpdateGate
,
T
&
valueResetGate
,
T
&
gradResetGate
,
T
&
valuePrevOut
,
T
&
gradPrevOut
,
T
&
gradResetOutput
,
activation_mode_t
actGate
)
{
gradResetGate
=
(
gradResetOutput
*
valuePrevOut
);
gradPrevOut
+=
(
gradResetOutput
*
valueResetGate
);
gradUpdateGate
=
activation
(
gradUpdateGate
,
valueUpdateGate
,
actGate
);
gradResetGate
=
activation
(
gradResetGate
,
valueResetGate
,
actGate
);
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
grad_update_gate
,
T
&
value_reset_gate
,
T
&
grad_reset_gate
,
T
&
value_prev_out
,
T
&
grad_prev_out
,
T
&
grad_reset_output
,
activation_mode_t
act_gate
)
{
grad_reset_gate
=
(
grad_reset_output
*
value_prev_out
);
grad_prev_out
+=
(
grad_reset_output
*
value_reset_gate
);
grad_update_gate
=
activation
(
grad_update_gate
,
value_update_gate
,
act_gate
);
grad_reset_gate
=
activation
(
grad_reset_gate
,
value_reset_gate
,
act_gate
);
}
#ifndef __NVCC__
#ifndef __AVX__
static
const
bool
avx
=
false
;
#else
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
gradUpdateGate
,
__m256
&
valueResetGate
,
__m256
&
gradResetGate
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
gradResetOutput
,
activation_mode_t
actGate
)
{
gradResetGate
=
_mm256_mul_ps
(
gradResetOutput
,
valuePrevOut
);
gradPrevOut
=
_mm256_add_ps
(
gradPrevOut
,
_mm256_mul_ps
(
gradResetOutput
,
valueResetGate
));
gradUpdateGate
=
activation
(
gradUpdateGate
,
valueUpdateGate
,
actGate
);
gradResetGate
=
activation
(
gradResetGate
,
valueResetGate
,
actGate
);
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
grad_update_gate
,
__m256
&
value_reset_gate
,
__m256
&
grad_reset_gate
,
__m256
&
value_prev_out
,
__m256
&
grad_prev_out
,
__m256
&
grad_reset_output
,
activation_mode_t
act_gate
)
{
grad_reset_gate
=
_mm256_mul_ps
(
grad_reset_output
,
value_prev_out
);
grad_prev_out
=
_mm256_add_ps
(
grad_prev_out
,
_mm256_mul_ps
(
grad_reset_output
,
value_reset_gate
));
grad_update_gate
=
activation
(
grad_update_gate
,
value_update_gate
,
act_gate
);
grad_reset_gate
=
activation
(
grad_reset_gate
,
value_reset_gate
,
act_gate
);
}
#endif
#endif
...
...
paddle/operators/math/gru_compute.cc
浏览文件 @
3e552cdc
...
...
@@ -21,29 +21,29 @@ namespace math {
template
<
typename
T
>
struct
GRUUnitFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
hl_gru_value
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
#ifndef __NVCC__
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
*
2
,
frameS
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
value
.
gateWeight
,
frameSize
*
2
,
1
,
value
.
gateValue
,
frameS
ize
*
3
);
context
,
false
,
false
,
batch
_size
,
frame_size
*
2
,
frame_s
ize
,
1
,
value
.
prev
_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
value
.
gate_value
,
frame_s
ize
*
3
);
}
detail
::
forward_reset_output
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
,
frame
Size
,
batchS
ize
,
active_gate
);
frame
_size
,
batch_s
ize
,
active_gate
);
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
value
.
reset
OutputValue
,
frameSize
,
value
.
stateWeight
,
frameSize
,
1
,
value
.
gateValue
+
frameSize
*
2
,
frameS
ize
*
3
);
context
,
false
,
false
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
value
.
reset
_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_s
ize
*
3
);
}
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
frame
Size
,
batchS
ize
,
active_node
);
frame
_size
,
batch_s
ize
,
active_node
);
#endif
}
};
...
...
@@ -51,41 +51,43 @@ struct GRUUnitFunctor<platform::CPUPlace, T> {
template
<
typename
T
>
struct
GRUUnitGradFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
#ifndef __NVCC__
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
grad
,
frame
Size
,
batchS
ize
,
active_node
);
grad
,
frame
_size
,
batch_s
ize
,
active_node
);
if
(
value
.
prev
OutValue
&&
grad
.
prevOutG
rad
)
{
if
(
value
.
prev
_out_value
&&
grad
.
prev_out_g
rad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
grad
.
gate
Grad
+
frameSize
*
2
,
frameSize
*
3
,
value
.
stateW
eight
,
frame
Size
,
0
,
grad
.
resetOutputGrad
,
frameS
ize
);
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
grad
.
gate
_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_w
eight
,
frame
_size
,
0
,
grad
.
reset_output_grad
,
frame_s
ize
);
if
(
grad
.
state
WeightG
rad
)
{
if
(
grad
.
state
_weight_g
rad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
true
,
false
,
frameSize
,
frameSize
,
batchSize
,
1
,
value
.
resetOutputValue
,
frameSize
,
grad
.
gateGrad
+
frameSize
*
2
,
frameSize
*
3
,
1
,
grad
.
stateWeightGrad
,
frameSize
);
context
,
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
reset_output_value
,
frame_size
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
}
}
detail
::
backward_reset_grad
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
,
grad
,
frame
Size
,
batchS
ize
,
active_gate
);
grad
,
frame
_size
,
batch_s
ize
,
active_gate
);
if
(
grad
.
prev
OutGrad
&&
value
.
prevOutV
alue
)
{
if
(
grad
.
prev
_out_grad
&&
value
.
prev_out_v
alue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
*
2
,
1
,
grad
.
gate
Grad
,
frameSize
*
3
,
value
.
gateWeight
,
frameS
ize
*
2
,
1
,
grad
.
prev
OutGrad
,
frameS
ize
);
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
*
2
,
1
,
grad
.
gate
_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_s
ize
*
2
,
1
,
grad
.
prev
_out_grad
,
frame_s
ize
);
if
(
grad
.
gate
WeightG
rad
)
{
if
(
grad
.
gate
_weight_g
rad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
true
,
false
,
frame
Size
,
frameSize
*
2
,
batchS
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
grad
.
gateGrad
,
frameS
ize
*
3
,
1
,
grad
.
gate
WeightGrad
,
frameS
ize
*
2
);
context
,
true
,
false
,
frame
_size
,
frame_size
*
2
,
batch_s
ize
,
1
,
value
.
prev
_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_s
ize
*
3
,
1
,
grad
.
gate
_weight_grad
,
frame_s
ize
*
2
);
}
}
#endif
...
...
paddle/operators/math/gru_compute.cu
浏览文件 @
3e552cdc
...
...
@@ -21,66 +21,66 @@ namespace math {
template
<
typename
T
>
struct
GRUUnitFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
hl_gru_value
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
dim3
threads
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
if
(
batch
_s
ize
==
1
)
{
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
32
-
1
)
/
32
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
32
-
1
)
/
32
);
}
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
*
2
,
frameS
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
value
.
gateWeight
,
frameSize
*
2
,
1
,
value
.
gateValue
,
frameS
ize
*
3
);
context
,
false
,
false
,
batch
_size
,
frame_size
*
2
,
frame_s
ize
,
1
,
value
.
prev
_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
1
,
value
.
gate_value
,
frame_s
ize
*
3
);
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
/* is
B
atch= */
false
,
/* is
_b
atch= */
false
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
V
alue
,
value
.
reset
OutputValue
,
value
.
prevOutValue
,
frameSize
,
batchS
ize
,
active_gate
);
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
reset
_output_value
,
value
.
prev_out_value
,
frame_s
ize
,
batch_size
,
active_gate
);
}
else
{
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
/* is
B
atch= */
true
,
/* is
_b
atch= */
true
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
V
alue
,
value
.
reset
OutputValue
,
value
.
prevOutValue
,
frameSize
,
batchS
ize
,
active_gate
);
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
reset
_output_value
,
value
.
prev_out_value
,
frame_s
ize
,
batch_size
,
active_gate
);
}
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
value
.
reset
OutputValue
,
frameSize
,
value
.
stateWeight
,
frameSize
,
1
,
value
.
gateValue
+
frameSize
*
2
,
frameS
ize
*
3
);
context
,
false
,
false
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
value
.
reset
_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_s
ize
*
3
);
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
/* is
B
atch= */
false
,
/* is
_b
atch= */
false
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
V
alue
,
value
.
prev
OutValue
,
value
.
outputValue
,
frameSize
,
batchS
ize
,
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
prev
_out_value
,
value
.
output_value
,
frame_size
,
batch_s
ize
,
active_node
);
}
else
{
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
/* is
B
atch= */
true
,
/* is
_b
atch= */
true
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
V
alue
,
value
.
prev
OutValue
,
value
.
outputValue
,
frameSize
,
batchS
ize
,
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
prev
_out_value
,
value
.
output_value
,
frame_size
,
batch_s
ize
,
active_node
);
}
}
...
...
@@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> {
template
<
typename
T
>
struct
GRUUnitGradFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
dim3
threads
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
if
(
batch
_s
ize
==
1
)
{
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
32
-
1
)
/
32
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
32
-
1
)
/
32
);
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruBackwardStateGrad
<
detail
::
backward
::
gru_stateGrad
<
T
>
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
outputGrad
,
frameSize
,
batchS
ize
,
active_node
);
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
_value
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
grad
.
output_grad
,
frame_size
,
batch_s
ize
,
active_node
);
}
else
{
detail
::
KeGruBackwardStateGrad
<
detail
::
backward
::
gru_stateGrad
<
T
>
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
outputGrad
,
frameSize
,
batchS
ize
,
active_node
);
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
_value
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
grad
.
output_grad
,
frame_size
,
batch_s
ize
,
active_node
);
}
if
(
value
.
prev
OutValue
&&
grad
.
prevOutG
rad
)
{
if
(
value
.
prev
_out_value
&&
grad
.
prev_out_g
rad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
grad
.
gate
Grad
+
frameSize
*
2
,
frameSize
*
3
,
value
.
stateW
eight
,
frame
Size
,
0
,
grad
.
resetOutputGrad
,
frameS
ize
);
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
grad
.
gate
_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_w
eight
,
frame
_size
,
0
,
grad
.
reset_output_grad
,
frame_s
ize
);
if
(
grad
.
state
WeightG
rad
)
{
if
(
grad
.
state
_weight_g
rad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
true
,
false
,
frameSize
,
frameSize
,
batchSize
,
1
,
value
.
resetOutputValue
,
frameSize
,
grad
.
gateGrad
+
frameSize
*
2
,
frameSize
*
3
,
1
,
grad
.
stateWeightGrad
,
frameSize
);
context
,
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
reset_output_value
,
frame_size
,
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruBackwardResetGrad
<
detail
::
backward
::
gru_resetGrad
<
T
>
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
resetOutputGrad
,
frameSize
,
batchS
ize
,
active_gate
);
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
_value
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
grad
.
reset_output_grad
,
frame_size
,
batch_s
ize
,
active_gate
);
}
else
{
detail
::
KeGruBackwardResetGrad
<
detail
::
backward
::
gru_resetGrad
<
T
>
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
resetOutputGrad
,
frameSize
,
batchS
ize
,
active_gate
);
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
_value
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
grad
.
reset_output_grad
,
frame_size
,
batch_s
ize
,
active_gate
);
}
if
(
grad
.
prev
OutGrad
&&
value
.
prevOutV
alue
)
{
if
(
grad
.
prev
_out_grad
&&
value
.
prev_out_v
alue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
*
2
,
1
,
grad
.
gate
Grad
,
frameSize
*
3
,
value
.
gateWeight
,
frameS
ize
*
2
,
1
,
grad
.
prev
OutGrad
,
frameS
ize
);
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
*
2
,
1
,
grad
.
gate
_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_s
ize
*
2
,
1
,
grad
.
prev
_out_grad
,
frame_s
ize
);
if
(
grad
.
gate
WeightG
rad
)
{
if
(
grad
.
gate
_weight_g
rad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
true
,
false
,
frame
Size
,
frameSize
*
2
,
batchS
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
grad
.
gateGrad
,
frameS
ize
*
3
,
1
,
grad
.
gate
WeightGrad
,
frameS
ize
*
2
);
context
,
true
,
false
,
frame
_size
,
frame_size
*
2
,
batch_s
ize
,
1
,
value
.
prev
_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_s
ize
*
3
,
1
,
grad
.
gate
_weight_grad
,
frame_s
ize
*
2
);
}
}
}
...
...
paddle/operators/math/gru_compute.h
浏览文件 @
3e552cdc
...
...
@@ -22,28 +22,28 @@ namespace math {
// TODO(guosheng): refine code style in gru_compute
template
<
typename
T
>
struct
hl_gru_value
{
T
*
gate
W
eight
;
T
*
state
W
eight
;
T
*
gate
V
alue
;
T
*
reset
OutputV
alue
;
T
*
output
V
alue
;
T
*
prev
OutV
alue
;
T
*
gate
_w
eight
;
T
*
state
_w
eight
;
T
*
gate
_v
alue
;
T
*
reset
_output_v
alue
;
T
*
output
_v
alue
;
T
*
prev
_out_v
alue
;
};
template
<
typename
T
>
struct
hl_gru_grad
{
T
*
gate
WeightG
rad
;
T
*
state
WeightG
rad
;
T
*
gate
G
rad
;
T
*
reset
OutputG
rad
;
T
*
output
G
rad
;
T
*
prev
OutG
rad
;
T
*
gate
_weight_g
rad
;
T
*
state
_weight_g
rad
;
T
*
gate
_g
rad
;
T
*
reset
_output_g
rad
;
T
*
output
_g
rad
;
T
*
prev
_out_g
rad
;
};
template
<
typename
Place
,
typename
T
>
struct
GRUUnitFunctor
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
hl_gru_value
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
);
};
...
...
@@ -51,8 +51,9 @@ struct GRUUnitFunctor {
template
<
typename
Place
,
typename
T
>
struct
GRUUnitGradFunctor
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
int
batchSize
,
activation_mode_t
active_node
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
);
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录