Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
3e552cdc
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3e552cdc
编写于
11月 29, 2017
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix gru_op related code style
上级
dcf3ffd9
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
617 addition
and
599 deletion
+617
-599
paddle/operators/gru_op.h
paddle/operators/gru_op.h
+23
-23
paddle/operators/math/detail/gru_cpu_kernel.h
paddle/operators/math/detail/gru_cpu_kernel.h
+272
-268
paddle/operators/math/detail/gru_gpu_kernel.h
paddle/operators/math/detail/gru_gpu_kernel.h
+126
-126
paddle/operators/math/detail/gru_kernel.h
paddle/operators/math/detail/gru_kernel.h
+72
-63
paddle/operators/math/gru_compute.cc
paddle/operators/math/gru_compute.cc
+33
-31
paddle/operators/math/gru_compute.cu
paddle/operators/math/gru_compute.cu
+75
-73
paddle/operators/math/gru_compute.h
paddle/operators/math/gru_compute.h
+16
-15
未找到文件。
paddle/operators/gru_op.h
浏览文件 @
3e552cdc
...
@@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -71,8 +71,8 @@ class GRUKernel : public framework::OpKernel<T> {
int
frame_size
=
hidden_dims
[
1
];
int
frame_size
=
hidden_dims
[
1
];
math
::
hl_gru_value
<
T
>
gru_value
;
math
::
hl_gru_value
<
T
>
gru_value
;
gru_value
.
gate
W
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
gate
_w
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state
W
eight
=
gru_value
.
state
_w
eight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
Tensor
ordered_h0
;
Tensor
ordered_h0
;
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
...
@@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -82,9 +82,9 @@ class GRUKernel : public framework::OpKernel<T> {
// to reorder.
// to reorder.
ReorderInitState
<
Place
,
T
>
(
context
.
device_context
(),
*
h0
,
order
,
ReorderInitState
<
Place
,
T
>
(
context
.
device_context
(),
*
h0
,
order
,
&
ordered_h0
,
true
);
&
ordered_h0
,
true
);
gru_value
.
prev
OutV
alue
=
ordered_h0
.
data
<
T
>
();
gru_value
.
prev
_out_v
alue
=
ordered_h0
.
data
<
T
>
();
}
else
{
}
else
{
gru_value
.
prev
OutV
alue
=
nullptr
;
gru_value
.
prev
_out_v
alue
=
nullptr
;
}
}
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
auto
batch_starts
=
batch_gate
->
lod
()[
0
];
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
size_t
num_batch
=
batch_starts
.
size
()
-
1
;
...
@@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel<T> {
...
@@ -96,14 +96,14 @@ class GRUKernel : public framework::OpKernel<T> {
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
Tensor
hidden_t
=
batch_hidden
->
Slice
(
bstart
,
bend
);
gru_value
.
output
V
alue
=
hidden_t
.
data
<
T
>
();
gru_value
.
output
_v
alue
=
hidden_t
.
data
<
T
>
();
gru_value
.
gate
V
alue
=
gate_t
.
data
<
T
>
();
gru_value
.
gate
_v
alue
=
gate_t
.
data
<
T
>
();
gru_value
.
reset
OutputV
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
gru_value
.
reset
_output_v
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
math
::
GRUUnitFunctor
<
Place
,
T
>::
compute
(
math
::
GRUUnitFunctor
<
Place
,
T
>::
compute
(
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
dev_ctx
,
gru_value
,
frame_size
,
cur_batch_size
,
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"activation"
)),
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"activation"
)),
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
)));
math
::
ActiveType
(
context
.
Attr
<
std
::
string
>
(
"gate_activation"
)));
gru_value
.
prev
OutValue
=
gru_value
.
outputV
alue
;
gru_value
.
prev
_out_value
=
gru_value
.
output_v
alue
;
}
}
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
...
@@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel<T> {
...
@@ -169,20 +169,20 @@ class GRUGradKernel : public framework::OpKernel<T> {
to_batch
(
dev_ctx
,
*
hidden_grad
,
batch_hidden_grad
,
false
,
is_reverse
);
to_batch
(
dev_ctx
,
*
hidden_grad
,
batch_hidden_grad
,
false
,
is_reverse
);
math
::
hl_gru_value
<
T
>
gru_value
;
math
::
hl_gru_value
<
T
>
gru_value
;
gru_value
.
gate
W
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
gate
_w
eight
=
const_cast
<
T
*>
(
weight_data
);
gru_value
.
state
W
eight
=
gru_value
.
state
_w
eight
=
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
const_cast
<
T
*>
(
weight_data
+
2
*
frame_size
*
frame_size
);
math
::
hl_gru_grad
<
T
>
gru_grad
;
math
::
hl_gru_grad
<
T
>
gru_grad
;
if
(
weight_grad
)
{
if
(
weight_grad
)
{
gru_grad
.
gate
WeightG
rad
=
gru_grad
.
gate
_weight_g
rad
=
weight_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
weight_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
zero
(
dev_ctx
,
weight_grad
,
static_cast
<
T
>
(
0.0
));
zero
(
dev_ctx
,
weight_grad
,
static_cast
<
T
>
(
0.0
));
gru_grad
.
state
WeightG
rad
=
gru_grad
.
state
_weight_g
rad
=
weight_grad
->
data
<
T
>
()
+
2
*
frame_size
*
frame_size
;
weight_grad
->
data
<
T
>
()
+
2
*
frame_size
*
frame_size
;
}
else
{
}
else
{
gru_grad
.
gate
WeightG
rad
=
nullptr
;
gru_grad
.
gate
_weight_g
rad
=
nullptr
;
gru_grad
.
state
WeightG
rad
=
nullptr
;
gru_grad
.
state
_weight_g
rad
=
nullptr
;
}
}
auto
batch_starts
=
batch_hidden_grad
.
lod
()[
0
];
auto
batch_starts
=
batch_hidden_grad
.
lod
()[
0
];
...
@@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel<T> {
...
@@ -193,27 +193,27 @@ class GRUGradKernel : public framework::OpKernel<T> {
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate_t
=
batch_gate
->
Slice
(
bstart
,
bend
);
gru_value
.
gate
V
alue
=
gate_t
.
data
<
T
>
();
gru_value
.
gate
_v
alue
=
gate_t
.
data
<
T
>
();
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
Tensor
reset_hidden_prev_t
=
batch_reset_hidden_prev
->
Slice
(
bstart
,
bend
);
gru_value
.
reset
OutputV
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
gru_value
.
reset
_output_v
alue
=
reset_hidden_prev_t
.
data
<
T
>
();
Tensor
hidden_grad_t
=
batch_hidden_grad
.
Slice
(
bstart
,
bend
);
Tensor
hidden_grad_t
=
batch_hidden_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
output
G
rad
=
hidden_grad_t
.
data
<
T
>
();
gru_grad
.
output
_g
rad
=
hidden_grad_t
.
data
<
T
>
();
Tensor
gate_grad_t
=
batch_gate_grad
.
Slice
(
bstart
,
bend
);
Tensor
gate_grad_t
=
batch_gate_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
gate
G
rad
=
gate_grad_t
.
data
<
T
>
();
gru_grad
.
gate
_g
rad
=
gate_grad_t
.
data
<
T
>
();
Tensor
reset_hidden_prev_grad_t
=
Tensor
reset_hidden_prev_grad_t
=
batch_reset_hidden_prev_grad
.
Slice
(
bstart
,
bend
);
batch_reset_hidden_prev_grad
.
Slice
(
bstart
,
bend
);
gru_grad
.
reset
OutputG
rad
=
reset_hidden_prev_grad_t
.
data
<
T
>
();
gru_grad
.
reset
_output_g
rad
=
reset_hidden_prev_grad_t
.
data
<
T
>
();
if
(
n
==
0
)
{
if
(
n
==
0
)
{
gru_value
.
prev
OutV
alue
=
h0
?
ordered_h0
.
data
<
T
>
()
:
nullptr
;
gru_value
.
prev
_out_v
alue
=
h0
?
ordered_h0
.
data
<
T
>
()
:
nullptr
;
gru_grad
.
prev
OutG
rad
=
gru_grad
.
prev
_out_g
rad
=
h0
&&
h0_grad
?
ordered_h0_grad
.
data
<
T
>
()
:
nullptr
;
h0
&&
h0_grad
?
ordered_h0_grad
.
data
<
T
>
()
:
nullptr
;
}
else
{
}
else
{
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
Tensor
hidden_prev_t
=
batch_hidden
->
Slice
(
bstart_pre
,
bstart
);
Tensor
hidden_prev_t
=
batch_hidden
->
Slice
(
bstart_pre
,
bstart
);
gru_value
.
prev
OutV
alue
=
hidden_prev_t
.
data
<
T
>
();
gru_value
.
prev
_out_v
alue
=
hidden_prev_t
.
data
<
T
>
();
Tensor
hidden_prev_grad_t
=
batch_hidden_grad
.
Slice
(
bstart_pre
,
bstart
);
Tensor
hidden_prev_grad_t
=
batch_hidden_grad
.
Slice
(
bstart_pre
,
bstart
);
gru_grad
.
prev
OutG
rad
=
hidden_prev_grad_t
.
data
<
T
>
();
gru_grad
.
prev
_out_g
rad
=
hidden_prev_grad_t
.
data
<
T
>
();
}
}
math
::
GRUUnitGradFunctor
<
Place
,
T
>::
compute
(
math
::
GRUUnitGradFunctor
<
Place
,
T
>::
compute
(
...
...
paddle/operators/math/detail/gru_cpu_kernel.h
浏览文件 @
3e552cdc
...
@@ -25,393 +25,397 @@ namespace detail {
...
@@ -25,393 +25,397 @@ namespace detail {
#ifndef __NVCC__
#ifndef __NVCC__
template
<
class
OpResetOutput
,
typename
T
>
template
<
class
OpResetOutput
,
typename
T
>
void
hl_naive_gru_forward_reset_output
(
OpResetOutput
op
ResetO
utput
,
void
hl_naive_gru_forward_reset_output
(
OpResetOutput
op
_reset_o
utput
,
T
*
gate
Value
,
T
*
resetOutputV
alue
,
T
*
gate
_value
,
T
*
reset_output_v
alue
,
T
*
prev
OutputValue
,
int
frameS
ize
,
T
*
prev
_output_value
,
int
frame_s
ize
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
T
r
ValueUpdateG
ate
;
T
r
_value_update_g
ate
;
T
r
ValueResetG
ate
;
T
r
_value_reset_g
ate
;
T
r
ValueResetO
utput
;
T
r
_value_reset_o
utput
;
T
r
PrevO
ut
=
0
;
T
r
_prev_o
ut
=
0
;
T
*
update
Gate
=
gateV
alue
;
T
*
update
_gate
=
gate_v
alue
;
T
*
reset
Gate
=
gateValue
+
frameS
ize
;
T
*
reset
_gate
=
gate_value
+
frame_s
ize
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
ValueUpdateGate
=
updateG
ate
[
i
];
r
_value_update_gate
=
update_g
ate
[
i
];
r
ValueResetGate
=
resetG
ate
[
i
];
r
_value_reset_gate
=
reset_g
ate
[
i
];
if
(
prev
OutputV
alue
)
{
if
(
prev
_output_v
alue
)
{
r
PrevOut
=
prevOutputV
alue
[
i
];
r
_prev_out
=
prev_output_v
alue
[
i
];
}
}
op
ResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevO
ut
,
op
_reset_output
(
r_value_update_gate
,
r_value_reset_gate
,
r_prev_o
ut
,
rValueResetO
utput
,
active_gate
);
r_value_reset_o
utput
,
active_gate
);
update
Gate
[
i
]
=
rValueUpdateG
ate
;
update
_gate
[
i
]
=
r_value_update_g
ate
;
reset
Gate
[
i
]
=
rValueResetG
ate
;
reset
_gate
[
i
]
=
r_value_reset_g
ate
;
reset
OutputValue
[
i
]
=
rValueResetO
utput
;
reset
_output_value
[
i
]
=
r_value_reset_o
utput
;
}
}
}
}
template
<
class
OpFinalOutput
,
typename
T
>
template
<
class
OpFinalOutput
,
typename
T
>
void
hl_naive_gru_forward_final_output
(
OpFinalOutput
op
FinalO
utput
,
void
hl_naive_gru_forward_final_output
(
OpFinalOutput
op
_final_o
utput
,
T
*
gate
Value
,
T
*
prevOutputV
alue
,
T
*
gate
_value
,
T
*
prev_output_v
alue
,
T
*
output
Value
,
int
frameS
ize
,
T
*
output
_value
,
int
frame_s
ize
,
activation_mode_t
active_node
)
{
activation_mode_t
active_node
)
{
T
r
ValueUpdateG
ate
;
T
r
_value_update_g
ate
;
T
r
ValueFrameS
tate
;
T
r
_value_frame_s
tate
;
T
r
PrevO
ut
=
0
;
T
r
_prev_o
ut
=
0
;
T
r
O
utput
;
T
r
_o
utput
;
T
*
update
Gate
=
gateV
alue
;
T
*
update
_gate
=
gate_v
alue
;
T
*
frame
State
=
gateValue
+
frameS
ize
*
2
;
T
*
frame
_state
=
gate_value
+
frame_s
ize
*
2
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
ValueUpdateGate
=
updateG
ate
[
i
];
r
_value_update_gate
=
update_g
ate
[
i
];
r
ValueFrameState
=
frameS
tate
[
i
];
r
_value_frame_state
=
frame_s
tate
[
i
];
if
(
prev
OutputV
alue
)
{
if
(
prev
_output_v
alue
)
{
r
PrevOut
=
prevOutputV
alue
[
i
];
r
_prev_out
=
prev_output_v
alue
[
i
];
}
}
op
FinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutp
ut
,
op
_final_output
(
r_value_update_gate
,
r_value_frame_state
,
r_prev_o
ut
,
active_node
);
r_output
,
active_node
);
frame
State
[
i
]
=
rValueFrameS
tate
;
frame
_state
[
i
]
=
r_value_frame_s
tate
;
output
Value
[
i
]
=
rO
utput
;
output
_value
[
i
]
=
r_o
utput
;
}
}
}
}
template
<
class
OpResetOutput
,
typename
T
>
template
<
class
OpResetOutput
,
typename
T
>
void
hl_avx_gru_forward_reset_output
(
OpResetOutput
op
ResetOutput
,
T
*
gateValue
,
void
hl_avx_gru_forward_reset_output
(
OpResetOutput
op
_reset_output
,
T
*
resetOutputValue
,
T
*
prevOutputV
alue
,
T
*
gate_value
,
T
*
reset_output_v
alue
,
int
frameS
ize
,
T
*
prev_output_value
,
int
frame_s
ize
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r
ValueUpdateG
ate
;
__m256
r
_value_update_g
ate
;
__m256
r
ValueResetG
ate
;
__m256
r
_value_reset_g
ate
;
__m256
r
ValueResetO
utput
;
__m256
r
_value_reset_o
utput
;
__m256
r
PrevO
ut
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_o
ut
=
_mm256_set1_ps
(
0.0
f
);
__m256
*
update
Gate
=
(
__m256
*
)
gateV
alue
;
__m256
*
update
_gate
=
(
__m256
*
)
gate_v
alue
;
__m256
*
reset
Gate
=
(
__m256
*
)(
gateValue
+
frameS
ize
);
__m256
*
reset
_gate
=
(
__m256
*
)(
gate_value
+
frame_s
ize
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
ValueUpdateGate
=
updateG
ate
[
i
];
r
_value_update_gate
=
update_g
ate
[
i
];
r
ValueResetGate
=
resetG
ate
[
i
];
r
_value_reset_gate
=
reset_g
ate
[
i
];
if
(
prev
OutputV
alue
)
{
if
(
prev
_output_v
alue
)
{
r
PrevOut
=
((
__m256
*
)
prevOutputV
alue
)[
i
];
r
_prev_out
=
((
__m256
*
)
prev_output_v
alue
)[
i
];
}
}
op
ResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevO
ut
,
op
_reset_output
(
r_value_update_gate
,
r_value_reset_gate
,
r_prev_o
ut
,
rValueResetO
utput
,
active_gate
);
r_value_reset_o
utput
,
active_gate
);
update
Gate
[
i
]
=
rValueUpdateG
ate
;
update
_gate
[
i
]
=
r_value_update_g
ate
;
reset
Gate
[
i
]
=
rValueResetG
ate
;
reset
_gate
[
i
]
=
r_value_reset_g
ate
;
((
__m256
*
)
reset
OutputValue
)[
i
]
=
rValueResetO
utput
;
((
__m256
*
)
reset
_output_value
)[
i
]
=
r_value_reset_o
utput
;
}
}
#endif
#endif
}
}
template
<
class
OpFinalOutput
,
typename
T
>
template
<
class
OpFinalOutput
,
typename
T
>
void
hl_avx_gru_forward_final_output
(
OpFinalOutput
op
FinalOutput
,
T
*
gateValue
,
void
hl_avx_gru_forward_final_output
(
OpFinalOutput
op
_final_output
,
T
*
prevOutputValue
,
T
*
outputV
alue
,
T
*
gate_value
,
T
*
prev_output_v
alue
,
int
frameS
ize
,
T
*
output_value
,
int
frame_s
ize
,
activation_mode_t
active_node
)
{
activation_mode_t
active_node
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r
ValueUpdateG
ate
;
__m256
r
_value_update_g
ate
;
__m256
r
ValueFrameS
tate
;
__m256
r
_value_frame_s
tate
;
__m256
r
PrevO
ut
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_o
ut
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
O
utput
;
__m256
r
_o
utput
;
__m256
*
update
Gate
=
(
__m256
*
)
gateV
alue
;
__m256
*
update
_gate
=
(
__m256
*
)
gate_v
alue
;
__m256
*
frame
State
=
(
__m256
*
)(
gateValue
+
frameS
ize
*
2
);
__m256
*
frame
_state
=
(
__m256
*
)(
gate_value
+
frame_s
ize
*
2
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
ValueUpdateGate
=
updateG
ate
[
i
];
r
_value_update_gate
=
update_g
ate
[
i
];
r
ValueFrameState
=
frameS
tate
[
i
];
r
_value_frame_state
=
frame_s
tate
[
i
];
if
(
prev
OutputV
alue
)
{
if
(
prev
_output_v
alue
)
{
r
PrevOut
=
((
__m256
*
)
prevOutputV
alue
)[
i
];
r
_prev_out
=
((
__m256
*
)
prev_output_v
alue
)[
i
];
}
}
op
FinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutp
ut
,
op
_final_output
(
r_value_update_gate
,
r_value_frame_state
,
r_prev_o
ut
,
active_node
);
r_output
,
active_node
);
frame
State
[
i
]
=
rValueFrameS
tate
;
frame
_state
[
i
]
=
r_value_frame_s
tate
;
((
__m256
*
)
output
Value
)[
i
]
=
rO
utput
;
((
__m256
*
)
output
_value
)[
i
]
=
r_o
utput
;
}
}
#endif
#endif
}
}
template
<
class
OpResetOutput
,
typename
T
>
template
<
class
OpResetOutput
,
typename
T
>
inline
void
forward_reset_output
(
OpResetOutput
opResetOutput
,
inline
void
forward_reset_output
(
OpResetOutput
op_reset_output
,
hl_gru_value
<
T
>
value
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
int
frame_size
,
int
batchSize
,
activation_mode_t
active_gate
)
{
int
batch_size
,
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
activation_mode_t
active_gate
)
{
if
(
OpResetOutput
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpResetOutput
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_forward_reset_output
(
hl_avx_gru_forward_reset_output
(
op
ResetOutput
,
value
.
gateValue
,
value
.
resetOutputV
alue
,
op
_reset_output
,
value
.
gate_value
,
value
.
reset_output_v
alue
,
value
.
prev
OutValue
,
frameS
ize
,
active_gate
);
value
.
prev
_out_value
,
frame_s
ize
,
active_gate
);
}
else
{
}
else
{
hl_naive_gru_forward_reset_output
(
hl_naive_gru_forward_reset_output
(
op
ResetOutput
,
value
.
gateValue
,
value
.
resetOutputV
alue
,
op
_reset_output
,
value
.
gate_value
,
value
.
reset_output_v
alue
,
value
.
prev
OutValue
,
frameS
ize
,
active_gate
);
value
.
prev
_out_value
,
frame_s
ize
,
active_gate
);
}
}
value
.
gate
Value
+=
frameS
ize
*
3
;
value
.
gate
_value
+=
frame_s
ize
*
3
;
value
.
reset
OutputValue
+=
frameS
ize
;
value
.
reset
_output_value
+=
frame_s
ize
;
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
value
.
prev
OutValue
+=
frameS
ize
;
value
.
prev
_out_value
+=
frame_s
ize
;
}
}
}
}
}
}
template
<
class
OpFinalOutput
,
typename
T
>
template
<
class
OpFinalOutput
,
typename
T
>
inline
void
forward_final_output
(
OpFinalOutput
opFinalOutput
,
inline
void
forward_final_output
(
OpFinalOutput
op_final_output
,
hl_gru_value
<
T
>
value
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
int
frame_size
,
int
batchSize
,
activation_mode_t
active_node
)
{
int
batch_size
,
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
activation_mode_t
active_node
)
{
if
(
OpFinalOutput
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
hl_avx_gru_forward_final_output
(
opFinalOutput
,
value
.
gateValue
,
if
(
OpFinalOutput
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
value
.
prevOutValue
,
value
.
outputValue
,
hl_avx_gru_forward_final_output
(
op_final_output
,
value
.
gate_value
,
frameSize
,
active_node
);
value
.
prev_out_value
,
value
.
output_value
,
frame_size
,
active_node
);
}
else
{
}
else
{
hl_naive_gru_forward_final_output
(
opFinalOutput
,
value
.
gateValue
,
hl_naive_gru_forward_final_output
(
value
.
prevOutValue
,
value
.
outputV
alue
,
op_final_output
,
value
.
gate_value
,
value
.
prev_out_v
alue
,
frameS
ize
,
active_node
);
value
.
output_value
,
frame_s
ize
,
active_node
);
}
}
value
.
gate
Value
+=
frameS
ize
*
3
;
value
.
gate
_value
+=
frame_s
ize
*
3
;
value
.
output
Value
+=
frameS
ize
;
value
.
output
_value
+=
frame_s
ize
;
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
value
.
prev
OutValue
+=
frameS
ize
;
value
.
prev
_out_value
+=
frame_s
ize
;
}
}
}
}
}
}
template
<
class
OpStateGrad
,
typename
T
>
template
<
class
OpStateGrad
,
typename
T
>
void
hl_naive_gru_backward_state_grad
(
OpStateGrad
op
StateGrad
,
T
*
gateV
alue
,
void
hl_naive_gru_backward_state_grad
(
OpStateGrad
op
_state_grad
,
T
*
gate_v
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
OutGrad
,
T
*
outputG
rad
,
T
*
prev
_out_grad
,
T
*
output_g
rad
,
int
frame
S
ize
,
int
frame
_s
ize
,
activation_mode_t
active_node
)
{
activation_mode_t
active_node
)
{
T
r
UpdateGateV
alue
;
T
r
_update_gate_v
alue
;
T
r
UpdateGateG
rad
;
T
r
_update_gate_g
rad
;
T
r
FrameStateV
alue
;
T
r
_frame_state_v
alue
;
T
r
FrameStateG
rad
;
T
r
_frame_state_g
rad
;
T
r
OutG
rad
;
T
r
_out_g
rad
;
T
r
PrevOutV
alue
=
0
;
T
r
_prev_out_v
alue
=
0
;
T
r
PrevOutG
rad
=
0
;
T
r
_prev_out_g
rad
=
0
;
T
*
update
GateValue
=
gateV
alue
;
T
*
update
_gate_value
=
gate_v
alue
;
T
*
update
GateGrad
=
gateG
rad
;
T
*
update
_gate_grad
=
gate_g
rad
;
T
*
frame
StateValue
=
gateValue
+
frameS
ize
*
2
;
T
*
frame
_state_value
=
gate_value
+
frame_s
ize
*
2
;
T
*
frame
StateGrad
=
gateGrad
+
frameS
ize
*
2
;
T
*
frame
_state_grad
=
gate_grad
+
frame_s
ize
*
2
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
UpdateGateValue
=
updateGateV
alue
[
i
];
r
_update_gate_value
=
update_gate_v
alue
[
i
];
r
FrameStateValue
=
frameStateV
alue
[
i
];
r
_frame_state_value
=
frame_state_v
alue
[
i
];
r
OutGrad
=
outputG
rad
[
i
];
r
_out_grad
=
output_g
rad
[
i
];
if
(
prev
OutV
alue
)
{
if
(
prev
_out_v
alue
)
{
r
PrevOutValue
=
prevOutV
alue
[
i
];
r
_prev_out_value
=
prev_out_v
alue
[
i
];
}
}
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
r
PrevOutGrad
=
prevOutG
rad
[
i
];
r
_prev_out_grad
=
prev_out_g
rad
[
i
];
}
}
op
StateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateV
alue
,
op
_state_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_frame_state_v
alue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutG
rad
,
r_frame_state_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
active_node
);
r_out_grad
,
active_node
);
update
GateGrad
[
i
]
=
rUpdateGateG
rad
;
update
_gate_grad
[
i
]
=
r_update_gate_g
rad
;
frame
StateGrad
[
i
]
=
rFrameStateG
rad
;
frame
_state_grad
[
i
]
=
r_frame_state_g
rad
;
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
prev
OutGrad
[
i
]
=
rPrevOutG
rad
;
prev
_out_grad
[
i
]
=
r_prev_out_g
rad
;
}
}
}
}
}
}
template
<
class
OpResetGrad
,
typename
T
>
template
<
class
OpResetGrad
,
typename
T
>
void
hl_naive_gru_backward_reset_grad
(
OpResetGrad
op
ResetGrad
,
T
*
gateV
alue
,
void
hl_naive_gru_backward_reset_grad
(
OpResetGrad
op
_reset_grad
,
T
*
gate_v
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
OutGrad
,
T
*
resetOutputG
rad
,
T
*
prev
_out_grad
,
T
*
reset_output_g
rad
,
int
frame
S
ize
,
int
frame
_s
ize
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
T
r
UpdateGateV
alue
;
T
r
_update_gate_v
alue
;
T
r
UpdateGateG
rad
;
T
r
_update_gate_g
rad
;
T
r
ResetGateV
alue
;
T
r
_reset_gate_v
alue
;
T
r
ResetGateG
rad
;
T
r
_reset_gate_g
rad
;
T
r
ResetOutputG
rad
=
0
;
T
r
_reset_output_g
rad
=
0
;
T
r
PrevOutV
alue
=
0
;
T
r
_prev_out_v
alue
=
0
;
T
r
PrevOutG
rad
=
0
;
T
r
_prev_out_g
rad
=
0
;
T
*
update
GateValue
=
gateV
alue
;
T
*
update
_gate_value
=
gate_v
alue
;
T
*
update
GateGrad
=
gateG
rad
;
T
*
update
_gate_grad
=
gate_g
rad
;
T
*
reset
GateValue
=
gateValue
+
frameS
ize
;
T
*
reset
_gate_value
=
gate_value
+
frame_s
ize
;
T
*
reset
GateGrad
=
gateGrad
+
frameS
ize
;
T
*
reset
_gate_grad
=
gate_grad
+
frame_s
ize
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
UpdateGateValue
=
updateGateV
alue
[
i
];
r
_update_gate_value
=
update_gate_v
alue
[
i
];
r
UpdateGateGrad
=
updateGateG
rad
[
i
];
r
_update_gate_grad
=
update_gate_g
rad
[
i
];
r
ResetGateValue
=
resetGateV
alue
[
i
];
r
_reset_gate_value
=
reset_gate_v
alue
[
i
];
if
(
prev
OutValue
&&
prevOutG
rad
)
{
if
(
prev
_out_value
&&
prev_out_g
rad
)
{
r
ResetOutputGrad
=
resetOutputG
rad
[
i
];
r
_reset_output_grad
=
reset_output_g
rad
[
i
];
}
}
if
(
prev
OutV
alue
)
{
if
(
prev
_out_v
alue
)
{
r
PrevOutValue
=
prevOutV
alue
[
i
];
r
_prev_out_value
=
prev_out_v
alue
[
i
];
}
}
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
r
PrevOutGrad
=
prevOutG
rad
[
i
];
r
_prev_out_grad
=
prev_out_g
rad
[
i
];
}
}
op
ResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateV
alue
,
op
_reset_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_reset_gate_v
alue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputG
rad
,
r_reset_gate_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
active_gate
);
r_reset_output_grad
,
active_gate
);
update
GateGrad
[
i
]
=
rUpdateGateG
rad
;
update
_gate_grad
[
i
]
=
r_update_gate_g
rad
;
reset
GateGrad
[
i
]
=
rResetGateG
rad
;
reset
_gate_grad
[
i
]
=
r_reset_gate_g
rad
;
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
prev
OutGrad
[
i
]
=
rPrevOutG
rad
;
prev
_out_grad
[
i
]
=
r_prev_out_g
rad
;
}
}
}
}
}
}
template
<
class
OpStateGrad
,
typename
T
>
template
<
class
OpStateGrad
,
typename
T
>
void
hl_avx_gru_backward_state_grad
(
OpStateGrad
op
StateGrad
,
T
*
gateV
alue
,
void
hl_avx_gru_backward_state_grad
(
OpStateGrad
op
_state_grad
,
T
*
gate_v
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
OutGrad
,
T
*
outputG
rad
,
T
*
prev
_out_grad
,
T
*
output_g
rad
,
int
frame
S
ize
,
int
frame
_s
ize
,
activation_mode_t
active_node
)
{
activation_mode_t
active_node
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r
UpdateGateV
alue
;
__m256
r
_update_gate_v
alue
;
__m256
r
UpdateGateG
rad
;
__m256
r
_update_gate_g
rad
;
__m256
r
FrameStateV
alue
;
__m256
r
_frame_state_v
alue
;
__m256
r
FrameStateG
rad
;
__m256
r
_frame_state_g
rad
;
__m256
r
OutG
rad
;
__m256
r
_out_g
rad
;
__m256
r
PrevOutV
alue
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_out_v
alue
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
PrevOutG
rad
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_out_g
rad
=
_mm256_set1_ps
(
0.0
f
);
__m256
*
update
GateValue
=
(
__m256
*
)
gateV
alue
;
__m256
*
update
_gate_value
=
(
__m256
*
)
gate_v
alue
;
__m256
*
update
GateGrad
=
(
__m256
*
)
gateG
rad
;
__m256
*
update
_gate_grad
=
(
__m256
*
)
gate_g
rad
;
__m256
*
frame
StateValue
=
(
__m256
*
)(
gateValue
+
frameS
ize
*
2
);
__m256
*
frame
_state_value
=
(
__m256
*
)(
gate_value
+
frame_s
ize
*
2
);
__m256
*
frame
StateGrad
=
(
__m256
*
)(
gateGrad
+
frameS
ize
*
2
);
__m256
*
frame
_state_grad
=
(
__m256
*
)(
gate_grad
+
frame_s
ize
*
2
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
UpdateGateValue
=
updateGateV
alue
[
i
];
r
_update_gate_value
=
update_gate_v
alue
[
i
];
r
FrameStateValue
=
frameStateV
alue
[
i
];
r
_frame_state_value
=
frame_state_v
alue
[
i
];
r
OutGrad
=
((
__m256
*
)
outputG
rad
)[
i
];
r
_out_grad
=
((
__m256
*
)
output_g
rad
)[
i
];
if
(
prev
OutV
alue
)
{
if
(
prev
_out_v
alue
)
{
r
PrevOutValue
=
((
__m256
*
)
prevOutV
alue
)[
i
];
r
_prev_out_value
=
((
__m256
*
)
prev_out_v
alue
)[
i
];
}
}
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
r
PrevOutGrad
=
((
__m256
*
)
prevOutG
rad
)[
i
];
r
_prev_out_grad
=
((
__m256
*
)
prev_out_g
rad
)[
i
];
}
}
op
StateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateV
alue
,
op
_state_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_frame_state_v
alue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutG
rad
,
r_frame_state_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
active_node
);
r_out_grad
,
active_node
);
update
GateGrad
[
i
]
=
rUpdateGateG
rad
;
update
_gate_grad
[
i
]
=
r_update_gate_g
rad
;
frame
StateGrad
[
i
]
=
rFrameStateG
rad
;
frame
_state_grad
[
i
]
=
r_frame_state_g
rad
;
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
((
__m256
*
)
prev
OutGrad
)[
i
]
=
rPrevOutG
rad
;
((
__m256
*
)
prev
_out_grad
)[
i
]
=
r_prev_out_g
rad
;
}
}
}
}
#endif
#endif
}
}
template
<
class
OpResetGrad
,
typename
T
>
template
<
class
OpResetGrad
,
typename
T
>
void
hl_avx_gru_backward_reset_grad
(
OpResetGrad
op
ResetGrad
,
T
*
gateV
alue
,
void
hl_avx_gru_backward_reset_grad
(
OpResetGrad
op
_reset_grad
,
T
*
gate_v
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
OutGrad
,
T
*
resetOutputG
rad
,
T
*
prev
_out_grad
,
T
*
reset_output_g
rad
,
int
frame
S
ize
,
int
frame
_s
ize
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r
UpdateGateV
alue
;
__m256
r
_update_gate_v
alue
;
__m256
r
UpdateGateG
rad
;
__m256
r
_update_gate_g
rad
;
__m256
r
ResetGateV
alue
;
__m256
r
_reset_gate_v
alue
;
__m256
r
ResetGateG
rad
;
__m256
r
_reset_gate_g
rad
;
__m256
r
ResetOutputG
rad
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_reset_output_g
rad
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
PrevOutV
alue
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_out_v
alue
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
PrevOutG
rad
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_out_g
rad
=
_mm256_set1_ps
(
0.0
f
);
__m256
*
update
GateValue
=
(
__m256
*
)
gateV
alue
;
__m256
*
update
_gate_value
=
(
__m256
*
)
gate_v
alue
;
__m256
*
update
GateGrad
=
(
__m256
*
)
gateG
rad
;
__m256
*
update
_gate_grad
=
(
__m256
*
)
gate_g
rad
;
__m256
*
reset
GateValue
=
(
__m256
*
)(
gateValue
+
frameS
ize
);
__m256
*
reset
_gate_value
=
(
__m256
*
)(
gate_value
+
frame_s
ize
);
__m256
*
reset
GateGrad
=
(
__m256
*
)(
gateGrad
+
frameS
ize
);
__m256
*
reset
_gate_grad
=
(
__m256
*
)(
gate_grad
+
frame_s
ize
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
UpdateGateValue
=
updateGateV
alue
[
i
];
r
_update_gate_value
=
update_gate_v
alue
[
i
];
r
UpdateGateGrad
=
updateGateG
rad
[
i
];
r
_update_gate_grad
=
update_gate_g
rad
[
i
];
r
ResetGateValue
=
resetGateV
alue
[
i
];
r
_reset_gate_value
=
reset_gate_v
alue
[
i
];
if
(
prev
OutValue
&&
prevOutG
rad
)
{
if
(
prev
_out_value
&&
prev_out_g
rad
)
{
r
ResetOutputGrad
=
((
__m256
*
)
resetOutputG
rad
)[
i
];
r
_reset_output_grad
=
((
__m256
*
)
reset_output_g
rad
)[
i
];
}
}
if
(
prev
OutV
alue
)
{
if
(
prev
_out_v
alue
)
{
r
PrevOutValue
=
((
__m256
*
)
prevOutV
alue
)[
i
];
r
_prev_out_value
=
((
__m256
*
)
prev_out_v
alue
)[
i
];
}
}
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
r
PrevOutGrad
=
((
__m256
*
)
prevOutG
rad
)[
i
];
r
_prev_out_grad
=
((
__m256
*
)
prev_out_g
rad
)[
i
];
}
}
op
ResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateV
alue
,
op
_reset_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_reset_gate_v
alue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputG
rad
,
r_reset_gate_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
active_gate
);
r_reset_output_grad
,
active_gate
);
update
GateGrad
[
i
]
=
rUpdateGateG
rad
;
update
_gate_grad
[
i
]
=
r_update_gate_g
rad
;
reset
GateGrad
[
i
]
=
rResetGateG
rad
;
reset
_gate_grad
[
i
]
=
r_reset_gate_g
rad
;
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
((
__m256
*
)
prev
OutGrad
)[
i
]
=
rPrevOutG
rad
;
((
__m256
*
)
prev
_out_grad
)[
i
]
=
r_prev_out_g
rad
;
}
}
}
}
#endif
#endif
}
}
template
<
class
OpStateGrad
,
typename
T
>
template
<
class
OpStateGrad
,
typename
T
>
inline
void
backward_state_grad
(
OpStateGrad
opStateGrad
,
hl_gru_value
<
T
>
value
,
inline
void
backward_state_grad
(
OpStateGrad
op_state_grad
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
batchSize
,
activation_mode_t
active_node
)
{
int
frame_size
,
int
batch_size
,
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
activation_mode_t
active_node
)
{
if
(
OpStateGrad
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpStateGrad
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_backward_state_grad
(
hl_avx_gru_backward_state_grad
(
op
StateGrad
,
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutV
alue
,
op
_state_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_v
alue
,
grad
.
prev
OutGrad
,
grad
.
outputGrad
,
frameS
ize
,
active_node
);
grad
.
prev
_out_grad
,
grad
.
output_grad
,
frame_s
ize
,
active_node
);
}
else
{
}
else
{
hl_naive_gru_backward_state_grad
(
hl_naive_gru_backward_state_grad
(
op
StateGrad
,
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutV
alue
,
op
_state_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_v
alue
,
grad
.
prev
OutGrad
,
grad
.
outputGrad
,
frameS
ize
,
active_node
);
grad
.
prev
_out_grad
,
grad
.
output_grad
,
frame_s
ize
,
active_node
);
}
}
value
.
gate
Value
+=
frameS
ize
*
3
;
value
.
gate
_value
+=
frame_s
ize
*
3
;
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
value
.
prev
OutValue
+=
frameS
ize
;
value
.
prev
_out_value
+=
frame_s
ize
;
}
}
grad
.
gate
Grad
+=
frameS
ize
*
3
;
grad
.
gate
_grad
+=
frame_s
ize
*
3
;
grad
.
output
Grad
+=
frameS
ize
;
grad
.
output
_grad
+=
frame_s
ize
;
if
(
grad
.
prev
OutG
rad
)
{
if
(
grad
.
prev
_out_g
rad
)
{
grad
.
prev
OutGrad
+=
frameS
ize
;
grad
.
prev
_out_grad
+=
frame_s
ize
;
}
}
}
}
}
}
template
<
class
OpResetGrad
,
typename
T
>
template
<
class
OpResetGrad
,
typename
T
>
inline
void
backward_reset_grad
(
OpResetGrad
opResetGrad
,
hl_gru_value
<
T
>
value
,
inline
void
backward_reset_grad
(
OpResetGrad
op_reset_grad
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
batchSize
,
activation_mode_t
active_gate
)
{
int
frame_size
,
int
batch_size
,
for
(
int
b
=
0
;
b
<
batchSize
;
b
++
)
{
activation_mode_t
active_gate
)
{
if
(
OpResetGrad
::
avx
&&
!
(
frameSize
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
for
(
int
b
=
0
;
b
<
batch_size
;
b
++
)
{
if
(
OpResetGrad
::
avx
&&
!
(
frame_size
&
(
8
-
1
))
&&
(
sizeof
(
T
)
==
4
))
{
hl_avx_gru_backward_reset_grad
(
hl_avx_gru_backward_reset_grad
(
op
ResetGrad
,
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutV
alue
,
op
_reset_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_v
alue
,
grad
.
prev
OutGrad
,
grad
.
resetOutputGrad
,
frameS
ize
,
active_gate
);
grad
.
prev
_out_grad
,
grad
.
reset_output_grad
,
frame_s
ize
,
active_gate
);
}
else
{
}
else
{
hl_naive_gru_backward_reset_grad
(
hl_naive_gru_backward_reset_grad
(
op
ResetGrad
,
value
.
gateValue
,
grad
.
gateGrad
,
value
.
prevOutV
alue
,
op
_reset_grad
,
value
.
gate_value
,
grad
.
gate_grad
,
value
.
prev_out_v
alue
,
grad
.
prev
OutGrad
,
grad
.
resetOutputGrad
,
frameS
ize
,
active_gate
);
grad
.
prev
_out_grad
,
grad
.
reset_output_grad
,
frame_s
ize
,
active_gate
);
}
}
value
.
gate
Value
+=
frameS
ize
*
3
;
value
.
gate
_value
+=
frame_s
ize
*
3
;
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
value
.
prev
OutValue
+=
frameS
ize
;
value
.
prev
_out_value
+=
frame_s
ize
;
}
}
grad
.
gate
Grad
+=
frameS
ize
*
3
;
grad
.
gate
_grad
+=
frame_s
ize
*
3
;
grad
.
reset
OutputGrad
+=
frameS
ize
;
grad
.
reset
_output_grad
+=
frame_s
ize
;
if
(
grad
.
prev
OutG
rad
)
{
if
(
grad
.
prev
_out_g
rad
)
{
grad
.
prev
OutGrad
+=
frameS
ize
;
grad
.
prev
_out_grad
+=
frame_s
ize
;
}
}
}
}
}
}
...
...
paddle/operators/math/detail/gru_gpu_kernel.h
浏览文件 @
3e552cdc
...
@@ -27,174 +27,174 @@ namespace math {
...
@@ -27,174 +27,174 @@ namespace math {
namespace
detail
{
namespace
detail
{
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
OpResetOutput
,
bool
is
B
atch
,
typename
T
>
template
<
class
OpResetOutput
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruForwardResetOutput
(
OpResetOutput
op
ResetO
utput
,
__global__
void
KeGruForwardResetOutput
(
OpResetOutput
op
_reset_o
utput
,
T
*
gate
Value
,
T
*
resetOutputV
alue
,
T
*
gate
_value
,
T
*
reset_output_v
alue
,
T
*
prev
OutputValue
,
int
frameS
ize
,
T
*
prev
_output_value
,
int
frame_s
ize
,
int
batch
S
ize
,
int
batch
_s
ize
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
reset
OutputValue
+=
batchIdx
*
frameS
ize
;
reset
_output_value
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
PrevO
ut
=
0
;
T
r
_prev_o
ut
=
0
;
T
r
ValueResetO
utput
;
T
r
_value_reset_o
utput
;
T
r
ValueUpdateGate
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
_value_update_gate
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
ValueResetGate
=
gateValue
[
frameIdx
+
frameS
ize
*
1
];
T
r
_value_reset_gate
=
gate_value
[
frame_idx
+
frame_s
ize
*
1
];
if
(
prev
OutputV
alue
)
{
if
(
prev
_output_v
alue
)
{
if
(
is
Batch
)
prevOutputValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_output_value
+=
batch_idx
*
frame_s
ize
;
r
PrevOut
=
prevOutputValue
[
frameI
dx
];
r
_prev_out
=
prev_output_value
[
frame_i
dx
];
}
}
op
ResetOutput
(
rValueUpdateGate
,
rValueResetGate
,
rPrevOut
,
rValueResetOutp
ut
,
op
_reset_output
(
r_value_update_gate
,
r_value_reset_gate
,
r_prev_o
ut
,
active_gate
);
r_value_reset_output
,
active_gate
);
gate
Value
[
frameIdx
+
frameSize
*
0
]
=
rValueUpdateG
ate
;
gate
_value
[
frame_idx
+
frame_size
*
0
]
=
r_value_update_g
ate
;
gate
Value
[
frameIdx
+
frameSize
*
1
]
=
rValueResetG
ate
;
gate
_value
[
frame_idx
+
frame_size
*
1
]
=
r_value_reset_g
ate
;
reset
OutputValue
[
frameIdx
]
=
rValueResetO
utput
;
reset
_output_value
[
frame_idx
]
=
r_value_reset_o
utput
;
}
}
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
OpFinalOutput
,
bool
is
B
atch
,
typename
T
>
template
<
class
OpFinalOutput
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruForwardFinalOutput
(
OpFinalOutput
op
FinalO
utput
,
__global__
void
KeGruForwardFinalOutput
(
OpFinalOutput
op
_final_o
utput
,
T
*
gate
Value
,
T
*
prevOutputV
alue
,
T
*
gate
_value
,
T
*
prev_output_v
alue
,
T
*
output
Value
,
int
frameS
ize
,
T
*
output
_value
,
int
frame_s
ize
,
int
batch
S
ize
,
int
batch
_s
ize
,
activation_mode_t
active_node
)
{
activation_mode_t
active_node
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
output
Value
+=
batchIdx
*
frameS
ize
;
output
_value
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
O
utput
;
T
r
_o
utput
;
T
r
PrevO
ut
=
0
;
T
r
_prev_o
ut
=
0
;
T
r
ValueUpdateGate
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
_value_update_gate
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
ValueFrameState
=
gateValue
[
frameIdx
+
frameS
ize
*
2
];
T
r
_value_frame_state
=
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
if
(
prev
OutputV
alue
)
{
if
(
prev
_output_v
alue
)
{
if
(
is
Batch
)
prevOutputValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_output_value
+=
batch_idx
*
frame_s
ize
;
r
PrevOut
=
prevOutputValue
[
frameI
dx
];
r
_prev_out
=
prev_output_value
[
frame_i
dx
];
}
}
op
FinalOutput
(
rValueUpdateGate
,
rValueFrameState
,
rPrevOut
,
rOutp
ut
,
op
_final_output
(
r_value_update_gate
,
r_value_frame_state
,
r_prev_o
ut
,
active_node
);
r_output
,
active_node
);
gate
Value
[
frameIdx
+
frameSize
*
2
]
=
rValueFrameS
tate
;
gate
_value
[
frame_idx
+
frame_size
*
2
]
=
r_value_frame_s
tate
;
output
Value
[
frameIdx
]
=
rO
utput
;
output
_value
[
frame_idx
]
=
r_o
utput
;
}
}
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
OpStateGrad
,
bool
is
B
atch
,
typename
T
>
template
<
class
OpStateGrad
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruBackwardStateGrad
(
OpStateGrad
op
StateGrad
,
T
*
gateV
alue
,
__global__
void
KeGruBackwardStateGrad
(
OpStateGrad
op
_state_grad
,
T
*
gate_v
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
OutGrad
,
T
*
outputG
rad
,
T
*
prev
_out_grad
,
T
*
output_g
rad
,
int
frame
Size
,
int
batchS
ize
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
)
{
activation_mode_t
active_node
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
gate
Grad
+=
batchIdx
*
3
*
frameS
ize
;
gate
_grad
+=
batch_idx
*
3
*
frame_s
ize
;
output
Grad
+=
batchIdx
*
frameS
ize
;
output
_grad
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
UpdateGateG
rad
;
T
r
_update_gate_g
rad
;
T
r
FrameStateG
rad
;
T
r
_frame_state_g
rad
;
T
r
PrevOutV
alue
=
0
;
T
r
_prev_out_v
alue
=
0
;
T
r
PrevOutG
rad
=
0
;
T
r
_prev_out_g
rad
=
0
;
T
r
UpdateGateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
_update_gate_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
FrameStateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
2
];
T
r
_frame_state_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
T
r
OutGrad
=
outputGrad
[
frameI
dx
];
T
r
_out_grad
=
output_grad
[
frame_i
dx
];
if
(
prev
OutValue
&&
prevOutG
rad
)
{
if
(
prev
_out_value
&&
prev_out_g
rad
)
{
if
(
is
Batch
)
prevOutValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_out_value
+=
batch_idx
*
frame_s
ize
;
r
PrevOutValue
=
prevOutValue
[
frameI
dx
];
r
_prev_out_value
=
prev_out_value
[
frame_i
dx
];
if
(
is
Batch
)
prevOutGrad
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_out_grad
+=
batch_idx
*
frame_s
ize
;
r
PrevOutGrad
=
prevOutGrad
[
frameI
dx
];
r
_prev_out_grad
=
prev_out_grad
[
frame_i
dx
];
}
}
op
StateGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rFrameStateV
alue
,
op
_state_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_frame_state_v
alue
,
rFrameStateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rOutG
rad
,
r_frame_state_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
active_node
);
r_out_grad
,
active_node
);
gate
Grad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
0
]
=
r_update_gate_g
rad
;
gate
Grad
[
frameIdx
+
frameSize
*
2
]
=
rFrameStateG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
2
]
=
r_frame_state_g
rad
;
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
prev
OutGrad
[
frameIdx
]
=
rPrevOutG
rad
;
prev
_out_grad
[
frame_idx
]
=
r_prev_out_g
rad
;
}
}
}
}
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
OpResetGrad
,
bool
is
B
atch
,
typename
T
>
template
<
class
OpResetGrad
,
bool
is
_b
atch
,
typename
T
>
__global__
void
KeGruBackwardResetGrad
(
OpResetGrad
op
ResetGrad
,
T
*
gateV
alue
,
__global__
void
KeGruBackwardResetGrad
(
OpResetGrad
op
_reset_grad
,
T
*
gate_v
alue
,
T
*
gate
Grad
,
T
*
prevOutV
alue
,
T
*
gate
_grad
,
T
*
prev_out_v
alue
,
T
*
prev
OutGrad
,
T
*
resetOutputG
rad
,
T
*
prev
_out_grad
,
T
*
reset_output_g
rad
,
int
frame
Size
,
int
batchS
ize
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
const
int
frame
Idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadI
dx
.
x
;
const
int
frame
_idx
=
block_idx
.
x
*
block_dim
.
x
+
thread_i
dx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
Idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadI
dx
.
y
;
batch
_idx
=
block_idx
.
y
*
block_dim
.
y
+
thread_i
dx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
gate
Value
+=
batchIdx
*
3
*
frameS
ize
;
gate
_value
+=
batch_idx
*
3
*
frame_s
ize
;
gate
Grad
+=
batchIdx
*
3
*
frameS
ize
;
gate
_grad
+=
batch_idx
*
3
*
frame_s
ize
;
reset
OutputGrad
+=
batchIdx
*
frameS
ize
;
reset
_output_grad
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
ResetGateG
rad
;
T
r
_reset_gate_g
rad
;
T
r
PrevOutV
alue
=
0
;
T
r
_prev_out_v
alue
=
0
;
T
r
PrevOutG
rad
=
0
;
T
r
_prev_out_g
rad
=
0
;
T
r
ResetOutputG
rad
=
0
;
T
r
_reset_output_g
rad
=
0
;
T
r
UpdateGateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
0
];
T
r
_update_gate_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
0
];
T
r
UpdateGateGrad
=
gateGrad
[
frameIdx
+
frameS
ize
*
0
];
T
r
_update_gate_grad
=
gate_grad
[
frame_idx
+
frame_s
ize
*
0
];
T
r
ResetGateValue
=
gateValue
[
frameIdx
+
frameS
ize
*
1
];
T
r
_reset_gate_value
=
gate_value
[
frame_idx
+
frame_s
ize
*
1
];
if
(
prev
OutValue
&&
prevOutG
rad
)
{
if
(
prev
_out_value
&&
prev_out_g
rad
)
{
if
(
is
Batch
)
prevOutValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_out_value
+=
batch_idx
*
frame_s
ize
;
if
(
is
Batch
)
prevOutGrad
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
prev_out_grad
+=
batch_idx
*
frame_s
ize
;
r
PrevOutValue
=
prevOutValue
[
frameI
dx
];
r
_prev_out_value
=
prev_out_value
[
frame_i
dx
];
r
PrevOutGrad
=
prevOutGrad
[
frameI
dx
];
r
_prev_out_grad
=
prev_out_grad
[
frame_i
dx
];
r
ResetOutputGrad
=
resetOutputGrad
[
frameI
dx
];
r
_reset_output_grad
=
reset_output_grad
[
frame_i
dx
];
}
}
op
ResetGrad
(
rUpdateGateValue
,
rUpdateGateGrad
,
rResetGateV
alue
,
op
_reset_grad
(
r_update_gate_value
,
r_update_gate_grad
,
r_reset_gate_v
alue
,
rResetGateGrad
,
rPrevOutValue
,
rPrevOutGrad
,
rResetOutputG
rad
,
r_reset_gate_grad
,
r_prev_out_value
,
r_prev_out_g
rad
,
active_gate
);
r_reset_output_grad
,
active_gate
);
gate
Grad
[
frameIdx
+
frameSize
*
0
]
=
rUpdateGateG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
0
]
=
r_update_gate_g
rad
;
gate
Grad
[
frameIdx
+
frameSize
*
1
]
=
rResetGateG
rad
;
gate
_grad
[
frame_idx
+
frame_size
*
1
]
=
r_reset_gate_g
rad
;
if
(
prev
OutG
rad
)
{
if
(
prev
_out_g
rad
)
{
prev
OutGrad
[
frameIdx
]
=
rPrevOutG
rad
;
prev
_out_grad
[
frame_idx
]
=
r_prev_out_g
rad
;
}
}
}
}
}
// namespace detail
}
// namespace detail
...
...
paddle/operators/math/detail/gru_kernel.h
浏览文件 @
3e552cdc
...
@@ -28,23 +28,25 @@ namespace forward {
...
@@ -28,23 +28,25 @@ namespace forward {
template
<
typename
T
>
template
<
typename
T
>
class
gru_resetOutput
{
class
gru_resetOutput
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueResetGate
,
T
&
prevOut
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
value_reset_gate
,
T
&
valueResetOutput
,
activation_mode_t
actGate
)
{
T
&
prev_out
,
T
&
value_reset_output
,
valueUpdateGate
=
activation
(
valueUpdateGate
,
actGate
);
activation_mode_t
act_gate
)
{
valueResetGate
=
activation
(
valueResetGate
,
actGate
);
value_update_gate
=
activation
(
value_update_gate
,
act_gate
);
valueResetOutput
=
prevOut
*
valueResetGate
;
value_reset_gate
=
activation
(
value_reset_gate
,
act_gate
);
value_reset_output
=
prev_out
*
value_reset_gate
;
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
#else
#else
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueResetGate
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
prevOut
,
__m256
&
valueResetOutput
,
__m256
&
value_reset_gate
,
__m256
&
prev_out
,
activation_mode_t
actGate
)
{
__m256
&
value_reset_output
,
valueUpdateGate
=
activation
(
valueUpdateGate
,
actGate
);
activation_mode_t
act_gate
)
{
valueResetGate
=
activation
(
valueResetGate
,
actGate
);
value_update_gate
=
activation
(
value_update_gate
,
act_gate
);
valueResetOutput
=
_mm256_mul_ps
(
prevOut
,
valueResetGate
);
value_reset_gate
=
activation
(
value_reset_gate
,
act_gate
);
value_reset_output
=
_mm256_mul_ps
(
prev_out
,
value_reset_gate
);
}
}
#endif
#endif
#endif
#endif
...
@@ -53,24 +55,26 @@ class gru_resetOutput {
...
@@ -53,24 +55,26 @@ class gru_resetOutput {
template
<
typename
T
>
template
<
typename
T
>
class
gru_finalOutput
{
class
gru_finalOutput
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
valueFrameState
,
T
&
prevOut
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
value_frame_state
,
T
&
valueOutput
,
activation_mode_t
actInput
)
{
T
&
prev_out
,
T
&
value_output
,
valueFrameState
=
activation
(
valueFrameState
,
actInput
);
activation_mode_t
act_input
)
{
valueOutput
=
prevOut
-
(
valueUpdateGate
*
prevOut
)
+
value_frame_state
=
activation
(
value_frame_state
,
act_input
);
(
valueUpdateGate
*
valueFrameState
);
value_output
=
prev_out
-
(
value_update_gate
*
prev_out
)
+
(
value_update_gate
*
value_frame_state
);
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
#else
#else
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
valueFrameState
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
prevOut
,
__m256
&
valueOutput
,
__m256
&
value_frame_state
,
__m256
&
prev_out
,
activation_mode_t
actInput
)
{
__m256
&
value_output
,
valueFrameState
=
activation
(
valueFrameState
,
actInput
);
activation_mode_t
act_input
)
{
valueOutput
=
_mm256_add_ps
(
value_frame_state
=
activation
(
value_frame_state
,
act_input
);
_mm256_sub_ps
(
prevOut
,
_mm256_mul_ps
(
valueUpdateGate
,
prevOut
)),
value_output
=
_mm256_add_ps
(
_mm256_mul_ps
(
valueUpdateGate
,
valueFrameState
));
_mm256_sub_ps
(
prev_out
,
_mm256_mul_ps
(
value_update_gate
,
prev_out
)),
_mm256_mul_ps
(
value_update_gate
,
value_frame_state
));
}
}
#endif
#endif
#endif
#endif
...
@@ -82,34 +86,37 @@ namespace backward {
...
@@ -82,34 +86,37 @@ namespace backward {
template
<
typename
T
>
template
<
typename
T
>
class
gru_stateGrad
{
class
gru_stateGrad
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
value
UpdateGate
,
T
&
gradUpdateG
ate
,
HOSTDEVICE
void
operator
()(
T
&
value
_update_gate
,
T
&
grad_update_g
ate
,
T
&
value
FrameState
,
T
&
gradFrameS
tate
,
T
&
value
_frame_state
,
T
&
grad_frame_s
tate
,
T
&
value
PrevOut
,
T
&
gradPrevOut
,
T
&
gradOutp
ut
,
T
&
value
_prev_out
,
T
&
grad_prev_o
ut
,
activation_mode_t
actI
nput
)
{
T
&
grad_output
,
activation_mode_t
act_i
nput
)
{
grad
UpdateGate
=
(
gradOutput
*
valueFrameS
tate
);
grad
_update_gate
=
(
grad_output
*
value_frame_s
tate
);
grad
UpdateGate
-=
(
gradOutput
*
valuePrevO
ut
);
grad
_update_gate
-=
(
grad_output
*
value_prev_o
ut
);
grad
PrevOut
-=
(
gradOutput
*
valueUpdateG
ate
);
grad
_prev_out
-=
(
grad_output
*
value_update_g
ate
);
grad
PrevOut
+=
gradO
utput
;
grad
_prev_out
+=
grad_o
utput
;
grad
FrameState
=
grad
_frame_state
=
activation
(
grad_output
*
value_update_gate
,
activation
(
gradOutput
*
valueUpdateGate
,
valueFrameState
,
actI
nput
);
value_frame_state
,
act_i
nput
);
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
#else
#else
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
gradUpdateGate
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
valueFrameState
,
__m256
&
gradFrameState
,
__m256
&
grad_update_gate
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
value_frame_state
,
__m256
&
gradOutput
,
activation_mode_t
actInput
)
{
__m256
&
grad_frame_state
,
__m256
&
value_prev_out
,
gradUpdateGate
=
_mm256_mul_ps
(
gradOutput
,
valueFrameState
);
__m256
&
grad_prev_out
,
__m256
&
grad_output
,
gradUpdateGate
=
activation_mode_t
act_input
)
{
_mm256_sub_ps
(
gradUpdateGate
,
_mm256_mul_ps
(
gradOutput
,
valuePrevOut
));
grad_update_gate
=
_mm256_mul_ps
(
grad_output
,
value_frame_state
);
gradPrevOut
=
_mm256_add_ps
(
grad_update_gate
=
_mm256_sub_ps
(
_mm256_sub_ps
(
gradPrevOut
,
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
)),
grad_update_gate
,
_mm256_mul_ps
(
grad_output
,
value_prev_out
));
gradOutput
);
grad_prev_out
=
_mm256_add_ps
(
gradFrameState
=
activation
(
_mm256_mul_ps
(
gradOutput
,
valueUpdateGate
),
_mm256_sub_ps
(
grad_prev_out
,
valueFrameState
,
actInput
);
_mm256_mul_ps
(
grad_output
,
value_update_gate
)),
grad_output
);
grad_frame_state
=
activation
(
_mm256_mul_ps
(
grad_output
,
value_update_gate
),
value_frame_state
,
act_input
);
}
}
#endif
#endif
#endif
#endif
...
@@ -118,30 +125,32 @@ class gru_stateGrad {
...
@@ -118,30 +125,32 @@ class gru_stateGrad {
template
<
typename
T
>
template
<
typename
T
>
class
gru_resetGrad
{
class
gru_resetGrad
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
valueUpdateGate
,
T
&
gradUpdateGate
,
HOSTDEVICE
void
operator
()(
T
&
value_update_gate
,
T
&
grad_update_gate
,
T
&
valueResetGate
,
T
&
gradResetGate
,
T
&
value_reset_gate
,
T
&
grad_reset_gate
,
T
&
valuePrevOut
,
T
&
gradPrevOut
,
T
&
value_prev_out
,
T
&
grad_prev_out
,
T
&
gradResetOutput
,
activation_mode_t
actGate
)
{
T
&
grad_reset_output
,
activation_mode_t
act_gate
)
{
gradResetGate
=
(
gradResetOutput
*
valuePrevOut
);
grad_reset_gate
=
(
grad_reset_output
*
value_prev_out
);
gradPrevOut
+=
(
gradResetOutput
*
valueResetGate
);
grad_prev_out
+=
(
grad_reset_output
*
value_reset_gate
);
gradUpdateGate
=
activation
(
gradUpdateGate
,
valueUpdateGate
,
actGate
);
grad_update_gate
=
gradResetGate
=
activation
(
gradResetGate
,
valueResetGate
,
actGate
);
activation
(
grad_update_gate
,
value_update_gate
,
act_gate
);
grad_reset_gate
=
activation
(
grad_reset_gate
,
value_reset_gate
,
act_gate
);
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__
#ifndef __AVX__
static
const
bool
avx
=
false
;
static
const
bool
avx
=
false
;
#else
#else
static
const
bool
avx
=
true
;
static
const
bool
avx
=
true
;
HOSTDEVICE
void
operator
()(
__m256
&
valueUpdateGate
,
__m256
&
gradUpdateGate
,
HOSTDEVICE
void
operator
()(
__m256
&
value_update_gate
,
__m256
&
valueResetGate
,
__m256
&
gradResetGate
,
__m256
&
grad_update_gate
,
__m256
&
value_reset_gate
,
__m256
&
valuePrevOut
,
__m256
&
gradPrevOut
,
__m256
&
grad_reset_gate
,
__m256
&
value_prev_out
,
__m256
&
gradResetOutput
,
__m256
&
grad_prev_out
,
__m256
&
grad_reset_output
,
activation_mode_t
actGate
)
{
activation_mode_t
act_gate
)
{
gradResetGate
=
_mm256_mul_ps
(
gradResetOutput
,
valuePrevOut
);
grad_reset_gate
=
_mm256_mul_ps
(
grad_reset_output
,
value_prev_out
);
gradPrevOut
=
_mm256_add_ps
(
gradPrevOut
,
grad_prev_out
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradResetOutput
,
valueResetGate
));
grad_prev_out
,
_mm256_mul_ps
(
grad_reset_output
,
value_reset_gate
));
gradUpdateGate
=
activation
(
gradUpdateGate
,
valueUpdateGate
,
actGate
);
grad_update_gate
=
gradResetGate
=
activation
(
gradResetGate
,
valueResetGate
,
actGate
);
activation
(
grad_update_gate
,
value_update_gate
,
act_gate
);
grad_reset_gate
=
activation
(
grad_reset_gate
,
value_reset_gate
,
act_gate
);
}
}
#endif
#endif
#endif
#endif
...
...
paddle/operators/math/gru_compute.cc
浏览文件 @
3e552cdc
...
@@ -21,29 +21,29 @@ namespace math {
...
@@ -21,29 +21,29 @@ namespace math {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitFunctor
<
platform
::
CPUPlace
,
T
>
{
struct
GRUUnitFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
hl_gru_value
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
#ifndef __NVCC__
#ifndef __NVCC__
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
*
2
,
frameS
ize
,
1
,
context
,
false
,
false
,
batch
_size
,
frame_size
*
2
,
frame_s
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
value
.
gateWeight
,
frameSize
*
2
,
1
,
value
.
prev
_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
value
.
gateValue
,
frameS
ize
*
3
);
1
,
value
.
gate_value
,
frame_s
ize
*
3
);
}
}
detail
::
forward_reset_output
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
,
detail
::
forward_reset_output
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
,
frame
Size
,
batchS
ize
,
active_gate
);
frame
_size
,
batch_s
ize
,
active_gate
);
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
context
,
false
,
false
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
value
.
reset
OutputValue
,
frameSize
,
value
.
stateWeight
,
frameSize
,
1
,
value
.
reset
_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
value
.
gateValue
+
frameSize
*
2
,
frameS
ize
*
3
);
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_s
ize
*
3
);
}
}
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
detail
::
forward_final_output
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
,
frame
Size
,
batchS
ize
,
active_node
);
frame
_size
,
batch_s
ize
,
active_node
);
#endif
#endif
}
}
};
};
...
@@ -51,41 +51,43 @@ struct GRUUnitFunctor<platform::CPUPlace, T> {
...
@@ -51,41 +51,43 @@ struct GRUUnitFunctor<platform::CPUPlace, T> {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitGradFunctor
<
platform
::
CPUPlace
,
T
>
{
struct
GRUUnitGradFunctor
<
platform
::
CPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
batchSize
,
activation_mode_t
active_node
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
#ifndef __NVCC__
#ifndef __NVCC__
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
detail
::
backward_state_grad
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
,
grad
,
frame
Size
,
batchS
ize
,
active_node
);
grad
,
frame
_size
,
batch_s
ize
,
active_node
);
if
(
value
.
prev
OutValue
&&
grad
.
prevOutG
rad
)
{
if
(
value
.
prev
_out_value
&&
grad
.
prev_out_g
rad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
grad
.
gate
Grad
+
frameSize
*
2
,
frameSize
*
3
,
value
.
stateW
eight
,
grad
.
gate
_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_w
eight
,
frame
Size
,
0
,
grad
.
resetOutputGrad
,
frameS
ize
);
frame
_size
,
0
,
grad
.
reset_output_grad
,
frame_s
ize
);
if
(
grad
.
state
WeightG
rad
)
{
if
(
grad
.
state
_weight_g
rad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
true
,
false
,
frameSize
,
frameSize
,
batchSize
,
1
,
context
,
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
resetOutputValue
,
frameSize
,
grad
.
gateGrad
+
frameSize
*
2
,
value
.
reset_output_value
,
frame_size
,
frameSize
*
3
,
1
,
grad
.
stateWeightGrad
,
frameSize
);
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
}
}
}
}
detail
::
backward_reset_grad
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
,
detail
::
backward_reset_grad
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
,
grad
,
frame
Size
,
batchS
ize
,
active_gate
);
grad
,
frame
_size
,
batch_s
ize
,
active_gate
);
if
(
grad
.
prev
OutGrad
&&
value
.
prevOutV
alue
)
{
if
(
grad
.
prev
_out_grad
&&
value
.
prev_out_v
alue
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
*
2
,
1
,
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
*
2
,
1
,
grad
.
gate
Grad
,
frameSize
*
3
,
value
.
gateWeight
,
frameS
ize
*
2
,
1
,
grad
.
gate
_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_s
ize
*
2
,
1
,
grad
.
prev
OutGrad
,
frameS
ize
);
grad
.
prev
_out_grad
,
frame_s
ize
);
if
(
grad
.
gate
WeightG
rad
)
{
if
(
grad
.
gate
_weight_g
rad
)
{
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
math
::
gemm
<
platform
::
CPUPlace
,
T
>
(
context
,
true
,
false
,
frame
Size
,
frameSize
*
2
,
batchS
ize
,
1
,
context
,
true
,
false
,
frame
_size
,
frame_size
*
2
,
batch_s
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
grad
.
gateGrad
,
frameS
ize
*
3
,
1
,
value
.
prev
_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_s
ize
*
3
,
1
,
grad
.
gate
WeightGrad
,
frameS
ize
*
2
);
grad
.
gate
_weight_grad
,
frame_s
ize
*
2
);
}
}
}
}
#endif
#endif
...
...
paddle/operators/math/gru_compute.cu
浏览文件 @
3e552cdc
...
@@ -21,66 +21,66 @@ namespace math {
...
@@ -21,66 +21,66 @@ namespace math {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitFunctor
<
platform
::
GPUPlace
,
T
>
{
struct
GRUUnitFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
hl_gru_value
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
auto
stream
=
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
}
else
{
threads
=
dim3
(
32
,
32
);
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
32
-
1
)
/
32
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
32
-
1
)
/
32
);
}
}
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
*
2
,
frameS
ize
,
1
,
context
,
false
,
false
,
batch
_size
,
frame_size
*
2
,
frame_s
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
value
.
gateWeight
,
frameSize
*
2
,
1
,
value
.
prev
_out_value
,
frame_size
,
value
.
gate_weight
,
frame_size
*
2
,
value
.
gateValue
,
frameS
ize
*
3
);
1
,
value
.
gate_value
,
frame_s
ize
*
3
);
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
/* is
B
atch= */
false
,
/* is
_b
atch= */
false
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
V
alue
,
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
reset
OutputValue
,
value
.
prevOutValue
,
frameSize
,
batchS
ize
,
value
.
reset
_output_value
,
value
.
prev_out_value
,
frame_s
ize
,
active_gate
);
batch_size
,
active_gate
);
}
else
{
}
else
{
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
detail
::
KeGruForwardResetOutput
<
detail
::
forward
::
gru_resetOutput
<
T
>
,
/* is
B
atch= */
true
,
/* is
_b
atch= */
true
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
V
alue
,
detail
::
forward
::
gru_resetOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
reset
OutputValue
,
value
.
prevOutValue
,
frameSize
,
batchS
ize
,
value
.
reset
_output_value
,
value
.
prev_out_value
,
frame_s
ize
,
active_gate
);
batch_size
,
active_gate
);
}
}
if
(
value
.
prev
OutV
alue
)
{
if
(
value
.
prev
_out_v
alue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
false
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
context
,
false
,
false
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
value
.
reset
OutputValue
,
frameSize
,
value
.
stateWeight
,
frameSize
,
1
,
value
.
reset
_output_value
,
frame_size
,
value
.
state_weight
,
frame_size
,
value
.
gateValue
+
frameSize
*
2
,
frameS
ize
*
3
);
1
,
value
.
gate_value
+
frame_size
*
2
,
frame_s
ize
*
3
);
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
/* is
B
atch= */
false
,
/* is
_b
atch= */
false
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
V
alue
,
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
prev
OutValue
,
value
.
outputValue
,
frameSize
,
batchS
ize
,
value
.
prev
_out_value
,
value
.
output_value
,
frame_size
,
batch_s
ize
,
active_node
);
active_node
);
}
else
{
}
else
{
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
detail
::
KeGruForwardFinalOutput
<
detail
::
forward
::
gru_finalOutput
<
T
>
,
/* is
B
atch= */
true
,
/* is
_b
atch= */
true
,
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
V
alue
,
detail
::
forward
::
gru_finalOutput
<
T
>
(),
value
.
gate
_v
alue
,
value
.
prev
OutValue
,
value
.
outputValue
,
frameSize
,
batchS
ize
,
value
.
prev
_out_value
,
value
.
output_value
,
frame_size
,
batch_s
ize
,
active_node
);
active_node
);
}
}
}
}
...
@@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> {
...
@@ -89,80 +89,82 @@ struct GRUUnitFunctor<platform::GPUPlace, T> {
template
<
typename
T
>
template
<
typename
T
>
struct
GRUUnitGradFunctor
<
platform
::
GPUPlace
,
T
>
{
struct
GRUUnitGradFunctor
<
platform
::
GPUPlace
,
T
>
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
batchSize
,
activation_mode_t
active_node
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
)
{
activation_mode_t
active_gate
)
{
auto
stream
=
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
}
else
{
threads
=
dim3
(
32
,
32
);
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
32
-
1
)
/
32
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
32
-
1
)
/
32
);
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruBackwardStateGrad
<
detail
::
KeGruBackwardStateGrad
<
detail
::
backward
::
gru_stateGrad
<
T
>
,
detail
::
backward
::
gru_stateGrad
<
T
>
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
_value
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
outputGrad
,
frameSize
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
batchS
ize
,
active_node
);
grad
.
output_grad
,
frame_size
,
batch_s
ize
,
active_node
);
}
else
{
}
else
{
detail
::
KeGruBackwardStateGrad
<
detail
::
KeGruBackwardStateGrad
<
detail
::
backward
::
gru_stateGrad
<
T
>
,
detail
::
backward
::
gru_stateGrad
<
T
>
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
detail
::
backward
::
gru_stateGrad
<
T
>
(),
value
.
gate
_value
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
outputGrad
,
frameSize
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
batchS
ize
,
active_node
);
grad
.
output_grad
,
frame_size
,
batch_s
ize
,
active_node
);
}
}
if
(
value
.
prev
OutValue
&&
grad
.
prevOutG
rad
)
{
if
(
value
.
prev
_out_value
&&
grad
.
prev_out_g
rad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
,
1
,
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
,
1
,
grad
.
gate
Grad
+
frameSize
*
2
,
frameSize
*
3
,
value
.
stateW
eight
,
grad
.
gate
_grad
+
frame_size
*
2
,
frame_size
*
3
,
value
.
state_w
eight
,
frame
Size
,
0
,
grad
.
resetOutputGrad
,
frameS
ize
);
frame
_size
,
0
,
grad
.
reset_output_grad
,
frame_s
ize
);
if
(
grad
.
state
WeightG
rad
)
{
if
(
grad
.
state
_weight_g
rad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
true
,
false
,
frameSize
,
frameSize
,
batchSize
,
1
,
context
,
true
,
false
,
frame_size
,
frame_size
,
batch_size
,
1
,
value
.
resetOutputValue
,
frameSize
,
grad
.
gateGrad
+
frameSize
*
2
,
value
.
reset_output_value
,
frame_size
,
frameSize
*
3
,
1
,
grad
.
stateWeightGrad
,
frameSize
);
grad
.
gate_grad
+
frame_size
*
2
,
frame_size
*
3
,
1
,
grad
.
state_weight_grad
,
frame_size
);
}
}
}
}
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
detail
::
KeGruBackwardResetGrad
<
detail
::
KeGruBackwardResetGrad
<
detail
::
backward
::
gru_resetGrad
<
T
>
,
detail
::
backward
::
gru_resetGrad
<
T
>
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
_value
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
resetOutputGrad
,
frameSize
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
batchS
ize
,
active_gate
);
grad
.
reset_output_grad
,
frame_size
,
batch_s
ize
,
active_gate
);
}
else
{
}
else
{
detail
::
KeGruBackwardResetGrad
<
detail
::
KeGruBackwardResetGrad
<
detail
::
backward
::
gru_resetGrad
<
T
>
,
detail
::
backward
::
gru_resetGrad
<
T
>
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
Value
,
grad
.
gateGrad
,
detail
::
backward
::
gru_resetGrad
<
T
>
(),
value
.
gate
_value
,
value
.
prevOutValue
,
grad
.
prevOutGrad
,
grad
.
resetOutputGrad
,
frameSize
,
grad
.
gate_grad
,
value
.
prev_out_value
,
grad
.
prev_out_grad
,
batchS
ize
,
active_gate
);
grad
.
reset_output_grad
,
frame_size
,
batch_s
ize
,
active_gate
);
}
}
if
(
grad
.
prev
OutGrad
&&
value
.
prevOutV
alue
)
{
if
(
grad
.
prev
_out_grad
&&
value
.
prev_out_v
alue
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
false
,
true
,
batch
Size
,
frameSize
,
frameS
ize
*
2
,
1
,
context
,
false
,
true
,
batch
_size
,
frame_size
,
frame_s
ize
*
2
,
1
,
grad
.
gate
Grad
,
frameSize
*
3
,
value
.
gateWeight
,
frameS
ize
*
2
,
1
,
grad
.
gate
_grad
,
frame_size
*
3
,
value
.
gate_weight
,
frame_s
ize
*
2
,
1
,
grad
.
prev
OutGrad
,
frameS
ize
);
grad
.
prev
_out_grad
,
frame_s
ize
);
if
(
grad
.
gate
WeightG
rad
)
{
if
(
grad
.
gate
_weight_g
rad
)
{
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
math
::
gemm
<
platform
::
GPUPlace
,
T
>
(
context
,
true
,
false
,
frame
Size
,
frameSize
*
2
,
batchS
ize
,
1
,
context
,
true
,
false
,
frame
_size
,
frame_size
*
2
,
batch_s
ize
,
1
,
value
.
prev
OutValue
,
frameSize
,
grad
.
gateGrad
,
frameS
ize
*
3
,
1
,
value
.
prev
_out_value
,
frame_size
,
grad
.
gate_grad
,
frame_s
ize
*
3
,
1
,
grad
.
gate
WeightGrad
,
frameS
ize
*
2
);
grad
.
gate
_weight_grad
,
frame_s
ize
*
2
);
}
}
}
}
}
}
...
...
paddle/operators/math/gru_compute.h
浏览文件 @
3e552cdc
...
@@ -22,28 +22,28 @@ namespace math {
...
@@ -22,28 +22,28 @@ namespace math {
// TODO(guosheng): refine code style in gru_compute
// TODO(guosheng): refine code style in gru_compute
template
<
typename
T
>
template
<
typename
T
>
struct
hl_gru_value
{
struct
hl_gru_value
{
T
*
gate
W
eight
;
T
*
gate
_w
eight
;
T
*
state
W
eight
;
T
*
state
_w
eight
;
T
*
gate
V
alue
;
T
*
gate
_v
alue
;
T
*
reset
OutputV
alue
;
T
*
reset
_output_v
alue
;
T
*
output
V
alue
;
T
*
output
_v
alue
;
T
*
prev
OutV
alue
;
T
*
prev
_out_v
alue
;
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
hl_gru_grad
{
struct
hl_gru_grad
{
T
*
gate
WeightG
rad
;
T
*
gate
_weight_g
rad
;
T
*
state
WeightG
rad
;
T
*
state
_weight_g
rad
;
T
*
gate
G
rad
;
T
*
gate
_g
rad
;
T
*
reset
OutputG
rad
;
T
*
reset
_output_g
rad
;
T
*
output
G
rad
;
T
*
output
_g
rad
;
T
*
prev
OutG
rad
;
T
*
prev
_out_g
rad
;
};
};
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
struct
GRUUnitFunctor
{
struct
GRUUnitFunctor
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
hl_gru_value
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
);
activation_mode_t
active_gate
);
};
};
...
@@ -51,8 +51,9 @@ struct GRUUnitFunctor {
...
@@ -51,8 +51,9 @@ struct GRUUnitFunctor {
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
struct
GRUUnitGradFunctor
{
struct
GRUUnitGradFunctor
{
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
static
void
compute
(
const
platform
::
DeviceContext
&
context
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
frameSize
,
hl_gru_value
<
T
>
value
,
hl_gru_grad
<
T
>
grad
,
int
batchSize
,
activation_mode_t
active_node
,
int
frame_size
,
int
batch_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
);
activation_mode_t
active_gate
);
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录