Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e5b51c4d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e5b51c4d
编写于
12月 03, 2017
作者:
Q
qingqing01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make lstm_op follow google code style.
上级
d89061c3
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
505 addition
and
492 deletion
+505
-492
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+35
-35
paddle/operators/math/detail/lstm_cpu_kernel.h
paddle/operators/math/detail/lstm_cpu_kernel.h
+216
-210
paddle/operators/math/detail/lstm_gpu_kernel.h
paddle/operators/math/detail/lstm_gpu_kernel.h
+154
-151
paddle/operators/math/detail/lstm_kernel.h
paddle/operators/math/detail/lstm_kernel.h
+66
-62
paddle/operators/math/lstm_compute.cc
paddle/operators/math/lstm_compute.cc
+18
-18
paddle/operators/math/lstm_compute.h
paddle/operators/math/lstm_compute.h
+16
-16
未找到文件。
paddle/operators/lstm_op.h
浏览文件 @
e5b51c4d
...
...
@@ -73,15 +73,15 @@ class LSTMKernel : public framework::OpKernel<T> {
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
// the code style in LstmMetaValue will be updated later.
lstm_value
.
check
I
g
=
bias_data
+
4
*
frame_size
;
lstm_value
.
check
Fg
=
lstm_value
.
checkI
g
+
frame_size
;
lstm_value
.
check
Og
=
lstm_value
.
checkF
g
+
frame_size
;
lstm_value
.
check
_i
g
=
bias_data
+
4
*
frame_size
;
lstm_value
.
check
_fg
=
lstm_value
.
check_i
g
+
frame_size
;
lstm_value
.
check
_og
=
lstm_value
.
check_f
g
+
frame_size
;
}
else
{
lstm_value
.
check
I
g
=
nullptr
;
lstm_value
.
check
F
g
=
nullptr
;
lstm_value
.
check
O
g
=
nullptr
;
lstm_value
.
check
_i
g
=
nullptr
;
lstm_value
.
check
_f
g
=
nullptr
;
lstm_value
.
check
_o
g
=
nullptr
;
}
lstm_value
.
prev
StateV
alue
=
nullptr
;
lstm_value
.
prev
_state_v
alue
=
nullptr
;
Tensor
ordered_c0
;
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
if
(
cell_t0
)
{
...
...
@@ -90,7 +90,7 @@ class LSTMKernel : public framework::OpKernel<T> {
// to reorder.
ReorderInitState
<
Place
,
T
>
(
device_ctx
,
*
cell_t0
,
order
,
&
ordered_c0
,
true
);
lstm_value
.
prev
StateV
alue
=
ordered_c0
.
data
<
T
>
();
lstm_value
.
prev
_state_v
alue
=
ordered_c0
.
data
<
T
>
();
}
// Use the local variable as here.
...
...
@@ -140,14 +140,14 @@ class LSTMKernel : public framework::OpKernel<T> {
static_cast
<
T
>
(
1.0
));
}
lstm_value
.
gate
V
alue
=
gate_t
.
data
<
T
>
();
lstm_value
.
output
V
alue
=
out_t
.
data
<
T
>
();
lstm_value
.
state
V
alue
=
cell_t
.
data
<
T
>
();
lstm_value
.
state
ActiveV
alue
=
cell_pre_act_t
.
data
<
T
>
();
lstm_value
.
gate
_v
alue
=
gate_t
.
data
<
T
>
();
lstm_value
.
output
_v
alue
=
out_t
.
data
<
T
>
();
lstm_value
.
state
_v
alue
=
cell_t
.
data
<
T
>
();
lstm_value
.
state
_active_v
alue
=
cell_pre_act_t
.
data
<
T
>
();
math
::
LstmUnitFunctor
<
Place
,
T
>::
compute
(
device_ctx
,
lstm_value
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
lstm_value
.
prev
StateValue
=
lstm_value
.
stateV
alue
;
lstm_value
.
prev
_state_value
=
lstm_value
.
state_v
alue
;
}
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
...
...
@@ -214,13 +214,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math
::
LstmMetaValue
<
T
>
lstm_value
;
if
(
bias
&&
ctx
.
Attr
<
bool
>
(
"use_peepholes"
))
{
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
lstm_value
.
check
I
g
=
bias_data
+
4
*
frame_size
;
lstm_value
.
check
Fg
=
lstm_value
.
checkI
g
+
frame_size
;
lstm_value
.
check
Og
=
lstm_value
.
checkF
g
+
frame_size
;
lstm_value
.
check
_i
g
=
bias_data
+
4
*
frame_size
;
lstm_value
.
check
_fg
=
lstm_value
.
check_i
g
+
frame_size
;
lstm_value
.
check
_og
=
lstm_value
.
check_f
g
+
frame_size
;
}
else
{
lstm_value
.
check
I
g
=
nullptr
;
lstm_value
.
check
F
g
=
nullptr
;
lstm_value
.
check
O
g
=
nullptr
;
lstm_value
.
check
_i
g
=
nullptr
;
lstm_value
.
check
_f
g
=
nullptr
;
lstm_value
.
check
_o
g
=
nullptr
;
}
math
::
LstmMetaGrad
<
T
>
lstm_grad
;
...
...
@@ -231,13 +231,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
}
if
(
bias
&&
bias_g
&&
ctx
.
Attr
<
bool
>
(
"use_peepholes"
))
{
T
*
bias_g_data
=
bias_g
->
data
<
T
>
();
lstm_grad
.
check
IgG
rad
=
bias_g_data
+
4
*
frame_size
;
lstm_grad
.
check
FgGrad
=
lstm_grad
.
checkIgG
rad
+
frame_size
;
lstm_grad
.
check
OgGrad
=
lstm_grad
.
checkFgG
rad
+
frame_size
;
lstm_grad
.
check
_ig_g
rad
=
bias_g_data
+
4
*
frame_size
;
lstm_grad
.
check
_fg_grad
=
lstm_grad
.
check_ig_g
rad
+
frame_size
;
lstm_grad
.
check
_og_grad
=
lstm_grad
.
check_fg_g
rad
+
frame_size
;
}
else
{
lstm_grad
.
check
IgG
rad
=
nullptr
;
lstm_grad
.
check
FgG
rad
=
nullptr
;
lstm_grad
.
check
OgG
rad
=
nullptr
;
lstm_grad
.
check
_ig_g
rad
=
nullptr
;
lstm_grad
.
check
_fg_g
rad
=
nullptr
;
lstm_grad
.
check
_og_g
rad
=
nullptr
;
}
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
...
...
@@ -276,26 +276,26 @@ class LSTMGradKernel : public framework::OpKernel<T> {
Tensor
gate
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
cell
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act
=
batch_cell_pre_act
->
Slice
(
bstart
,
bend
);
lstm_value
.
gate
V
alue
=
gate
.
data
<
T
>
();
lstm_value
.
state
V
alue
=
cell
.
data
<
T
>
();
lstm_value
.
state
ActiveV
alue
=
cell_pre_act
.
data
<
T
>
();
lstm_value
.
gate
_v
alue
=
gate
.
data
<
T
>
();
lstm_value
.
state
_v
alue
=
cell
.
data
<
T
>
();
lstm_value
.
state
_active_v
alue
=
cell_pre_act
.
data
<
T
>
();
Tensor
out_g
=
batch_hidden_g
.
Slice
(
bstart
,
bend
);
Tensor
gate_g
=
batch_gate_g
.
Slice
(
bstart
,
bend
);
Tensor
cell_g
=
batch_cell_g
.
Slice
(
bstart
,
bend
);
lstm_grad
.
state
G
rad
=
cell_g
.
data
<
T
>
();
lstm_grad
.
gate
G
rad
=
gate_g
.
data
<
T
>
();
lstm_grad
.
output
G
rad
=
out_g
.
data
<
T
>
();
lstm_grad
.
state
_g
rad
=
cell_g
.
data
<
T
>
();
lstm_grad
.
gate
_g
rad
=
gate_g
.
data
<
T
>
();
lstm_grad
.
output
_g
rad
=
out_g
.
data
<
T
>
();
if
(
n
>
0
)
{
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
Tensor
cell_pre
=
batch_cell
.
Slice
(
bstart_pre
,
bstart
);
Tensor
cell_pre_g
=
batch_cell_g
.
Slice
(
bstart_pre
,
bstart
);
lstm_value
.
prev
StateV
alue
=
cell_pre
.
data
<
T
>
();
lstm_grad
.
prev
StateG
rad
=
cell_pre_g
.
data
<
T
>
();
lstm_value
.
prev
_state_v
alue
=
cell_pre
.
data
<
T
>
();
lstm_grad
.
prev
_state_g
rad
=
cell_pre_g
.
data
<
T
>
();
}
else
{
lstm_value
.
prev
StateV
alue
=
c0
?
ordered_c0
.
data
<
T
>
()
:
nullptr
;
lstm_grad
.
prev
StateG
rad
=
c0_g
?
ordered_c0_g
.
data
<
T
>
()
:
nullptr
;
lstm_value
.
prev
_state_v
alue
=
c0
?
ordered_c0
.
data
<
T
>
()
:
nullptr
;
lstm_grad
.
prev
_state_g
rad
=
c0_g
?
ordered_c0_g
.
data
<
T
>
()
:
nullptr
;
}
int
cur_batch_size
=
bend
-
bstart
;
...
...
paddle/operators/math/detail/lstm_cpu_kernel.h
浏览文件 @
e5b51c4d
...
...
@@ -26,278 +26,284 @@ namespace detail {
template
<
class
T
,
class
Op
>
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
S
ize
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
T
r
ValueI
n
;
T
r
ValueI
g
;
T
r
ValueF
g
;
T
r
ValueO
g
;
T
r
C
heckI
;
T
r
C
heckF
;
T
r
C
heckO
;
T
r
S
tate
;
T
r
PrevS
tate
=
0
;
T
r
StateA
tv
;
T
r
O
ut
;
T
*
value
In
=
value
.
gateV
alue
;
T
*
value
Ig
=
value
.
gateValue
+
frameS
ize
;
T
*
value
Fg
=
value
.
gateValue
+
frameS
ize
*
2
;
T
*
value
Og
=
value
.
gateValue
+
frameS
ize
*
3
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
r
ValueIn
=
valueI
n
[
i
];
r
ValueIg
=
valueI
g
[
i
];
r
ValueFg
=
valueF
g
[
i
];
r
ValueOg
=
valueO
g
[
i
];
r
CheckI
=
value
.
checkIg
?
value
.
checkI
g
[
i
]
:
0
;
r
CheckF
=
value
.
checkFg
?
value
.
checkF
g
[
i
]
:
0
;
r
CheckO
=
value
.
checkOg
?
value
.
checkO
g
[
i
]
:
0
;
if
(
value
.
prev
StateV
alue
)
{
r
PrevState
=
value
.
prevStateV
alue
[
i
];
T
r
_value_i
n
;
T
r
_value_i
g
;
T
r
_value_f
g
;
T
r
_value_o
g
;
T
r
_c
heckI
;
T
r
_c
heckF
;
T
r
_c
heckO
;
T
r
_s
tate
;
T
r
_prev_s
tate
=
0
;
T
r
_state_a
tv
;
T
r
_o
ut
;
T
*
value
_in
=
value
.
gate_v
alue
;
T
*
value
_ig
=
value
.
gate_value
+
frame_s
ize
;
T
*
value
_fg
=
value
.
gate_value
+
frame_s
ize
*
2
;
T
*
value
_og
=
value
.
gate_value
+
frame_s
ize
*
3
;
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
_value_in
=
value_i
n
[
i
];
r
_value_ig
=
value_i
g
[
i
];
r
_value_fg
=
value_f
g
[
i
];
r
_value_og
=
value_o
g
[
i
];
r
_checkI
=
value
.
check_ig
?
value
.
check_i
g
[
i
]
:
0
;
r
_checkF
=
value
.
check_fg
?
value
.
check_f
g
[
i
]
:
0
;
r
_checkO
=
value
.
check_og
?
value
.
check_o
g
[
i
]
:
0
;
if
(
value
.
prev
_state_v
alue
)
{
r
_prev_state
=
value
.
prev_state_v
alue
[
i
];
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
valueIn
[
i
]
=
rValueIn
;
valueIg
[
i
]
=
rValueIg
;
valueFg
[
i
]
=
rValueFg
;
valueOg
[
i
]
=
rValueOg
;
value
.
stateValue
[
i
]
=
rState
;
value
.
stateActiveValue
[
i
]
=
rStateAtv
;
value
.
outputValue
[
i
]
=
rOut
;
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_prev_state
,
r_state
,
r_state_atv
,
r_out
,
r_checkI
,
r_checkF
,
r_checkO
,
active_node
,
active_gate
,
active_state
);
value_in
[
i
]
=
r_value_in
;
value_ig
[
i
]
=
r_value_ig
;
value_fg
[
i
]
=
r_value_fg
;
value_og
[
i
]
=
r_value_og
;
value
.
state_value
[
i
]
=
r_state
;
value
.
state_active_value
[
i
]
=
r_state_atv
;
value
.
output_value
[
i
]
=
r_out
;
}
}
template
<
class
T
,
class
Op
>
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
S
ize
,
LstmMetaGrad
<
T
>
grad
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
T
r
ValueI
n
;
T
r
ValueI
g
;
T
r
ValueF
g
;
T
r
ValueO
g
;
T
r
GradI
n
;
T
r
GradI
g
;
T
r
GradF
g
;
T
r
GradO
g
;
T
r
PrevS
tate
=
0
;
T
r
PrevStateG
rad
;
T
r
S
tate
;
T
r
StateG
rad
;
T
r
StateA
tv
;
T
r
OutputG
rad
;
T
r
C
heckI
;
T
r
C
heckF
;
T
r
C
heckO
;
T
r
C
heckIGrad
;
T
r
C
heckFGrad
;
T
r
C
heckOGrad
;
T
*
value
In
=
value
.
gateV
alue
;
T
*
value
Ig
=
value
.
gateValue
+
frameS
ize
;
T
*
value
Fg
=
value
.
gateValue
+
frameS
ize
*
2
;
T
*
value
Og
=
value
.
gateValue
+
frameS
ize
*
3
;
T
*
grad
In
=
grad
.
gateG
rad
;
T
*
grad
Ig
=
grad
.
gateGrad
+
frameS
ize
;
T
*
grad
Fg
=
grad
.
gateGrad
+
frameS
ize
*
2
;
T
*
grad
Og
=
grad
.
gateGrad
+
frameS
ize
*
3
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
r
ValueIn
=
valueI
n
[
i
];
r
ValueIg
=
valueI
g
[
i
];
r
ValueFg
=
valueF
g
[
i
];
r
ValueOg
=
valueO
g
[
i
];
r
CheckI
=
value
.
checkIg
?
value
.
checkI
g
[
i
]
:
0
;
r
CheckF
=
value
.
checkFg
?
value
.
checkF
g
[
i
]
:
0
;
r
CheckO
=
value
.
checkOg
?
value
.
checkO
g
[
i
]
:
0
;
r
State
=
value
.
stateV
alue
[
i
];
r
StateAtv
=
value
.
stateActiveV
alue
[
i
];
r
OutputGrad
=
grad
.
outputG
rad
[
i
];
r
StateGrad
=
grad
.
stateG
rad
[
i
];
if
(
value
.
prev
StateV
alue
)
{
r
PrevState
=
value
.
prevStateV
alue
[
i
];
T
r
_value_i
n
;
T
r
_value_i
g
;
T
r
_value_f
g
;
T
r
_value_o
g
;
T
r
_grad_i
n
;
T
r
_grad_i
g
;
T
r
_grad_f
g
;
T
r
_grad_o
g
;
T
r
_prev_s
tate
=
0
;
T
r
_prev_state_g
rad
;
T
r
_s
tate
;
T
r
_state_g
rad
;
T
r
_state_a
tv
;
T
r
_output_g
rad
;
T
r
_c
heckI
;
T
r
_c
heckF
;
T
r
_c
heckO
;
T
r
_c
heckIGrad
;
T
r
_c
heckFGrad
;
T
r
_c
heckOGrad
;
T
*
value
_in
=
value
.
gate_v
alue
;
T
*
value
_ig
=
value
.
gate_value
+
frame_s
ize
;
T
*
value
_fg
=
value
.
gate_value
+
frame_s
ize
*
2
;
T
*
value
_og
=
value
.
gate_value
+
frame_s
ize
*
3
;
T
*
grad
_in
=
grad
.
gate_g
rad
;
T
*
grad
_ig
=
grad
.
gate_grad
+
frame_s
ize
;
T
*
grad
_fg
=
grad
.
gate_grad
+
frame_s
ize
*
2
;
T
*
grad
_og
=
grad
.
gate_grad
+
frame_s
ize
*
3
;
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
_value_in
=
value_i
n
[
i
];
r
_value_ig
=
value_i
g
[
i
];
r
_value_fg
=
value_f
g
[
i
];
r
_value_og
=
value_o
g
[
i
];
r
_checkI
=
value
.
check_ig
?
value
.
check_i
g
[
i
]
:
0
;
r
_checkF
=
value
.
check_fg
?
value
.
check_f
g
[
i
]
:
0
;
r
_checkO
=
value
.
check_og
?
value
.
check_o
g
[
i
]
:
0
;
r
_state
=
value
.
state_v
alue
[
i
];
r
_state_atv
=
value
.
state_active_v
alue
[
i
];
r
_output_grad
=
grad
.
output_g
rad
[
i
];
r
_state_grad
=
grad
.
state_g
rad
[
i
];
if
(
value
.
prev
_state_v
alue
)
{
r
_prev_state
=
value
.
prev_state_v
alue
[
i
];
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
active_node
,
active_gate
,
active_state
);
gradIn
[
i
]
=
rGradIn
;
gradIg
[
i
]
=
rGradIg
;
gradFg
[
i
]
=
rGradFg
;
gradOg
[
i
]
=
rGradOg
;
grad
.
stateGrad
[
i
]
=
rStateGrad
;
if
(
grad
.
prevStateGrad
)
grad
.
prevStateGrad
[
i
]
=
rPrevStateGrad
;
if
(
value
.
prevStateValue
)
{
if
(
grad
.
checkIgGrad
)
grad
.
checkIgGrad
[
i
]
+=
rCheckIGrad
;
if
(
grad
.
checkFgGrad
)
grad
.
checkFgGrad
[
i
]
+=
rCheckFGrad
;
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_grad_in
,
r_grad_ig
,
r_grad_fg
,
r_grad_og
,
r_prev_state
,
r_prev_state_grad
,
r_state
,
r_state_grad
,
r_state_atv
,
r_output_grad
,
r_checkI
,
r_checkF
,
r_checkO
,
r_checkIGrad
,
r_checkFGrad
,
r_checkOGrad
,
active_node
,
active_gate
,
active_state
);
grad_in
[
i
]
=
r_grad_in
;
grad_ig
[
i
]
=
r_grad_ig
;
grad_fg
[
i
]
=
r_grad_fg
;
grad_og
[
i
]
=
r_grad_og
;
grad
.
state_grad
[
i
]
=
r_state_grad
;
if
(
grad
.
prev_state_grad
)
grad
.
prev_state_grad
[
i
]
=
r_prev_state_grad
;
if
(
value
.
prev_state_value
)
{
if
(
grad
.
check_ig_grad
)
grad
.
check_ig_grad
[
i
]
+=
r_checkIGrad
;
if
(
grad
.
check_fg_grad
)
grad
.
check_fg_grad
[
i
]
+=
r_checkFGrad
;
}
if
(
grad
.
check
OgGrad
)
grad
.
checkOgGrad
[
i
]
+=
rC
heckOGrad
;
if
(
grad
.
check
_og_grad
)
grad
.
check_og_grad
[
i
]
+=
r_c
heckOGrad
;
}
}
template
<
class
T
,
class
Op
>
void
avx_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frameSize
,
void
avx_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
#ifdef __AVX__
__m256
r
ValueI
n
;
__m256
r
ValueI
g
;
__m256
r
ValueF
g
;
__m256
r
ValueO
g
;
__m256
r
C
heckI
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckF
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckO
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
S
tate
;
__m256
r
PrevS
tate
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
StateA
tv
;
__m256
r
O
ut
;
__m256
*
value
In
=
(
__m256
*
)
value
.
gateV
alue
;
__m256
*
value
Ig
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
);
__m256
*
value
Fg
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
*
2
);
__m256
*
value
Og
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
*
3
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
r
ValueIn
=
valueI
n
[
i
];
r
ValueIg
=
valueI
g
[
i
];
r
ValueFg
=
valueF
g
[
i
];
r
ValueOg
=
valueO
g
[
i
];
if
(
value
.
check
I
g
)
{
r
CheckI
=
((
__m256
*
)
value
.
checkI
g
)[
i
];
r
CheckF
=
((
__m256
*
)
value
.
checkF
g
)[
i
];
r
CheckO
=
((
__m256
*
)
value
.
checkO
g
)[
i
];
__m256
r
_value_i
n
;
__m256
r
_value_i
g
;
__m256
r
_value_f
g
;
__m256
r
_value_o
g
;
__m256
r
_c
heckI
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckF
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckO
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_s
tate
;
__m256
r
_prev_s
tate
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_state_a
tv
;
__m256
r
_o
ut
;
__m256
*
value
_in
=
(
__m256
*
)
value
.
gate_v
alue
;
__m256
*
value
_ig
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
);
__m256
*
value
_fg
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
*
2
);
__m256
*
value
_og
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
*
3
);
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
_value_in
=
value_i
n
[
i
];
r
_value_ig
=
value_i
g
[
i
];
r
_value_fg
=
value_f
g
[
i
];
r
_value_og
=
value_o
g
[
i
];
if
(
value
.
check
_i
g
)
{
r
_checkI
=
((
__m256
*
)
value
.
check_i
g
)[
i
];
r
_checkF
=
((
__m256
*
)
value
.
check_f
g
)[
i
];
r
_checkO
=
((
__m256
*
)
value
.
check_o
g
)[
i
];
}
if
(
value
.
prev
StateV
alue
)
{
r
PrevState
=
((
__m256
*
)
value
.
prevStateV
alue
)[
i
];
if
(
value
.
prev
_state_v
alue
)
{
r
_prev_state
=
((
__m256
*
)
value
.
prev_state_v
alue
)[
i
];
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
valueIn
[
i
]
=
rValueIn
;
valueIg
[
i
]
=
rValueIg
;
valueFg
[
i
]
=
rValueFg
;
valueOg
[
i
]
=
rValueOg
;
((
__m256
*
)
value
.
stateValue
)[
i
]
=
rState
;
((
__m256
*
)
value
.
stateActiveValue
)[
i
]
=
rStateAtv
;
((
__m256
*
)
value
.
outputValue
)[
i
]
=
rOut
;
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_prev_state
,
r_state
,
r_state_atv
,
r_out
,
r_checkI
,
r_checkF
,
r_checkO
,
active_node
,
active_gate
,
active_state
);
value_in
[
i
]
=
r_value_in
;
value_ig
[
i
]
=
r_value_ig
;
value_fg
[
i
]
=
r_value_fg
;
value_og
[
i
]
=
r_value_og
;
((
__m256
*
)
value
.
state_value
)[
i
]
=
r_state
;
((
__m256
*
)
value
.
state_active_value
)[
i
]
=
r_state_atv
;
((
__m256
*
)
value
.
output_value
)[
i
]
=
r_out
;
}
#endif
}
template
<
class
T
,
class
Op
>
void
avx_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
S
ize
,
LstmMetaGrad
<
T
>
grad
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
#ifdef __AVX__
__m256
r
ValueI
n
;
__m256
r
ValueI
g
;
__m256
r
ValueF
g
;
__m256
r
ValueO
g
;
__m256
r
GradI
n
;
__m256
r
GradI
g
;
__m256
r
GradF
g
;
__m256
r
GradO
g
;
__m256
r
PrevS
tate
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
PrevStateG
rad
;
__m256
r
StateG
rad
;
__m256
r
S
tate
;
__m256
r
StateA
tv
;
__m256
r
OutputG
rad
;
__m256
r
C
heckI
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckF
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckO
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckIGrad
;
__m256
r
C
heckFGrad
;
__m256
r
C
heckOGrad
;
__m256
*
value
In
=
(
__m256
*
)
value
.
gateV
alue
;
__m256
*
value
Ig
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
);
__m256
*
value
Fg
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
*
2
);
__m256
*
value
Og
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
*
3
);
__m256
*
grad
In
=
(
__m256
*
)
grad
.
gateG
rad
;
__m256
*
grad
Ig
=
(
__m256
*
)(
grad
.
gateGrad
+
frameS
ize
);
__m256
*
grad
Fg
=
(
__m256
*
)(
grad
.
gateGrad
+
frameS
ize
*
2
);
__m256
*
grad
Og
=
(
__m256
*
)(
grad
.
gateGrad
+
frameS
ize
*
3
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
r
ValueIn
=
valueI
n
[
i
];
r
ValueIg
=
valueI
g
[
i
];
r
ValueFg
=
valueF
g
[
i
];
r
ValueOg
=
valueO
g
[
i
];
if
(
value
.
check
I
g
)
{
r
CheckI
=
((
__m256
*
)
value
.
checkI
g
)[
i
];
r
CheckF
=
((
__m256
*
)
value
.
checkF
g
)[
i
];
r
CheckO
=
((
__m256
*
)
value
.
checkO
g
)[
i
];
__m256
r
_value_i
n
;
__m256
r
_value_i
g
;
__m256
r
_value_f
g
;
__m256
r
_value_o
g
;
__m256
r
_grad_i
n
;
__m256
r
_grad_i
g
;
__m256
r
_grad_f
g
;
__m256
r
_grad_o
g
;
__m256
r
_prev_s
tate
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_state_g
rad
;
__m256
r
_state_g
rad
;
__m256
r
_s
tate
;
__m256
r
_state_a
tv
;
__m256
r
_output_g
rad
;
__m256
r
_c
heckI
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckF
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckO
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckIGrad
;
__m256
r
_c
heckFGrad
;
__m256
r
_c
heckOGrad
;
__m256
*
value
_in
=
(
__m256
*
)
value
.
gate_v
alue
;
__m256
*
value
_ig
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
);
__m256
*
value
_fg
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
*
2
);
__m256
*
value
_og
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
*
3
);
__m256
*
grad
_in
=
(
__m256
*
)
grad
.
gate_g
rad
;
__m256
*
grad
_ig
=
(
__m256
*
)(
grad
.
gate_grad
+
frame_s
ize
);
__m256
*
grad
_fg
=
(
__m256
*
)(
grad
.
gate_grad
+
frame_s
ize
*
2
);
__m256
*
grad
_og
=
(
__m256
*
)(
grad
.
gate_grad
+
frame_s
ize
*
3
);
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
_value_in
=
value_i
n
[
i
];
r
_value_ig
=
value_i
g
[
i
];
r
_value_fg
=
value_f
g
[
i
];
r
_value_og
=
value_o
g
[
i
];
if
(
value
.
check
_i
g
)
{
r
_checkI
=
((
__m256
*
)
value
.
check_i
g
)[
i
];
r
_checkF
=
((
__m256
*
)
value
.
check_f
g
)[
i
];
r
_checkO
=
((
__m256
*
)
value
.
check_o
g
)[
i
];
}
r
State
=
((
__m256
*
)
value
.
stateV
alue
)[
i
];
r
StateAtv
=
((
__m256
*
)
value
.
stateActiveV
alue
)[
i
];
r
OutputGrad
=
((
__m256
*
)
grad
.
outputG
rad
)[
i
];
r
StateGrad
=
((
__m256
*
)
grad
.
stateG
rad
)[
i
];
if
(
value
.
prev
StateV
alue
)
{
r
PrevState
=
((
__m256
*
)
value
.
prevStateV
alue
)[
i
];
r
_state
=
((
__m256
*
)
value
.
state_v
alue
)[
i
];
r
_state_atv
=
((
__m256
*
)
value
.
state_active_v
alue
)[
i
];
r
_output_grad
=
((
__m256
*
)
grad
.
output_g
rad
)[
i
];
r
_state_grad
=
((
__m256
*
)
grad
.
state_g
rad
)[
i
];
if
(
value
.
prev
_state_v
alue
)
{
r
_prev_state
=
((
__m256
*
)
value
.
prev_state_v
alue
)[
i
];
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
active_node
,
active_gate
,
active_state
);
gradIn
[
i
]
=
rGradIn
;
gradIg
[
i
]
=
rGradIg
;
gradFg
[
i
]
=
rGradFg
;
gradOg
[
i
]
=
rGradOg
;
((
__m256
*
)
grad
.
stateGrad
)[
i
]
=
rStateGrad
;
if
(
grad
.
prevStateGrad
)
((
__m256
*
)
grad
.
prevStateGrad
)[
i
]
=
rPrevStateGrad
;
if
(
value
.
prevStateValue
)
{
if
(
grad
.
checkIgGrad
)
((
__m256
*
)
grad
.
checkIgGrad
)[
i
]
+=
rCheckIGrad
;
if
(
grad
.
checkFgGrad
)
((
__m256
*
)
grad
.
checkFgGrad
)[
i
]
+=
rCheckFGrad
;
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_grad_in
,
r_grad_ig
,
r_grad_fg
,
r_grad_og
,
r_prev_state
,
r_prev_state_grad
,
r_state
,
r_state_grad
,
r_state_atv
,
r_output_grad
,
r_checkI
,
r_checkF
,
r_checkO
,
r_checkIGrad
,
r_checkFGrad
,
r_checkOGrad
,
active_node
,
active_gate
,
active_state
);
grad_in
[
i
]
=
r_grad_in
;
grad_ig
[
i
]
=
r_grad_ig
;
grad_fg
[
i
]
=
r_grad_fg
;
grad_og
[
i
]
=
r_grad_og
;
((
__m256
*
)
grad
.
state_grad
)[
i
]
=
r_state_grad
;
if
(
grad
.
prev_state_grad
)
((
__m256
*
)
grad
.
prev_state_grad
)[
i
]
=
r_prev_state_grad
;
if
(
value
.
prev_state_value
)
{
if
(
grad
.
check_ig_grad
)
((
__m256
*
)
grad
.
check_ig_grad
)[
i
]
+=
r_checkIGrad
;
if
(
grad
.
check_fg_grad
)
((
__m256
*
)
grad
.
check_fg_grad
)[
i
]
+=
r_checkFGrad
;
}
if
(
grad
.
check
OgGrad
)
((
__m256
*
)
grad
.
checkOgGrad
)[
i
]
+=
rC
heckOGrad
;
if
(
grad
.
check
_og_grad
)
((
__m256
*
)
grad
.
check_og_grad
)[
i
]
+=
r_c
heckOGrad
;
}
#endif
}
template
<
class
T
,
class
Op
>
void
cpu_lstm_forward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
S
ize
,
void
cpu_lstm_forward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
if
(
Op
::
avx
&&
!
(
frame
S
ize
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame
S
ize
,
active_node
,
if
(
Op
::
avx
&&
!
(
frame
_s
ize
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame
_s
ize
,
active_node
,
active_gate
,
active_state
);
}
else
{
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame
S
ize
,
active_node
,
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame
_s
ize
,
active_node
,
active_gate
,
active_state
);
}
}
template
<
class
T
,
class
Op
>
void
cpu_lstm_backward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
S
ize
,
activation_mode_t
active_node
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
if
(
Op
::
avx
&&
!
(
frame
S
ize
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame
S
ize
,
active_node
,
if
(
Op
::
avx
&&
!
(
frame
_s
ize
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame
_s
ize
,
active_node
,
active_gate
,
active_state
);
}
else
{
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame
Size
,
active_nod
e
,
active_gate
,
active_state
);
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame
_siz
e
,
active_
node
,
active_
gate
,
active_state
);
}
}
...
...
paddle/operators/math/detail/lstm_gpu_kernel.h
浏览文件 @
e5b51c4d
...
...
@@ -26,189 +26,192 @@ namespace math {
namespace
detail
{
/*
* threads(frame
PerBlock, batchPerB
lock)
* grid(frame
Blocks, batchB
locks)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
_blocks, batch_b
locks)
*/
template
<
class
T
,
class
Op
,
bool
is
B
atch
>
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
S
ize
,
int
batch
S
ize
,
activation_mode_t
active_node
,
template
<
class
T
,
class
Op
,
bool
is
_b
atch
>
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
_s
ize
,
int
batch
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
const
int
frame
I
dx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
int
batch
I
dx
=
0
;
if
(
is
B
atch
)
{
batch
I
dx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
value
.
gate
Value
+=
batchIdx
*
frameS
ize
*
4
;
value
.
output
Value
+=
batchIdx
*
frameS
ize
;
value
.
state
Value
+=
batchIdx
*
frameS
ize
;
value
.
state
ActiveValue
+=
batchIdx
*
frameS
ize
;
const
int
frame
_i
dx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
_i
dx
=
0
;
if
(
is
_b
atch
)
{
batch
_i
dx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
value
.
gate
_value
+=
batch_idx
*
frame_s
ize
*
4
;
value
.
output
_value
+=
batch_idx
*
frame_s
ize
;
value
.
state
_value
+=
batch_idx
*
frame_s
ize
;
value
.
state
_active_value
+=
batch_idx
*
frame_s
ize
;
}
T
r
S
tate
;
T
r
PrevS
tate
=
0
;
T
r
StateA
tv
;
T
r
O
ut
;
T
r
ValueI
n
;
T
r
ValueI
g
;
T
r
ValueF
g
;
T
r
ValueO
g
;
T
r
CheckI
=
value
.
checkIg
?
value
.
checkIg
[
frameI
dx
]
:
0
;
T
r
CheckF
=
value
.
checkFg
?
value
.
checkFg
[
frameI
dx
]
:
0
;
T
r
CheckO
=
value
.
checkOg
?
value
.
checkOg
[
frameI
dx
]
:
0
;
r
ValueIn
=
value
.
gateValue
[
frameI
dx
];
r
ValueIg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
];
r
ValueFg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
*
2
];
r
ValueOg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
*
3
];
if
(
value
.
prev
StateV
alue
)
{
if
(
is
Batch
)
value
.
prevStateValue
+=
batchIdx
*
frameS
ize
;
r
PrevState
=
value
.
prevStateValue
[
frameI
dx
];
T
r
_s
tate
;
T
r
_prev_s
tate
=
0
;
T
r
_state_a
tv
;
T
r
_o
ut
;
T
r
_value_i
n
;
T
r
_value_i
g
;
T
r
_value_f
g
;
T
r
_value_o
g
;
T
r
_checkI
=
value
.
check_ig
?
value
.
check_ig
[
frame_i
dx
]
:
0
;
T
r
_checkF
=
value
.
check_fg
?
value
.
check_fg
[
frame_i
dx
]
:
0
;
T
r
_checkO
=
value
.
check_og
?
value
.
check_og
[
frame_i
dx
]
:
0
;
r
_value_in
=
value
.
gate_value
[
frame_i
dx
];
r
_value_ig
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
];
r
_value_fg
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
r
_value_og
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
*
3
];
if
(
value
.
prev
_state_v
alue
)
{
if
(
is
_batch
)
value
.
prev_state_value
+=
batch_idx
*
frame_s
ize
;
r
_prev_state
=
value
.
prev_state_value
[
frame_i
dx
];
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_prev_state
,
r_state
,
r_state_atv
,
r_out
,
r_checkI
,
r_checkF
,
r_checkO
,
active_node
,
active_gate
,
active_state
);
value
.
gate
Value
[
frameIdx
]
=
rValueI
n
;
value
.
gate
Value
[
frameIdx
+
frameSize
]
=
rValueI
g
;
value
.
gate
Value
[
frameIdx
+
frameSize
*
2
]
=
rValueF
g
;
value
.
gate
Value
[
frameIdx
+
frameSize
*
3
]
=
rValueO
g
;
value
.
gate
_value
[
frame_idx
]
=
r_value_i
n
;
value
.
gate
_value
[
frame_idx
+
frame_size
]
=
r_value_i
g
;
value
.
gate
_value
[
frame_idx
+
frame_size
*
2
]
=
r_value_f
g
;
value
.
gate
_value
[
frame_idx
+
frame_size
*
3
]
=
r_value_o
g
;
value
.
state
Value
[
frameIdx
]
=
rS
tate
;
value
.
state
ActiveValue
[
frameIdx
]
=
rStateA
tv
;
value
.
output
Value
[
frameIdx
]
=
rO
ut
;
value
.
state
_value
[
frame_idx
]
=
r_s
tate
;
value
.
state
_active_value
[
frame_idx
]
=
r_state_a
tv
;
value
.
output
_value
[
frame_idx
]
=
r_o
ut
;
}
/*
* threads(frame
PerBlock, batchPerB
lock)
* grid(frame
Blocks, batchB
locks)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
_blocks, batch_b
locks)
*/
template
<
class
T
,
class
Op
,
bool
is
B
atch
>
template
<
class
T
,
class
Op
,
bool
is
_b
atch
>
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
S
ize
,
int
batch
S
ize
,
activation_mode_t
active_node
,
LstmMetaGrad
<
T
>
grad
,
int
frame
_s
ize
,
int
batch
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
const
int
frame
I
dx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
int
batch
I
dx
=
0
;
if
(
is
B
atch
)
{
batch
I
dx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
value
.
gate
Value
+=
batchIdx
*
frameS
ize
*
4
;
value
.
state
Value
+=
batchIdx
*
frameS
ize
;
value
.
state
ActiveValue
+=
batchIdx
*
frameS
ize
;
grad
.
gate
Grad
+=
batchIdx
*
frameS
ize
*
4
;
grad
.
state
Grad
+=
batchIdx
*
frameS
ize
;
grad
.
output
Grad
+=
batchIdx
*
frameS
ize
;
const
int
frame
_i
dx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
_i
dx
=
0
;
if
(
is
_b
atch
)
{
batch
_i
dx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
value
.
gate
_value
+=
batch_idx
*
frame_s
ize
*
4
;
value
.
state
_value
+=
batch_idx
*
frame_s
ize
;
value
.
state
_active_value
+=
batch_idx
*
frame_s
ize
;
grad
.
gate
_grad
+=
batch_idx
*
frame_s
ize
*
4
;
grad
.
state
_grad
+=
batch_idx
*
frame_s
ize
;
grad
.
output
_grad
+=
batch_idx
*
frame_s
ize
;
}
T
r
ValueI
n
;
T
r
ValueI
g
;
T
r
ValueF
g
;
T
r
ValueO
g
;
T
r
GradI
n
;
T
r
GradI
g
;
T
r
GradF
g
;
T
r
GradO
g
;
T
r
PrevS
tate
=
0
;
T
r
PrevStateG
rad
;
T
r
S
tate
;
T
r
StateG
rad
;
T
r
StateA
tv
;
T
r
OutputG
rad
;
T
r
CheckI
=
value
.
checkIg
?
value
.
checkIg
[
frameI
dx
]
:
0
;
T
r
CheckF
=
value
.
checkFg
?
value
.
checkFg
[
frameI
dx
]
:
0
;
T
r
CheckO
=
value
.
checkOg
?
value
.
checkOg
[
frameI
dx
]
:
0
;
T
r
C
heckIGrad
;
T
r
C
heckFGrad
;
T
r
C
heckOGrad
;
r
ValueIn
=
value
.
gateValue
[
frameI
dx
];
r
ValueIg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
];
r
ValueFg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
*
2
];
r
ValueOg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
*
3
];
r
State
=
value
.
stateValue
[
frameI
dx
];
r
StateAtv
=
value
.
stateActiveValue
[
frameI
dx
];
r
OutputGrad
=
grad
.
outputGrad
[
frameI
dx
];
r
StateGrad
=
grad
.
stateGrad
[
frameI
dx
];
if
(
value
.
prev
StateV
alue
)
{
if
(
is
Batch
)
value
.
prevStateValue
+=
batchIdx
*
frameS
ize
;
r
PrevState
=
value
.
prevStateValue
[
frameI
dx
];
T
r
_value_i
n
;
T
r
_value_i
g
;
T
r
_value_f
g
;
T
r
_value_o
g
;
T
r
_grad_i
n
;
T
r
_grad_i
g
;
T
r
_grad_f
g
;
T
r
_grad_o
g
;
T
r
_prev_s
tate
=
0
;
T
r
_prev_state_g
rad
;
T
r
_s
tate
;
T
r
_state_g
rad
;
T
r
_state_a
tv
;
T
r
_output_g
rad
;
T
r
_checkI
=
value
.
check_ig
?
value
.
check_ig
[
frame_i
dx
]
:
0
;
T
r
_checkF
=
value
.
check_fg
?
value
.
check_fg
[
frame_i
dx
]
:
0
;
T
r
_checkO
=
value
.
check_og
?
value
.
check_og
[
frame_i
dx
]
:
0
;
T
r
_c
heckIGrad
;
T
r
_c
heckFGrad
;
T
r
_c
heckOGrad
;
r
_value_in
=
value
.
gate_value
[
frame_i
dx
];
r
_value_ig
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
];
r
_value_fg
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
r
_value_og
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
*
3
];
r
_state
=
value
.
state_value
[
frame_i
dx
];
r
_state_atv
=
value
.
state_active_value
[
frame_i
dx
];
r
_output_grad
=
grad
.
output_grad
[
frame_i
dx
];
r
_state_grad
=
grad
.
state_grad
[
frame_i
dx
];
if
(
value
.
prev
_state_v
alue
)
{
if
(
is
_batch
)
value
.
prev_state_value
+=
batch_idx
*
frame_s
ize
;
r
_prev_state
=
value
.
prev_state_value
[
frame_i
dx
];
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
active_node
,
active_gate
,
active_state
);
grad
.
gateGrad
[
frameIdx
]
=
rGradIn
;
grad
.
gateGrad
[
frameIdx
+
frameSize
]
=
rGradIg
;
grad
.
gateGrad
[
frameIdx
+
frameSize
*
2
]
=
rGradFg
;
grad
.
gateGrad
[
frameIdx
+
frameSize
*
3
]
=
rGradOg
;
grad
.
stateGrad
[
frameIdx
]
=
rStateGrad
;
if
(
grad
.
prevStateGrad
)
{
if
(
isBatch
)
grad
.
prevStateGrad
+=
batchIdx
*
frameSize
;
grad
.
prevStateGrad
[
frameIdx
]
=
rPrevStateGrad
;
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_grad_in
,
r_grad_ig
,
r_grad_fg
,
r_grad_og
,
r_prev_state
,
r_prev_state_grad
,
r_state
,
r_state_grad
,
r_state_atv
,
r_output_grad
,
r_checkI
,
r_checkF
,
r_checkO
,
r_checkIGrad
,
r_checkFGrad
,
r_checkOGrad
,
active_node
,
active_gate
,
active_state
);
grad
.
gate_grad
[
frame_idx
]
=
r_grad_in
;
grad
.
gate_grad
[
frame_idx
+
frame_size
]
=
r_grad_ig
;
grad
.
gate_grad
[
frame_idx
+
frame_size
*
2
]
=
r_grad_fg
;
grad
.
gate_grad
[
frame_idx
+
frame_size
*
3
]
=
r_grad_og
;
grad
.
state_grad
[
frame_idx
]
=
r_state_grad
;
if
(
grad
.
prev_state_grad
)
{
if
(
is_batch
)
grad
.
prev_state_grad
+=
batch_idx
*
frame_size
;
grad
.
prev_state_grad
[
frame_idx
]
=
r_prev_state_grad
;
}
if
(
is
B
atch
)
{
if
(
value
.
prev
StateV
alue
)
{
if
(
grad
.
check
IgG
rad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check
IgGrad
+
frameI
dx
,
r
C
heckIGrad
);
if
(
grad
.
check
FgG
rad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check
FgGrad
+
frameI
dx
,
r
C
heckFGrad
);
if
(
is
_b
atch
)
{
if
(
value
.
prev
_state_v
alue
)
{
if
(
grad
.
check
_ig_g
rad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check
_ig_grad
+
frame_i
dx
,
r
_c
heckIGrad
);
if
(
grad
.
check
_fg_g
rad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check
_fg_grad
+
frame_i
dx
,
r
_c
heckFGrad
);
}
if
(
grad
.
checkOgGrad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
checkOgGrad
+
frameIdx
,
rCheckOGrad
);
if
(
grad
.
check_og_grad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check_og_grad
+
frame_idx
,
r_checkOGrad
);
}
else
{
if
(
value
.
prev
StateV
alue
)
{
if
(
grad
.
check
IgGrad
)
grad
.
checkIgGrad
[
frameIdx
]
+=
rC
heckIGrad
;
if
(
grad
.
check
FgGrad
)
grad
.
checkFgGrad
[
frameIdx
]
+=
rC
heckFGrad
;
if
(
value
.
prev
_state_v
alue
)
{
if
(
grad
.
check
_ig_grad
)
grad
.
check_ig_grad
[
frame_idx
]
+=
r_c
heckIGrad
;
if
(
grad
.
check
_fg_grad
)
grad
.
check_fg_grad
[
frame_idx
]
+=
r_c
heckFGrad
;
}
if
(
grad
.
check
OgGrad
)
grad
.
checkOgGrad
[
frameIdx
]
+=
rC
heckOGrad
;
if
(
grad
.
check
_og_grad
)
grad
.
check_og_grad
[
frame_idx
]
+=
r_c
heckOGrad
;
}
}
template
<
class
T
,
class
Op
>
void
gpu_lstm_forward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
LstmMetaValue
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
dim3
threads
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
if
(
batch
_s
ize
==
1
)
{
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
/* frame
PerBlock = 32 batchPerB
lock = 32 */
/* frame
_per_block = 32 batch_per_b
lock = 32 */
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
32
-
1
)
/
32
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
32
-
1
)
/
32
);
}
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
KeLstmForward
<
T
,
Op
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frame
Size
,
batchS
ize
,
active_node
,
active_gate
,
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frame
_size
,
batch_s
ize
,
active_node
,
active_gate
,
active_state
);
}
else
{
KeLstmForward
<
T
,
Op
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frame
Size
,
batchS
ize
,
active_node
,
active_gate
,
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frame
_size
,
batch_s
ize
,
active_node
,
active_gate
,
active_state
);
}
}
...
...
@@ -216,34 +219,34 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
template
<
class
T
,
class
Op
>
void
gpu_lstm_backward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
Size
,
int
batchS
ize
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
dim3
threads
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
if
(
batch
_s
ize
==
1
)
{
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
/* frame
PerBlock = 32 batchPerB
lock = 16 */
/* frame
_per_block = 32 batch_per_b
lock = 16 */
threads
=
dim3
(
32
,
16
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
16
-
1
)
/
16
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
16
-
1
)
/
16
);
}
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
KeLstmBackward
<
T
,
Op
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frame
Size
,
batchS
ize
,
active_node
,
active_gate
,
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frame
_size
,
batch_s
ize
,
active_node
,
active_gate
,
active_state
);
}
else
{
KeLstmBackward
<
T
,
Op
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frame
Size
,
batchS
ize
,
active_node
,
active_gate
,
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frame
_size
,
batch_s
ize
,
active_node
,
active_gate
,
active_state
);
}
}
...
...
paddle/operators/math/detail/lstm_kernel.h
浏览文件 @
e5b51c4d
...
...
@@ -27,19 +27,19 @@ namespace forward {
template
<
class
T
>
class
lstm
{
public:
HOSTDEVICE
void
operator
()(
T
&
value
In
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueO
g
,
T
&
prev
State
,
T
&
state
,
T
&
stateA
tv
,
T
&
output
,
HOSTDEVICE
void
operator
()(
T
&
value
_in
,
T
&
value_ig
,
T
&
value_fg
,
T
&
value_o
g
,
T
&
prev
_state
,
T
&
state
,
T
&
state_a
tv
,
T
&
output
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
value
In
=
activation
(
valueI
n
,
active_node
);
value
Ig
=
activation
(
valueIg
+
prevS
tate
*
checkI
,
active_gate
);
value
Fg
=
activation
(
valueFg
+
prevS
tate
*
checkF
,
active_gate
);
state
=
value
In
*
valueIg
+
prevState
*
valueF
g
;
value
Og
=
activation
(
valueO
g
+
state
*
checkO
,
active_gate
);
state
A
tv
=
activation
(
state
,
active_state
);
output
=
value
Og
*
stateA
tv
;
value
_in
=
activation
(
value_i
n
,
active_node
);
value
_ig
=
activation
(
value_ig
+
prev_s
tate
*
checkI
,
active_gate
);
value
_fg
=
activation
(
value_fg
+
prev_s
tate
*
checkF
,
active_gate
);
state
=
value
_in
*
value_ig
+
prev_state
*
value_f
g
;
value
_og
=
activation
(
value_o
g
+
state
*
checkO
,
active_gate
);
state
_a
tv
=
activation
(
state
,
active_state
);
output
=
value
_og
*
state_a
tv
;
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...
...
@@ -48,24 +48,27 @@ class lstm {
// Only float support AVX optimization
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
HOSTDEVICE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueOg
,
__m256
&
prevState
,
__m256
&
state
,
__m256
&
stateAtv
,
__m256
&
output
,
__m256
&
checkI
,
HOSTDEVICE
void
operator
()(
__m256
&
value_in
,
__m256
&
value_ig
,
__m256
&
value_fg
,
__m256
&
value_og
,
__m256
&
prev_state
,
__m256
&
state
,
__m256
&
state_atv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
valueIn
=
activation
(
valueIn
,
active_node
);
valueIg
=
activation
(
_mm256_add_ps
(
valueIg
,
_mm256_mul_ps
(
prevState
,
checkI
)),
active_gate
);
valueFg
=
activation
(
_mm256_add_ps
(
valueFg
,
_mm256_mul_ps
(
prevState
,
checkF
)),
active_gate
);
state
=
_mm256_add_ps
(
_mm256_mul_ps
(
valueIn
,
valueIg
),
_mm256_mul_ps
(
prevState
,
valueFg
));
valueOg
=
activation
(
_mm256_add_ps
(
valueOg
,
_mm256_mul_ps
(
state
,
checkO
)),
active_gate
);
stateAtv
=
activation
(
state
,
active_state
);
output
=
_mm256_mul_ps
(
valueOg
,
stateAtv
);
value_in
=
activation
(
value_in
,
active_node
);
value_ig
=
activation
(
_mm256_add_ps
(
value_ig
,
_mm256_mul_ps
(
prev_state
,
checkI
)),
active_gate
);
value_fg
=
activation
(
_mm256_add_ps
(
value_fg
,
_mm256_mul_ps
(
prev_state
,
checkF
)),
active_gate
);
state
=
_mm256_add_ps
(
_mm256_mul_ps
(
value_in
,
value_ig
),
_mm256_mul_ps
(
prev_state
,
value_fg
));
value_og
=
activation
(
_mm256_add_ps
(
value_og
,
_mm256_mul_ps
(
state
,
checkO
)),
active_gate
);
state_atv
=
activation
(
state
,
active_state
);
output
=
_mm256_mul_ps
(
value_og
,
state_atv
);
}
#endif
#endif
...
...
@@ -78,25 +81,26 @@ namespace backward {
template
<
class
T
>
class
lstm
{
public:
HOSTDEVICE
void
operator
()(
T
&
value
In
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueO
g
,
T
&
grad
In
,
T
&
gradIg
,
T
&
gradFg
,
T
&
gradO
g
,
T
&
prev
State
,
T
&
prevStateG
rad
,
T
&
state
,
T
&
state
Grad
,
T
&
stateAtv
,
T
&
outputG
rad
,
HOSTDEVICE
void
operator
()(
T
&
value
_in
,
T
&
value_ig
,
T
&
value_fg
,
T
&
value_o
g
,
T
&
grad
_in
,
T
&
grad_ig
,
T
&
grad_fg
,
T
&
grad_o
g
,
T
&
prev
_state
,
T
&
prev_state_g
rad
,
T
&
state
,
T
&
state
_grad
,
T
&
state_atv
,
T
&
output_g
rad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
gradOg
=
activation
(
outputGrad
*
stateAtv
,
valueOg
,
active_gate
);
stateGrad
+=
activation
(
outputGrad
*
valueOg
,
stateAtv
,
active_state
)
+
gradOg
*
checkO
;
gradIn
=
activation
(
stateGrad
*
valueIg
,
valueIn
,
active_node
);
gradIg
=
activation
(
stateGrad
*
valueIn
,
valueIg
,
active_gate
);
gradFg
=
activation
(
stateGrad
*
prevState
,
valueFg
,
active_gate
);
prevStateGrad
=
gradIg
*
checkI
+
gradFg
*
checkF
+
stateGrad
*
valueFg
;
checkIGrad
=
gradIg
*
prevState
;
checkFGrad
=
gradFg
*
prevState
;
checkOGrad
=
gradOg
*
state
;
grad_og
=
activation
(
output_grad
*
state_atv
,
value_og
,
active_gate
);
state_grad
+=
activation
(
output_grad
*
value_og
,
state_atv
,
active_state
)
+
grad_og
*
checkO
;
grad_in
=
activation
(
state_grad
*
value_ig
,
value_in
,
active_node
);
grad_ig
=
activation
(
state_grad
*
value_in
,
value_ig
,
active_gate
);
grad_fg
=
activation
(
state_grad
*
prev_state
,
value_fg
,
active_gate
);
prev_state_grad
=
grad_ig
*
checkI
+
grad_fg
*
checkF
+
state_grad
*
value_fg
;
checkIGrad
=
grad_ig
*
prev_state
;
checkFGrad
=
grad_fg
*
prev_state
;
checkOGrad
=
grad_og
*
state
;
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...
...
@@ -105,32 +109,32 @@ class lstm {
// Only float support AVX optimization
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
HOSTDEVICE
void
operator
()(
__m256
&
value
In
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueO
g
,
__m256
&
grad
In
,
__m256
&
gradIg
,
__m256
&
gradFg
,
__m256
&
gradO
g
,
__m256
&
prev
State
,
__m256
&
prevStateG
rad
,
__m256
&
state
,
__m256
&
state
Grad
,
__m256
&
stateAtv
,
__m256
&
outputGrad
,
__m256
&
checkI
,
__m256
&
check
F
,
__m256
&
checkO
,
__m256
&
checkIGrad
,
__m256
&
checkF
Grad
,
__m256
&
checkOGrad
,
activation_mode_t
active_node
,
__m256
&
value
_in
,
__m256
&
value_ig
,
__m256
&
value_fg
,
__m256
&
value_o
g
,
__m256
&
grad
_in
,
__m256
&
grad_ig
,
__m256
&
grad_fg
,
__m256
&
grad_o
g
,
__m256
&
prev
_state
,
__m256
&
prev_state_g
rad
,
__m256
&
state
,
__m256
&
state
_grad
,
__m256
&
state_atv
,
__m256
&
output_grad
,
__m256
&
check
I
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkI
Grad
,
__m256
&
check
FGrad
,
__m256
&
check
OGrad
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
grad
Og
=
activation
(
_mm256_mul_ps
(
outputGrad
,
stateAtv
),
valueOg
,
active_gate
);
state
Grad
=
_mm256_add_ps
(
activation
(
_mm256_mul_ps
(
outputGrad
,
valueOg
),
stateA
tv
,
active_state
),
stateG
rad
);
state
Grad
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradOg
,
checkO
),
stateG
rad
);
grad
I
n
=
activation
(
_mm256_mul_ps
(
state
Grad
,
valueIg
),
valueI
n
,
active_node
);
grad
I
g
=
activation
(
_mm256_mul_ps
(
state
Grad
,
valueIn
),
valueI
g
,
active_gate
);
grad
Fg
=
activation
(
_mm256_mul_ps
(
stateGrad
,
prevState
),
valueFg
,
active_gate
);
prev
StateGrad
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradI
g
,
checkI
),
_mm256_mul_ps
(
gradF
g
,
checkF
));
prev
StateG
rad
=
_mm256_add_ps
(
_mm256_mul_ps
(
state
Grad
,
valueFg
),
prevStateG
rad
);
checkIGrad
=
_mm256_mul_ps
(
grad
Ig
,
prevS
tate
);
checkFGrad
=
_mm256_mul_ps
(
grad
Fg
,
prevS
tate
);
checkOGrad
=
_mm256_mul_ps
(
grad
O
g
,
state
);
grad
_og
=
activation
(
_mm256_mul_ps
(
output_grad
,
state_atv
),
value_og
,
active_gate
);
state
_grad
=
_mm256_add_ps
(
activation
(
_mm256_mul_ps
(
output_grad
,
value_og
),
state_a
tv
,
active_state
),
state_g
rad
);
state
_grad
=
_mm256_add_ps
(
_mm256_mul_ps
(
grad_og
,
checkO
),
state_g
rad
);
grad
_i
n
=
activation
(
_mm256_mul_ps
(
state
_grad
,
value_ig
),
value_i
n
,
active_node
);
grad
_i
g
=
activation
(
_mm256_mul_ps
(
state
_grad
,
value_in
),
value_i
g
,
active_gate
);
grad
_fg
=
activation
(
_mm256_mul_ps
(
state_grad
,
prev_state
),
value_fg
,
active_gate
);
prev
_state_grad
=
_mm256_add_ps
(
_mm256_mul_ps
(
grad_i
g
,
checkI
),
_mm256_mul_ps
(
grad_f
g
,
checkF
));
prev
_state_g
rad
=
_mm256_add_ps
(
_mm256_mul_ps
(
state
_grad
,
value_fg
),
prev_state_g
rad
);
checkIGrad
=
_mm256_mul_ps
(
grad
_ig
,
prev_s
tate
);
checkFGrad
=
_mm256_mul_ps
(
grad
_fg
,
prev_s
tate
);
checkOGrad
=
_mm256_mul_ps
(
grad
_o
g
,
state
);
}
#endif
#endif
...
...
paddle/operators/math/lstm_compute.cc
浏览文件 @
e5b51c4d
...
...
@@ -30,12 +30,12 @@ struct LstmUnitFunctor<platform::CPUPlace, T> {
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
value
.
gate
V
alue
+=
frame_size
*
4
;
value
.
state
V
alue
+=
frame_size
;
value
.
state
ActiveV
alue
+=
frame_size
;
value
.
output
V
alue
+=
frame_size
;
if
(
value
.
prev
StateV
alue
)
{
value
.
prev
StateV
alue
+=
frame_size
;
value
.
gate
_v
alue
+=
frame_size
*
4
;
value
.
state
_v
alue
+=
frame_size
;
value
.
state
_active_v
alue
+=
frame_size
;
value
.
output
_v
alue
+=
frame_size
;
if
(
value
.
prev
_state_v
alue
)
{
value
.
prev
_state_v
alue
+=
frame_size
;
}
}
}
...
...
@@ -53,20 +53,20 @@ struct LstmUnitGradFunctor<platform::CPUPlace, T> {
frame_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
value
.
gate
V
alue
+=
frame_size
*
4
;
value
.
state
V
alue
+=
frame_size
;
value
.
state
ActiveV
alue
+=
frame_size
;
value
.
output
V
alue
+=
frame_size
;
if
(
value
.
prev
StateV
alue
)
{
value
.
prev
StateV
alue
+=
frame_size
;
value
.
gate
_v
alue
+=
frame_size
*
4
;
value
.
state
_v
alue
+=
frame_size
;
value
.
state
_active_v
alue
+=
frame_size
;
value
.
output
_v
alue
+=
frame_size
;
if
(
value
.
prev
_state_v
alue
)
{
value
.
prev
_state_v
alue
+=
frame_size
;
}
grad
.
gate
G
rad
+=
frame_size
*
4
;
grad
.
state
G
rad
+=
frame_size
;
grad
.
state
ActiveG
rad
+=
frame_size
;
grad
.
output
G
rad
+=
frame_size
;
if
(
grad
.
prev
StateG
rad
)
{
grad
.
prev
StateG
rad
+=
frame_size
;
grad
.
gate
_g
rad
+=
frame_size
*
4
;
grad
.
state
_g
rad
+=
frame_size
;
grad
.
state
_active_g
rad
+=
frame_size
;
grad
.
output
_g
rad
+=
frame_size
;
if
(
grad
.
prev
_state_g
rad
)
{
grad
.
prev
_state_g
rad
+=
frame_size
;
}
}
}
...
...
paddle/operators/math/lstm_compute.h
浏览文件 @
e5b51c4d
...
...
@@ -31,26 +31,26 @@ typedef enum {
template
<
class
T
>
struct
LstmMetaValue
{
T
*
gate
V
alue
;
T
*
prev
StateV
alue
;
T
*
state
V
alue
;
T
*
state
ActiveV
alue
;
T
*
output
V
alue
;
T
*
check
I
g
;
T
*
check
F
g
;
T
*
check
O
g
;
T
*
gate
_v
alue
;
T
*
prev
_state_v
alue
;
T
*
state
_v
alue
;
T
*
state
_active_v
alue
;
T
*
output
_v
alue
;
T
*
check
_i
g
;
T
*
check
_f
g
;
T
*
check
_o
g
;
};
template
<
class
T
>
struct
LstmMetaGrad
{
T
*
gate
G
rad
;
T
*
prev
StateG
rad
;
T
*
state
G
rad
;
T
*
state
ActiveG
rad
;
T
*
output
G
rad
;
T
*
check
IgG
rad
;
T
*
check
FgG
rad
;
T
*
check
OgG
rad
;
T
*
gate
_g
rad
;
T
*
prev
_state_g
rad
;
T
*
state
_g
rad
;
T
*
state
_active_g
rad
;
T
*
output
_g
rad
;
T
*
check
_ig_g
rad
;
T
*
check
_fg_g
rad
;
T
*
check
_og_g
rad
;
};
inline
activation_mode_t
ActiveType
(
const
std
::
string
&
type
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录