Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
74725d05
P
Paddle
项目概览
PaddlePaddle
/
Paddle
11 个月 前同步成功
通知
2289
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
74725d05
编写于
12月 04, 2017
作者:
Q
qingqing01
提交者:
GitHub
12月 04, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6198 from qingqing01/lstm_style
Make lstm_op follow google code style.
上级
0c15b6c2
e5b51c4d
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
505 addition
and
492 deletion
+505
-492
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+35
-35
paddle/operators/math/detail/lstm_cpu_kernel.h
paddle/operators/math/detail/lstm_cpu_kernel.h
+216
-210
paddle/operators/math/detail/lstm_gpu_kernel.h
paddle/operators/math/detail/lstm_gpu_kernel.h
+154
-151
paddle/operators/math/detail/lstm_kernel.h
paddle/operators/math/detail/lstm_kernel.h
+66
-62
paddle/operators/math/lstm_compute.cc
paddle/operators/math/lstm_compute.cc
+18
-18
paddle/operators/math/lstm_compute.h
paddle/operators/math/lstm_compute.h
+16
-16
未找到文件。
paddle/operators/lstm_op.h
浏览文件 @
74725d05
...
@@ -73,15 +73,15 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -73,15 +73,15 @@ class LSTMKernel : public framework::OpKernel<T> {
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
// the code style in LstmMetaValue will be updated later.
// the code style in LstmMetaValue will be updated later.
lstm_value
.
check
I
g
=
bias_data
+
4
*
frame_size
;
lstm_value
.
check
_i
g
=
bias_data
+
4
*
frame_size
;
lstm_value
.
check
Fg
=
lstm_value
.
checkI
g
+
frame_size
;
lstm_value
.
check
_fg
=
lstm_value
.
check_i
g
+
frame_size
;
lstm_value
.
check
Og
=
lstm_value
.
checkF
g
+
frame_size
;
lstm_value
.
check
_og
=
lstm_value
.
check_f
g
+
frame_size
;
}
else
{
}
else
{
lstm_value
.
check
I
g
=
nullptr
;
lstm_value
.
check
_i
g
=
nullptr
;
lstm_value
.
check
F
g
=
nullptr
;
lstm_value
.
check
_f
g
=
nullptr
;
lstm_value
.
check
O
g
=
nullptr
;
lstm_value
.
check
_o
g
=
nullptr
;
}
}
lstm_value
.
prev
StateV
alue
=
nullptr
;
lstm_value
.
prev
_state_v
alue
=
nullptr
;
Tensor
ordered_c0
;
Tensor
ordered_c0
;
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
if
(
cell_t0
)
{
if
(
cell_t0
)
{
...
@@ -90,7 +90,7 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -90,7 +90,7 @@ class LSTMKernel : public framework::OpKernel<T> {
// to reorder.
// to reorder.
ReorderInitState
<
Place
,
T
>
(
device_ctx
,
*
cell_t0
,
order
,
&
ordered_c0
,
ReorderInitState
<
Place
,
T
>
(
device_ctx
,
*
cell_t0
,
order
,
&
ordered_c0
,
true
);
true
);
lstm_value
.
prev
StateV
alue
=
ordered_c0
.
data
<
T
>
();
lstm_value
.
prev
_state_v
alue
=
ordered_c0
.
data
<
T
>
();
}
}
// Use the local variable as here.
// Use the local variable as here.
...
@@ -140,14 +140,14 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -140,14 +140,14 @@ class LSTMKernel : public framework::OpKernel<T> {
static_cast
<
T
>
(
1.0
));
static_cast
<
T
>
(
1.0
));
}
}
lstm_value
.
gate
V
alue
=
gate_t
.
data
<
T
>
();
lstm_value
.
gate
_v
alue
=
gate_t
.
data
<
T
>
();
lstm_value
.
output
V
alue
=
out_t
.
data
<
T
>
();
lstm_value
.
output
_v
alue
=
out_t
.
data
<
T
>
();
lstm_value
.
state
V
alue
=
cell_t
.
data
<
T
>
();
lstm_value
.
state
_v
alue
=
cell_t
.
data
<
T
>
();
lstm_value
.
state
ActiveV
alue
=
cell_pre_act_t
.
data
<
T
>
();
lstm_value
.
state
_active_v
alue
=
cell_pre_act_t
.
data
<
T
>
();
math
::
LstmUnitFunctor
<
Place
,
T
>::
compute
(
device_ctx
,
lstm_value
,
math
::
LstmUnitFunctor
<
Place
,
T
>::
compute
(
device_ctx
,
lstm_value
,
frame_size
,
cur_batch_size
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
gate_act
,
cell_act
,
cand_act
);
lstm_value
.
prev
StateValue
=
lstm_value
.
stateV
alue
;
lstm_value
.
prev
_state_value
=
lstm_value
.
state_v
alue
;
}
}
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
...
@@ -214,13 +214,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -214,13 +214,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math
::
LstmMetaValue
<
T
>
lstm_value
;
math
::
LstmMetaValue
<
T
>
lstm_value
;
if
(
bias
&&
ctx
.
Attr
<
bool
>
(
"use_peepholes"
))
{
if
(
bias
&&
ctx
.
Attr
<
bool
>
(
"use_peepholes"
))
{
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
lstm_value
.
check
I
g
=
bias_data
+
4
*
frame_size
;
lstm_value
.
check
_i
g
=
bias_data
+
4
*
frame_size
;
lstm_value
.
check
Fg
=
lstm_value
.
checkI
g
+
frame_size
;
lstm_value
.
check
_fg
=
lstm_value
.
check_i
g
+
frame_size
;
lstm_value
.
check
Og
=
lstm_value
.
checkF
g
+
frame_size
;
lstm_value
.
check
_og
=
lstm_value
.
check_f
g
+
frame_size
;
}
else
{
}
else
{
lstm_value
.
check
I
g
=
nullptr
;
lstm_value
.
check
_i
g
=
nullptr
;
lstm_value
.
check
F
g
=
nullptr
;
lstm_value
.
check
_f
g
=
nullptr
;
lstm_value
.
check
O
g
=
nullptr
;
lstm_value
.
check
_o
g
=
nullptr
;
}
}
math
::
LstmMetaGrad
<
T
>
lstm_grad
;
math
::
LstmMetaGrad
<
T
>
lstm_grad
;
...
@@ -231,13 +231,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -231,13 +231,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
}
}
if
(
bias
&&
bias_g
&&
ctx
.
Attr
<
bool
>
(
"use_peepholes"
))
{
if
(
bias
&&
bias_g
&&
ctx
.
Attr
<
bool
>
(
"use_peepholes"
))
{
T
*
bias_g_data
=
bias_g
->
data
<
T
>
();
T
*
bias_g_data
=
bias_g
->
data
<
T
>
();
lstm_grad
.
check
IgG
rad
=
bias_g_data
+
4
*
frame_size
;
lstm_grad
.
check
_ig_g
rad
=
bias_g_data
+
4
*
frame_size
;
lstm_grad
.
check
FgGrad
=
lstm_grad
.
checkIgG
rad
+
frame_size
;
lstm_grad
.
check
_fg_grad
=
lstm_grad
.
check_ig_g
rad
+
frame_size
;
lstm_grad
.
check
OgGrad
=
lstm_grad
.
checkFgG
rad
+
frame_size
;
lstm_grad
.
check
_og_grad
=
lstm_grad
.
check_fg_g
rad
+
frame_size
;
}
else
{
}
else
{
lstm_grad
.
check
IgG
rad
=
nullptr
;
lstm_grad
.
check
_ig_g
rad
=
nullptr
;
lstm_grad
.
check
FgG
rad
=
nullptr
;
lstm_grad
.
check
_fg_g
rad
=
nullptr
;
lstm_grad
.
check
OgG
rad
=
nullptr
;
lstm_grad
.
check
_og_g
rad
=
nullptr
;
}
}
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
...
@@ -276,26 +276,26 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -276,26 +276,26 @@ class LSTMGradKernel : public framework::OpKernel<T> {
Tensor
gate
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
cell
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act
=
batch_cell_pre_act
->
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act
=
batch_cell_pre_act
->
Slice
(
bstart
,
bend
);
lstm_value
.
gate
V
alue
=
gate
.
data
<
T
>
();
lstm_value
.
gate
_v
alue
=
gate
.
data
<
T
>
();
lstm_value
.
state
V
alue
=
cell
.
data
<
T
>
();
lstm_value
.
state
_v
alue
=
cell
.
data
<
T
>
();
lstm_value
.
state
ActiveV
alue
=
cell_pre_act
.
data
<
T
>
();
lstm_value
.
state
_active_v
alue
=
cell_pre_act
.
data
<
T
>
();
Tensor
out_g
=
batch_hidden_g
.
Slice
(
bstart
,
bend
);
Tensor
out_g
=
batch_hidden_g
.
Slice
(
bstart
,
bend
);
Tensor
gate_g
=
batch_gate_g
.
Slice
(
bstart
,
bend
);
Tensor
gate_g
=
batch_gate_g
.
Slice
(
bstart
,
bend
);
Tensor
cell_g
=
batch_cell_g
.
Slice
(
bstart
,
bend
);
Tensor
cell_g
=
batch_cell_g
.
Slice
(
bstart
,
bend
);
lstm_grad
.
state
G
rad
=
cell_g
.
data
<
T
>
();
lstm_grad
.
state
_g
rad
=
cell_g
.
data
<
T
>
();
lstm_grad
.
gate
G
rad
=
gate_g
.
data
<
T
>
();
lstm_grad
.
gate
_g
rad
=
gate_g
.
data
<
T
>
();
lstm_grad
.
output
G
rad
=
out_g
.
data
<
T
>
();
lstm_grad
.
output
_g
rad
=
out_g
.
data
<
T
>
();
if
(
n
>
0
)
{
if
(
n
>
0
)
{
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
Tensor
cell_pre
=
batch_cell
.
Slice
(
bstart_pre
,
bstart
);
Tensor
cell_pre
=
batch_cell
.
Slice
(
bstart_pre
,
bstart
);
Tensor
cell_pre_g
=
batch_cell_g
.
Slice
(
bstart_pre
,
bstart
);
Tensor
cell_pre_g
=
batch_cell_g
.
Slice
(
bstart_pre
,
bstart
);
lstm_value
.
prev
StateV
alue
=
cell_pre
.
data
<
T
>
();
lstm_value
.
prev
_state_v
alue
=
cell_pre
.
data
<
T
>
();
lstm_grad
.
prev
StateG
rad
=
cell_pre_g
.
data
<
T
>
();
lstm_grad
.
prev
_state_g
rad
=
cell_pre_g
.
data
<
T
>
();
}
else
{
}
else
{
lstm_value
.
prev
StateV
alue
=
c0
?
ordered_c0
.
data
<
T
>
()
:
nullptr
;
lstm_value
.
prev
_state_v
alue
=
c0
?
ordered_c0
.
data
<
T
>
()
:
nullptr
;
lstm_grad
.
prev
StateG
rad
=
c0_g
?
ordered_c0_g
.
data
<
T
>
()
:
nullptr
;
lstm_grad
.
prev
_state_g
rad
=
c0_g
?
ordered_c0_g
.
data
<
T
>
()
:
nullptr
;
}
}
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
...
...
paddle/operators/math/detail/lstm_cpu_kernel.h
浏览文件 @
74725d05
...
@@ -26,278 +26,284 @@ namespace detail {
...
@@ -26,278 +26,284 @@ namespace detail {
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
S
ize
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
T
r
ValueI
n
;
T
r
_value_i
n
;
T
r
ValueI
g
;
T
r
_value_i
g
;
T
r
ValueF
g
;
T
r
_value_f
g
;
T
r
ValueO
g
;
T
r
_value_o
g
;
T
r
C
heckI
;
T
r
_c
heckI
;
T
r
C
heckF
;
T
r
_c
heckF
;
T
r
C
heckO
;
T
r
_c
heckO
;
T
r
S
tate
;
T
r
_s
tate
;
T
r
PrevS
tate
=
0
;
T
r
_prev_s
tate
=
0
;
T
r
StateA
tv
;
T
r
_state_a
tv
;
T
r
O
ut
;
T
r
_o
ut
;
T
*
value
In
=
value
.
gateV
alue
;
T
*
value
_in
=
value
.
gate_v
alue
;
T
*
value
Ig
=
value
.
gateValue
+
frameS
ize
;
T
*
value
_ig
=
value
.
gate_value
+
frame_s
ize
;
T
*
value
Fg
=
value
.
gateValue
+
frameS
ize
*
2
;
T
*
value
_fg
=
value
.
gate_value
+
frame_s
ize
*
2
;
T
*
value
Og
=
value
.
gateValue
+
frameS
ize
*
3
;
T
*
value
_og
=
value
.
gate_value
+
frame_s
ize
*
3
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
ValueIn
=
valueI
n
[
i
];
r
_value_in
=
value_i
n
[
i
];
r
ValueIg
=
valueI
g
[
i
];
r
_value_ig
=
value_i
g
[
i
];
r
ValueFg
=
valueF
g
[
i
];
r
_value_fg
=
value_f
g
[
i
];
r
ValueOg
=
valueO
g
[
i
];
r
_value_og
=
value_o
g
[
i
];
r
CheckI
=
value
.
checkIg
?
value
.
checkI
g
[
i
]
:
0
;
r
_checkI
=
value
.
check_ig
?
value
.
check_i
g
[
i
]
:
0
;
r
CheckF
=
value
.
checkFg
?
value
.
checkF
g
[
i
]
:
0
;
r
_checkF
=
value
.
check_fg
?
value
.
check_f
g
[
i
]
:
0
;
r
CheckO
=
value
.
checkOg
?
value
.
checkO
g
[
i
]
:
0
;
r
_checkO
=
value
.
check_og
?
value
.
check_o
g
[
i
]
:
0
;
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
r
PrevState
=
value
.
prevStateV
alue
[
i
];
r
_prev_state
=
value
.
prev_state_v
alue
[
i
];
}
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_prev_state
,
r_state
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
r_state_atv
,
r_out
,
r_checkI
,
r_checkF
,
r_checkO
,
active_node
,
active_gate
,
active_state
);
valueIn
[
i
]
=
rValueIn
;
valueIg
[
i
]
=
rValueIg
;
value_in
[
i
]
=
r_value_in
;
valueFg
[
i
]
=
rValueFg
;
value_ig
[
i
]
=
r_value_ig
;
valueOg
[
i
]
=
rValueOg
;
value_fg
[
i
]
=
r_value_fg
;
value
.
stateValue
[
i
]
=
rState
;
value_og
[
i
]
=
r_value_og
;
value
.
stateActiveValue
[
i
]
=
rStateAtv
;
value
.
state_value
[
i
]
=
r_state
;
value
.
outputValue
[
i
]
=
rOut
;
value
.
state_active_value
[
i
]
=
r_state_atv
;
value
.
output_value
[
i
]
=
r_out
;
}
}
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
S
ize
,
LstmMetaGrad
<
T
>
grad
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
T
r
ValueI
n
;
T
r
_value_i
n
;
T
r
ValueI
g
;
T
r
_value_i
g
;
T
r
ValueF
g
;
T
r
_value_f
g
;
T
r
ValueO
g
;
T
r
_value_o
g
;
T
r
GradI
n
;
T
r
_grad_i
n
;
T
r
GradI
g
;
T
r
_grad_i
g
;
T
r
GradF
g
;
T
r
_grad_f
g
;
T
r
GradO
g
;
T
r
_grad_o
g
;
T
r
PrevS
tate
=
0
;
T
r
_prev_s
tate
=
0
;
T
r
PrevStateG
rad
;
T
r
_prev_state_g
rad
;
T
r
S
tate
;
T
r
_s
tate
;
T
r
StateG
rad
;
T
r
_state_g
rad
;
T
r
StateA
tv
;
T
r
_state_a
tv
;
T
r
OutputG
rad
;
T
r
_output_g
rad
;
T
r
C
heckI
;
T
r
_c
heckI
;
T
r
C
heckF
;
T
r
_c
heckF
;
T
r
C
heckO
;
T
r
_c
heckO
;
T
r
C
heckIGrad
;
T
r
_c
heckIGrad
;
T
r
C
heckFGrad
;
T
r
_c
heckFGrad
;
T
r
C
heckOGrad
;
T
r
_c
heckOGrad
;
T
*
value
In
=
value
.
gateV
alue
;
T
*
value
_in
=
value
.
gate_v
alue
;
T
*
value
Ig
=
value
.
gateValue
+
frameS
ize
;
T
*
value
_ig
=
value
.
gate_value
+
frame_s
ize
;
T
*
value
Fg
=
value
.
gateValue
+
frameS
ize
*
2
;
T
*
value
_fg
=
value
.
gate_value
+
frame_s
ize
*
2
;
T
*
value
Og
=
value
.
gateValue
+
frameS
ize
*
3
;
T
*
value
_og
=
value
.
gate_value
+
frame_s
ize
*
3
;
T
*
grad
In
=
grad
.
gateG
rad
;
T
*
grad
_in
=
grad
.
gate_g
rad
;
T
*
grad
Ig
=
grad
.
gateGrad
+
frameS
ize
;
T
*
grad
_ig
=
grad
.
gate_grad
+
frame_s
ize
;
T
*
grad
Fg
=
grad
.
gateGrad
+
frameS
ize
*
2
;
T
*
grad
_fg
=
grad
.
gate_grad
+
frame_s
ize
*
2
;
T
*
grad
Og
=
grad
.
gateGrad
+
frameS
ize
*
3
;
T
*
grad
_og
=
grad
.
gate_grad
+
frame_s
ize
*
3
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
ValueIn
=
valueI
n
[
i
];
r
_value_in
=
value_i
n
[
i
];
r
ValueIg
=
valueI
g
[
i
];
r
_value_ig
=
value_i
g
[
i
];
r
ValueFg
=
valueF
g
[
i
];
r
_value_fg
=
value_f
g
[
i
];
r
ValueOg
=
valueO
g
[
i
];
r
_value_og
=
value_o
g
[
i
];
r
CheckI
=
value
.
checkIg
?
value
.
checkI
g
[
i
]
:
0
;
r
_checkI
=
value
.
check_ig
?
value
.
check_i
g
[
i
]
:
0
;
r
CheckF
=
value
.
checkFg
?
value
.
checkF
g
[
i
]
:
0
;
r
_checkF
=
value
.
check_fg
?
value
.
check_f
g
[
i
]
:
0
;
r
CheckO
=
value
.
checkOg
?
value
.
checkO
g
[
i
]
:
0
;
r
_checkO
=
value
.
check_og
?
value
.
check_o
g
[
i
]
:
0
;
r
State
=
value
.
stateV
alue
[
i
];
r
_state
=
value
.
state_v
alue
[
i
];
r
StateAtv
=
value
.
stateActiveV
alue
[
i
];
r
_state_atv
=
value
.
state_active_v
alue
[
i
];
r
OutputGrad
=
grad
.
outputG
rad
[
i
];
r
_output_grad
=
grad
.
output_g
rad
[
i
];
r
StateGrad
=
grad
.
stateG
rad
[
i
];
r
_state_grad
=
grad
.
state_g
rad
[
i
];
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
r
PrevState
=
value
.
prevStateV
alue
[
i
];
r
_prev_state
=
value
.
prev_state_v
alue
[
i
];
}
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_grad_in
,
r_grad_ig
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
r_grad_fg
,
r_grad_og
,
r_prev_state
,
r_prev_state_grad
,
r_state
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
r_state_grad
,
r_state_atv
,
r_output_grad
,
r_checkI
,
r_checkF
,
r_checkO
,
rCheckOGrad
,
active_node
,
active_gate
,
active_state
);
r_checkIGrad
,
r_checkFGrad
,
r_checkOGrad
,
active_node
,
active_gate
,
active_state
);
gradIn
[
i
]
=
rGradIn
;
gradIg
[
i
]
=
rGradIg
;
grad_in
[
i
]
=
r_grad_in
;
gradFg
[
i
]
=
rGradFg
;
grad_ig
[
i
]
=
r_grad_ig
;
gradOg
[
i
]
=
rGradOg
;
grad_fg
[
i
]
=
r_grad_fg
;
grad
.
stateGrad
[
i
]
=
rStateGrad
;
grad_og
[
i
]
=
r_grad_og
;
grad
.
state_grad
[
i
]
=
r_state_grad
;
if
(
grad
.
prevStateGrad
)
grad
.
prevStateGrad
[
i
]
=
rPrevStateGrad
;
if
(
value
.
prevStateValue
)
{
if
(
grad
.
prev_state_grad
)
grad
.
prev_state_grad
[
i
]
=
r_prev_state_grad
;
if
(
grad
.
checkIgGrad
)
grad
.
checkIgGrad
[
i
]
+=
rCheckIGrad
;
if
(
value
.
prev_state_value
)
{
if
(
grad
.
checkFgGrad
)
grad
.
checkFgGrad
[
i
]
+=
rCheckFGrad
;
if
(
grad
.
check_ig_grad
)
grad
.
check_ig_grad
[
i
]
+=
r_checkIGrad
;
if
(
grad
.
check_fg_grad
)
grad
.
check_fg_grad
[
i
]
+=
r_checkFGrad
;
}
}
if
(
grad
.
check
OgGrad
)
grad
.
checkOgGrad
[
i
]
+=
rC
heckOGrad
;
if
(
grad
.
check
_og_grad
)
grad
.
check_og_grad
[
i
]
+=
r_c
heckOGrad
;
}
}
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
avx_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frameSize
,
void
avx_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r
ValueI
n
;
__m256
r
_value_i
n
;
__m256
r
ValueI
g
;
__m256
r
_value_i
g
;
__m256
r
ValueF
g
;
__m256
r
_value_f
g
;
__m256
r
ValueO
g
;
__m256
r
_value_o
g
;
__m256
r
C
heckI
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckI
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckF
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckF
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckO
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckO
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
S
tate
;
__m256
r
_s
tate
;
__m256
r
PrevS
tate
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_s
tate
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
StateA
tv
;
__m256
r
_state_a
tv
;
__m256
r
O
ut
;
__m256
r
_o
ut
;
__m256
*
value
In
=
(
__m256
*
)
value
.
gateV
alue
;
__m256
*
value
_in
=
(
__m256
*
)
value
.
gate_v
alue
;
__m256
*
value
Ig
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
);
__m256
*
value
_ig
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
);
__m256
*
value
Fg
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
*
2
);
__m256
*
value
_fg
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
*
2
);
__m256
*
value
Og
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
*
3
);
__m256
*
value
_og
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
*
3
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
ValueIn
=
valueI
n
[
i
];
r
_value_in
=
value_i
n
[
i
];
r
ValueIg
=
valueI
g
[
i
];
r
_value_ig
=
value_i
g
[
i
];
r
ValueFg
=
valueF
g
[
i
];
r
_value_fg
=
value_f
g
[
i
];
r
ValueOg
=
valueO
g
[
i
];
r
_value_og
=
value_o
g
[
i
];
if
(
value
.
check
I
g
)
{
if
(
value
.
check
_i
g
)
{
r
CheckI
=
((
__m256
*
)
value
.
checkI
g
)[
i
];
r
_checkI
=
((
__m256
*
)
value
.
check_i
g
)[
i
];
r
CheckF
=
((
__m256
*
)
value
.
checkF
g
)[
i
];
r
_checkF
=
((
__m256
*
)
value
.
check_f
g
)[
i
];
r
CheckO
=
((
__m256
*
)
value
.
checkO
g
)[
i
];
r
_checkO
=
((
__m256
*
)
value
.
check_o
g
)[
i
];
}
}
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
r
PrevState
=
((
__m256
*
)
value
.
prevStateV
alue
)[
i
];
r
_prev_state
=
((
__m256
*
)
value
.
prev_state_v
alue
)[
i
];
}
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_prev_state
,
r_state
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
r_state_atv
,
r_out
,
r_checkI
,
r_checkF
,
r_checkO
,
active_node
,
active_gate
,
active_state
);
valueIn
[
i
]
=
rValueIn
;
valueIg
[
i
]
=
rValueIg
;
value_in
[
i
]
=
r_value_in
;
valueFg
[
i
]
=
rValueFg
;
value_ig
[
i
]
=
r_value_ig
;
valueOg
[
i
]
=
rValueOg
;
value_fg
[
i
]
=
r_value_fg
;
((
__m256
*
)
value
.
stateValue
)[
i
]
=
rState
;
value_og
[
i
]
=
r_value_og
;
((
__m256
*
)
value
.
stateActiveValue
)[
i
]
=
rStateAtv
;
((
__m256
*
)
value
.
state_value
)[
i
]
=
r_state
;
((
__m256
*
)
value
.
outputValue
)[
i
]
=
rOut
;
((
__m256
*
)
value
.
state_active_value
)[
i
]
=
r_state_atv
;
((
__m256
*
)
value
.
output_value
)[
i
]
=
r_out
;
}
}
#endif
#endif
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
avx_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
avx_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
S
ize
,
LstmMetaGrad
<
T
>
grad
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r
ValueI
n
;
__m256
r
_value_i
n
;
__m256
r
ValueI
g
;
__m256
r
_value_i
g
;
__m256
r
ValueF
g
;
__m256
r
_value_f
g
;
__m256
r
ValueO
g
;
__m256
r
_value_o
g
;
__m256
r
GradI
n
;
__m256
r
_grad_i
n
;
__m256
r
GradI
g
;
__m256
r
_grad_i
g
;
__m256
r
GradF
g
;
__m256
r
_grad_f
g
;
__m256
r
GradO
g
;
__m256
r
_grad_o
g
;
__m256
r
PrevS
tate
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_s
tate
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
PrevStateG
rad
;
__m256
r
_prev_state_g
rad
;
__m256
r
StateG
rad
;
__m256
r
_state_g
rad
;
__m256
r
S
tate
;
__m256
r
_s
tate
;
__m256
r
StateA
tv
;
__m256
r
_state_a
tv
;
__m256
r
OutputG
rad
;
__m256
r
_output_g
rad
;
__m256
r
C
heckI
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckI
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckF
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckF
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckO
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckO
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckIGrad
;
__m256
r
_c
heckIGrad
;
__m256
r
C
heckFGrad
;
__m256
r
_c
heckFGrad
;
__m256
r
C
heckOGrad
;
__m256
r
_c
heckOGrad
;
__m256
*
value
In
=
(
__m256
*
)
value
.
gateV
alue
;
__m256
*
value
_in
=
(
__m256
*
)
value
.
gate_v
alue
;
__m256
*
value
Ig
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
);
__m256
*
value
_ig
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
);
__m256
*
value
Fg
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
*
2
);
__m256
*
value
_fg
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
*
2
);
__m256
*
value
Og
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
*
3
);
__m256
*
value
_og
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
*
3
);
__m256
*
grad
In
=
(
__m256
*
)
grad
.
gateG
rad
;
__m256
*
grad
_in
=
(
__m256
*
)
grad
.
gate_g
rad
;
__m256
*
grad
Ig
=
(
__m256
*
)(
grad
.
gateGrad
+
frameS
ize
);
__m256
*
grad
_ig
=
(
__m256
*
)(
grad
.
gate_grad
+
frame_s
ize
);
__m256
*
grad
Fg
=
(
__m256
*
)(
grad
.
gateGrad
+
frameS
ize
*
2
);
__m256
*
grad
_fg
=
(
__m256
*
)(
grad
.
gate_grad
+
frame_s
ize
*
2
);
__m256
*
grad
Og
=
(
__m256
*
)(
grad
.
gateGrad
+
frameS
ize
*
3
);
__m256
*
grad
_og
=
(
__m256
*
)(
grad
.
gate_grad
+
frame_s
ize
*
3
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
ValueIn
=
valueI
n
[
i
];
r
_value_in
=
value_i
n
[
i
];
r
ValueIg
=
valueI
g
[
i
];
r
_value_ig
=
value_i
g
[
i
];
r
ValueFg
=
valueF
g
[
i
];
r
_value_fg
=
value_f
g
[
i
];
r
ValueOg
=
valueO
g
[
i
];
r
_value_og
=
value_o
g
[
i
];
if
(
value
.
check
I
g
)
{
if
(
value
.
check
_i
g
)
{
r
CheckI
=
((
__m256
*
)
value
.
checkI
g
)[
i
];
r
_checkI
=
((
__m256
*
)
value
.
check_i
g
)[
i
];
r
CheckF
=
((
__m256
*
)
value
.
checkF
g
)[
i
];
r
_checkF
=
((
__m256
*
)
value
.
check_f
g
)[
i
];
r
CheckO
=
((
__m256
*
)
value
.
checkO
g
)[
i
];
r
_checkO
=
((
__m256
*
)
value
.
check_o
g
)[
i
];
}
}
r
State
=
((
__m256
*
)
value
.
stateV
alue
)[
i
];
r
_state
=
((
__m256
*
)
value
.
state_v
alue
)[
i
];
r
StateAtv
=
((
__m256
*
)
value
.
stateActiveV
alue
)[
i
];
r
_state_atv
=
((
__m256
*
)
value
.
state_active_v
alue
)[
i
];
r
OutputGrad
=
((
__m256
*
)
grad
.
outputG
rad
)[
i
];
r
_output_grad
=
((
__m256
*
)
grad
.
output_g
rad
)[
i
];
r
StateGrad
=
((
__m256
*
)
grad
.
stateG
rad
)[
i
];
r
_state_grad
=
((
__m256
*
)
grad
.
state_g
rad
)[
i
];
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
r
PrevState
=
((
__m256
*
)
value
.
prevStateV
alue
)[
i
];
r
_prev_state
=
((
__m256
*
)
value
.
prev_state_v
alue
)[
i
];
}
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_grad_in
,
r_grad_ig
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
r_grad_fg
,
r_grad_og
,
r_prev_state
,
r_prev_state_grad
,
r_state
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
r_state_grad
,
r_state_atv
,
r_output_grad
,
r_checkI
,
r_checkF
,
r_checkO
,
rCheckOGrad
,
active_node
,
active_gate
,
active_state
);
r_checkIGrad
,
r_checkFGrad
,
r_checkOGrad
,
active_node
,
active_gate
,
active_state
);
gradIn
[
i
]
=
rGradIn
;
gradIg
[
i
]
=
rGradIg
;
grad_in
[
i
]
=
r_grad_in
;
gradFg
[
i
]
=
rGradFg
;
grad_ig
[
i
]
=
r_grad_ig
;
gradOg
[
i
]
=
rGradOg
;
grad_fg
[
i
]
=
r_grad_fg
;
((
__m256
*
)
grad
.
stateGrad
)[
i
]
=
rStateGrad
;
grad_og
[
i
]
=
r_grad_og
;
((
__m256
*
)
grad
.
state_grad
)[
i
]
=
r_state_grad
;
if
(
grad
.
prevStateGrad
)
((
__m256
*
)
grad
.
prevStateGrad
)[
i
]
=
rPrevStateGrad
;
if
(
value
.
prevStateValue
)
{
if
(
grad
.
prev_state_grad
)
if
(
grad
.
checkIgGrad
)
((
__m256
*
)
grad
.
checkIgGrad
)[
i
]
+=
rCheckIGrad
;
((
__m256
*
)
grad
.
prev_state_grad
)[
i
]
=
r_prev_state_grad
;
if
(
grad
.
checkFgGrad
)
((
__m256
*
)
grad
.
checkFgGrad
)[
i
]
+=
rCheckFGrad
;
if
(
value
.
prev_state_value
)
{
if
(
grad
.
check_ig_grad
)
((
__m256
*
)
grad
.
check_ig_grad
)[
i
]
+=
r_checkIGrad
;
if
(
grad
.
check_fg_grad
)
((
__m256
*
)
grad
.
check_fg_grad
)[
i
]
+=
r_checkFGrad
;
}
}
if
(
grad
.
check
OgGrad
)
((
__m256
*
)
grad
.
checkOgGrad
)[
i
]
+=
rC
heckOGrad
;
if
(
grad
.
check
_og_grad
)
((
__m256
*
)
grad
.
check_og_grad
)[
i
]
+=
r_c
heckOGrad
;
}
}
#endif
#endif
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
cpu_lstm_forward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
S
ize
,
void
cpu_lstm_forward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
if
(
Op
::
avx
&&
!
(
frame
S
ize
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
if
(
Op
::
avx
&&
!
(
frame
_s
ize
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame
S
ize
,
active_node
,
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame
_s
ize
,
active_node
,
active_gate
,
active_state
);
active_gate
,
active_state
);
}
else
{
}
else
{
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame
S
ize
,
active_node
,
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame
_s
ize
,
active_node
,
active_gate
,
active_state
);
active_gate
,
active_state
);
}
}
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
cpu_lstm_backward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
void
cpu_lstm_backward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
S
ize
,
activation_mode_t
active_node
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
if
(
Op
::
avx
&&
!
(
frame
S
ize
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
if
(
Op
::
avx
&&
!
(
frame
_s
ize
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame
S
ize
,
active_node
,
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame
_s
ize
,
active_node
,
active_gate
,
active_state
);
active_gate
,
active_state
);
}
else
{
}
else
{
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame
Size
,
active_nod
e
,
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame
_siz
e
,
active_gate
,
active_state
);
active_
node
,
active_
gate
,
active_state
);
}
}
}
}
...
...
paddle/operators/math/detail/lstm_gpu_kernel.h
浏览文件 @
74725d05
...
@@ -26,189 +26,192 @@ namespace math {
...
@@ -26,189 +26,192 @@ namespace math {
namespace
detail
{
namespace
detail
{
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
T
,
class
Op
,
bool
is
B
atch
>
template
<
class
T
,
class
Op
,
bool
is
_b
atch
>
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
S
ize
,
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
_s
ize
,
int
batch
S
ize
,
activation_mode_t
active_node
,
int
batch
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
const
int
frame
I
dx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame
_i
dx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
I
dx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch
_i
dx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
value
.
gate
Value
+=
batchIdx
*
frameS
ize
*
4
;
value
.
gate
_value
+=
batch_idx
*
frame_s
ize
*
4
;
value
.
output
Value
+=
batchIdx
*
frameS
ize
;
value
.
output
_value
+=
batch_idx
*
frame_s
ize
;
value
.
state
Value
+=
batchIdx
*
frameS
ize
;
value
.
state
_value
+=
batch_idx
*
frame_s
ize
;
value
.
state
ActiveValue
+=
batchIdx
*
frameS
ize
;
value
.
state
_active_value
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
S
tate
;
T
r
_s
tate
;
T
r
PrevS
tate
=
0
;
T
r
_prev_s
tate
=
0
;
T
r
StateA
tv
;
T
r
_state_a
tv
;
T
r
O
ut
;
T
r
_o
ut
;
T
r
ValueI
n
;
T
r
_value_i
n
;
T
r
ValueI
g
;
T
r
_value_i
g
;
T
r
ValueF
g
;
T
r
_value_f
g
;
T
r
ValueO
g
;
T
r
_value_o
g
;
T
r
CheckI
=
value
.
checkIg
?
value
.
checkIg
[
frameI
dx
]
:
0
;
T
r
_checkI
=
value
.
check_ig
?
value
.
check_ig
[
frame_i
dx
]
:
0
;
T
r
CheckF
=
value
.
checkFg
?
value
.
checkFg
[
frameI
dx
]
:
0
;
T
r
_checkF
=
value
.
check_fg
?
value
.
check_fg
[
frame_i
dx
]
:
0
;
T
r
CheckO
=
value
.
checkOg
?
value
.
checkOg
[
frameI
dx
]
:
0
;
T
r
_checkO
=
value
.
check_og
?
value
.
check_og
[
frame_i
dx
]
:
0
;
r
ValueIn
=
value
.
gateValue
[
frameI
dx
];
r
_value_in
=
value
.
gate_value
[
frame_i
dx
];
r
ValueIg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
];
r
_value_ig
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
];
r
ValueFg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
*
2
];
r
_value_fg
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
r
ValueOg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
*
3
];
r
_value_og
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
*
3
];
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
if
(
is
Batch
)
value
.
prevStateValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
value
.
prev_state_value
+=
batch_idx
*
frame_s
ize
;
r
PrevState
=
value
.
prevStateValue
[
frameI
dx
];
r
_prev_state
=
value
.
prev_state_value
[
frame_i
dx
];
}
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_prev_state
,
r_state
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
r_state_atv
,
r_out
,
r_checkI
,
r_checkF
,
r_checkO
,
active_node
,
active_gate
,
active_state
);
value
.
gate
Value
[
frameIdx
]
=
rValueI
n
;
value
.
gate
_value
[
frame_idx
]
=
r_value_i
n
;
value
.
gate
Value
[
frameIdx
+
frameSize
]
=
rValueI
g
;
value
.
gate
_value
[
frame_idx
+
frame_size
]
=
r_value_i
g
;
value
.
gate
Value
[
frameIdx
+
frameSize
*
2
]
=
rValueF
g
;
value
.
gate
_value
[
frame_idx
+
frame_size
*
2
]
=
r_value_f
g
;
value
.
gate
Value
[
frameIdx
+
frameSize
*
3
]
=
rValueO
g
;
value
.
gate
_value
[
frame_idx
+
frame_size
*
3
]
=
r_value_o
g
;
value
.
state
Value
[
frameIdx
]
=
rS
tate
;
value
.
state
_value
[
frame_idx
]
=
r_s
tate
;
value
.
state
ActiveValue
[
frameIdx
]
=
rStateA
tv
;
value
.
state
_active_value
[
frame_idx
]
=
r_state_a
tv
;
value
.
output
Value
[
frameIdx
]
=
rO
ut
;
value
.
output
_value
[
frame_idx
]
=
r_o
ut
;
}
}
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
T
,
class
Op
,
bool
is
B
atch
>
template
<
class
T
,
class
Op
,
bool
is
_b
atch
>
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
S
ize
,
LstmMetaGrad
<
T
>
grad
,
int
frame
_s
ize
,
int
batch
S
ize
,
activation_mode_t
active_node
,
int
batch
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
const
int
frame
I
dx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame
_i
dx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
I
dx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch
_i
dx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
value
.
gate
Value
+=
batchIdx
*
frameS
ize
*
4
;
value
.
gate
_value
+=
batch_idx
*
frame_s
ize
*
4
;
value
.
state
Value
+=
batchIdx
*
frameS
ize
;
value
.
state
_value
+=
batch_idx
*
frame_s
ize
;
value
.
state
ActiveValue
+=
batchIdx
*
frameS
ize
;
value
.
state
_active_value
+=
batch_idx
*
frame_s
ize
;
grad
.
gate
Grad
+=
batchIdx
*
frameS
ize
*
4
;
grad
.
gate
_grad
+=
batch_idx
*
frame_s
ize
*
4
;
grad
.
state
Grad
+=
batchIdx
*
frameS
ize
;
grad
.
state
_grad
+=
batch_idx
*
frame_s
ize
;
grad
.
output
Grad
+=
batchIdx
*
frameS
ize
;
grad
.
output
_grad
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
ValueI
n
;
T
r
_value_i
n
;
T
r
ValueI
g
;
T
r
_value_i
g
;
T
r
ValueF
g
;
T
r
_value_f
g
;
T
r
ValueO
g
;
T
r
_value_o
g
;
T
r
GradI
n
;
T
r
_grad_i
n
;
T
r
GradI
g
;
T
r
_grad_i
g
;
T
r
GradF
g
;
T
r
_grad_f
g
;
T
r
GradO
g
;
T
r
_grad_o
g
;
T
r
PrevS
tate
=
0
;
T
r
_prev_s
tate
=
0
;
T
r
PrevStateG
rad
;
T
r
_prev_state_g
rad
;
T
r
S
tate
;
T
r
_s
tate
;
T
r
StateG
rad
;
T
r
_state_g
rad
;
T
r
StateA
tv
;
T
r
_state_a
tv
;
T
r
OutputG
rad
;
T
r
_output_g
rad
;
T
r
CheckI
=
value
.
checkIg
?
value
.
checkIg
[
frameI
dx
]
:
0
;
T
r
_checkI
=
value
.
check_ig
?
value
.
check_ig
[
frame_i
dx
]
:
0
;
T
r
CheckF
=
value
.
checkFg
?
value
.
checkFg
[
frameI
dx
]
:
0
;
T
r
_checkF
=
value
.
check_fg
?
value
.
check_fg
[
frame_i
dx
]
:
0
;
T
r
CheckO
=
value
.
checkOg
?
value
.
checkOg
[
frameI
dx
]
:
0
;
T
r
_checkO
=
value
.
check_og
?
value
.
check_og
[
frame_i
dx
]
:
0
;
T
r
C
heckIGrad
;
T
r
_c
heckIGrad
;
T
r
C
heckFGrad
;
T
r
_c
heckFGrad
;
T
r
C
heckOGrad
;
T
r
_c
heckOGrad
;
r
ValueIn
=
value
.
gateValue
[
frameI
dx
];
r
_value_in
=
value
.
gate_value
[
frame_i
dx
];
r
ValueIg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
];
r
_value_ig
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
];
r
ValueFg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
*
2
];
r
_value_fg
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
r
ValueOg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
*
3
];
r
_value_og
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
*
3
];
r
State
=
value
.
stateValue
[
frameI
dx
];
r
_state
=
value
.
state_value
[
frame_i
dx
];
r
StateAtv
=
value
.
stateActiveValue
[
frameI
dx
];
r
_state_atv
=
value
.
state_active_value
[
frame_i
dx
];
r
OutputGrad
=
grad
.
outputGrad
[
frameI
dx
];
r
_output_grad
=
grad
.
output_grad
[
frame_i
dx
];
r
StateGrad
=
grad
.
stateGrad
[
frameI
dx
];
r
_state_grad
=
grad
.
state_grad
[
frame_i
dx
];
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
if
(
is
Batch
)
value
.
prevStateValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
value
.
prev_state_value
+=
batch_idx
*
frame_s
ize
;
r
PrevState
=
value
.
prevStateValue
[
frameI
dx
];
r
_prev_state
=
value
.
prev_state_value
[
frame_i
dx
];
}
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_grad_in
,
r_grad_ig
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
r_grad_fg
,
r_grad_og
,
r_prev_state
,
r_prev_state_grad
,
r_state
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
r_state_grad
,
r_state_atv
,
r_output_grad
,
r_checkI
,
r_checkF
,
r_checkO
,
active_node
,
active_gate
,
active_state
);
r_checkIGrad
,
r_checkFGrad
,
r_checkOGrad
,
active_node
,
active_gate
,
active_state
);
grad
.
gateGrad
[
frameIdx
]
=
rGradIn
;
grad
.
gateGrad
[
frameIdx
+
frameSize
]
=
rGradIg
;
grad
.
gate_grad
[
frame_idx
]
=
r_grad_in
;
grad
.
gateGrad
[
frameIdx
+
frameSize
*
2
]
=
rGradFg
;
grad
.
gate_grad
[
frame_idx
+
frame_size
]
=
r_grad_ig
;
grad
.
gateGrad
[
frameIdx
+
frameSize
*
3
]
=
rGradOg
;
grad
.
gate_grad
[
frame_idx
+
frame_size
*
2
]
=
r_grad_fg
;
grad
.
stateGrad
[
frameIdx
]
=
rStateGrad
;
grad
.
gate_grad
[
frame_idx
+
frame_size
*
3
]
=
r_grad_og
;
if
(
grad
.
prevStateGrad
)
{
grad
.
state_grad
[
frame_idx
]
=
r_state_grad
;
if
(
isBatch
)
grad
.
prevStateGrad
+=
batchIdx
*
frameSize
;
if
(
grad
.
prev_state_grad
)
{
grad
.
prevStateGrad
[
frameIdx
]
=
rPrevStateGrad
;
if
(
is_batch
)
grad
.
prev_state_grad
+=
batch_idx
*
frame_size
;
grad
.
prev_state_grad
[
frame_idx
]
=
r_prev_state_grad
;
}
}
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
if
(
grad
.
check
IgG
rad
)
if
(
grad
.
check
_ig_g
rad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check
IgGrad
+
frameI
dx
,
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check
_ig_grad
+
frame_i
dx
,
r
C
heckIGrad
);
r
_c
heckIGrad
);
if
(
grad
.
check
FgG
rad
)
if
(
grad
.
check
_fg_g
rad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check
FgGrad
+
frameI
dx
,
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check
_fg_grad
+
frame_i
dx
,
r
C
heckFGrad
);
r
_c
heckFGrad
);
}
}
if
(
grad
.
checkOgGrad
)
if
(
grad
.
check_og_grad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
checkOgGrad
+
frameIdx
,
rCheckOGrad
);
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check_og_grad
+
frame_idx
,
r_checkOGrad
);
}
else
{
}
else
{
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
if
(
grad
.
check
IgGrad
)
grad
.
checkIgGrad
[
frameIdx
]
+=
rC
heckIGrad
;
if
(
grad
.
check
_ig_grad
)
grad
.
check_ig_grad
[
frame_idx
]
+=
r_c
heckIGrad
;
if
(
grad
.
check
FgGrad
)
grad
.
checkFgGrad
[
frameIdx
]
+=
rC
heckFGrad
;
if
(
grad
.
check
_fg_grad
)
grad
.
check_fg_grad
[
frame_idx
]
+=
r_c
heckFGrad
;
}
}
if
(
grad
.
check
OgGrad
)
grad
.
checkOgGrad
[
frameIdx
]
+=
rC
heckOGrad
;
if
(
grad
.
check
_og_grad
)
grad
.
check_og_grad
[
frame_idx
]
+=
r_c
heckOGrad
;
}
}
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
gpu_lstm_forward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
void
gpu_lstm_forward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
LstmMetaValue
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
}
else
{
/* frame
PerBlock = 32 batchPerB
lock = 32 */
/* frame
_per_block = 32 batch_per_b
lock = 32 */
threads
=
dim3
(
32
,
32
);
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
32
-
1
)
/
32
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
32
-
1
)
/
32
);
}
}
auto
stream
=
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
KeLstmForward
<
T
,
Op
,
KeLstmForward
<
T
,
Op
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frame
Size
,
batchS
ize
,
active_node
,
active_gate
,
op
,
value
,
frame
_size
,
batch_s
ize
,
active_node
,
active_gate
,
active_state
);
active_state
);
}
else
{
}
else
{
KeLstmForward
<
T
,
Op
,
KeLstmForward
<
T
,
Op
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frame
Size
,
batchS
ize
,
active_node
,
active_gate
,
op
,
value
,
frame
_size
,
batch_s
ize
,
active_node
,
active_gate
,
active_state
);
active_state
);
}
}
}
}
...
@@ -216,34 +219,34 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
...
@@ -216,34 +219,34 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
gpu_lstm_backward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
void
gpu_lstm_backward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
Size
,
int
batchS
ize
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
}
else
{
/* frame
PerBlock = 32 batchPerB
lock = 16 */
/* frame
_per_block = 32 batch_per_b
lock = 16 */
threads
=
dim3
(
32
,
16
);
threads
=
dim3
(
32
,
16
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
16
-
1
)
/
16
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
16
-
1
)
/
16
);
}
}
auto
stream
=
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
KeLstmBackward
<
T
,
Op
,
KeLstmBackward
<
T
,
Op
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frame
Size
,
batchS
ize
,
active_node
,
active_gate
,
op
,
value
,
grad
,
frame
_size
,
batch_s
ize
,
active_node
,
active_gate
,
active_state
);
active_state
);
}
else
{
}
else
{
KeLstmBackward
<
T
,
Op
,
KeLstmBackward
<
T
,
Op
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frame
Size
,
batchS
ize
,
active_node
,
active_gate
,
op
,
value
,
grad
,
frame
_size
,
batch_s
ize
,
active_node
,
active_gate
,
active_state
);
active_state
);
}
}
}
}
...
...
paddle/operators/math/detail/lstm_kernel.h
浏览文件 @
74725d05
...
@@ -27,19 +27,19 @@ namespace forward {
...
@@ -27,19 +27,19 @@ namespace forward {
template
<
class
T
>
template
<
class
T
>
class
lstm
{
class
lstm
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
value
In
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueO
g
,
HOSTDEVICE
void
operator
()(
T
&
value
_in
,
T
&
value_ig
,
T
&
value_fg
,
T
&
value_o
g
,
T
&
prev
State
,
T
&
state
,
T
&
stateA
tv
,
T
&
output
,
T
&
prev
_state
,
T
&
state
,
T
&
state_a
tv
,
T
&
output
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
value
In
=
activation
(
valueI
n
,
active_node
);
value
_in
=
activation
(
value_i
n
,
active_node
);
value
Ig
=
activation
(
valueIg
+
prevS
tate
*
checkI
,
active_gate
);
value
_ig
=
activation
(
value_ig
+
prev_s
tate
*
checkI
,
active_gate
);
value
Fg
=
activation
(
valueFg
+
prevS
tate
*
checkF
,
active_gate
);
value
_fg
=
activation
(
value_fg
+
prev_s
tate
*
checkF
,
active_gate
);
state
=
value
In
*
valueIg
+
prevState
*
valueF
g
;
state
=
value
_in
*
value_ig
+
prev_state
*
value_f
g
;
value
Og
=
activation
(
valueO
g
+
state
*
checkO
,
active_gate
);
value
_og
=
activation
(
value_o
g
+
state
*
checkO
,
active_gate
);
state
A
tv
=
activation
(
state
,
active_state
);
state
_a
tv
=
activation
(
state
,
active_state
);
output
=
value
Og
*
stateA
tv
;
output
=
value
_og
*
state_a
tv
;
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...
@@ -48,24 +48,27 @@ class lstm {
...
@@ -48,24 +48,27 @@ class lstm {
// Only float support AVX optimization
// Only float support AVX optimization
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
HOSTDEVICE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
HOSTDEVICE
void
operator
()(
__m256
&
value_in
,
__m256
&
value_ig
,
__m256
&
valueOg
,
__m256
&
prevState
,
__m256
&
state
,
__m256
&
value_fg
,
__m256
&
value_og
,
__m256
&
stateAtv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
prev_state
,
__m256
&
state
,
__m256
&
state_atv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkF
,
__m256
&
checkO
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
valueIn
=
activation
(
valueIn
,
active_node
);
value_in
=
activation
(
value_in
,
active_node
);
valueIg
=
activation
(
value_ig
=
_mm256_add_ps
(
valueIg
,
_mm256_mul_ps
(
prevState
,
checkI
)),
active_gate
);
activation
(
_mm256_add_ps
(
value_ig
,
_mm256_mul_ps
(
prev_state
,
checkI
)),
valueFg
=
activation
(
active_gate
);
_mm256_add_ps
(
valueFg
,
_mm256_mul_ps
(
prevState
,
checkF
)),
active_gate
);
value_fg
=
state
=
_mm256_add_ps
(
_mm256_mul_ps
(
valueIn
,
valueIg
),
activation
(
_mm256_add_ps
(
value_fg
,
_mm256_mul_ps
(
prev_state
,
checkF
)),
_mm256_mul_ps
(
prevState
,
valueFg
));
active_gate
);
valueOg
=
activation
(
_mm256_add_ps
(
valueOg
,
_mm256_mul_ps
(
state
,
checkO
)),
state
=
_mm256_add_ps
(
_mm256_mul_ps
(
value_in
,
value_ig
),
active_gate
);
_mm256_mul_ps
(
prev_state
,
value_fg
));
stateAtv
=
activation
(
state
,
active_state
);
value_og
=
activation
(
_mm256_add_ps
(
value_og
,
_mm256_mul_ps
(
state
,
checkO
)),
output
=
_mm256_mul_ps
(
valueOg
,
stateAtv
);
active_gate
);
state_atv
=
activation
(
state
,
active_state
);
output
=
_mm256_mul_ps
(
value_og
,
state_atv
);
}
}
#endif
#endif
#endif
#endif
...
@@ -78,25 +81,26 @@ namespace backward {
...
@@ -78,25 +81,26 @@ namespace backward {
template
<
class
T
>
template
<
class
T
>
class
lstm
{
class
lstm
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
value
In
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueO
g
,
HOSTDEVICE
void
operator
()(
T
&
value
_in
,
T
&
value_ig
,
T
&
value_fg
,
T
&
value_o
g
,
T
&
grad
In
,
T
&
gradIg
,
T
&
gradFg
,
T
&
gradO
g
,
T
&
grad
_in
,
T
&
grad_ig
,
T
&
grad_fg
,
T
&
grad_o
g
,
T
&
prev
State
,
T
&
prevStateG
rad
,
T
&
state
,
T
&
prev
_state
,
T
&
prev_state_g
rad
,
T
&
state
,
T
&
state
Grad
,
T
&
stateAtv
,
T
&
outputG
rad
,
T
&
state
_grad
,
T
&
state_atv
,
T
&
output_g
rad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
gradOg
=
activation
(
outputGrad
*
stateAtv
,
valueOg
,
active_gate
);
grad_og
=
activation
(
output_grad
*
state_atv
,
value_og
,
active_gate
);
stateGrad
+=
activation
(
outputGrad
*
valueOg
,
stateAtv
,
active_state
)
+
state_grad
+=
activation
(
output_grad
*
value_og
,
state_atv
,
active_state
)
+
gradOg
*
checkO
;
grad_og
*
checkO
;
gradIn
=
activation
(
stateGrad
*
valueIg
,
valueIn
,
active_node
);
grad_in
=
activation
(
state_grad
*
value_ig
,
value_in
,
active_node
);
gradIg
=
activation
(
stateGrad
*
valueIn
,
valueIg
,
active_gate
);
grad_ig
=
activation
(
state_grad
*
value_in
,
value_ig
,
active_gate
);
gradFg
=
activation
(
stateGrad
*
prevState
,
valueFg
,
active_gate
);
grad_fg
=
activation
(
state_grad
*
prev_state
,
value_fg
,
active_gate
);
prevStateGrad
=
gradIg
*
checkI
+
gradFg
*
checkF
+
stateGrad
*
valueFg
;
prev_state_grad
=
checkIGrad
=
gradIg
*
prevState
;
grad_ig
*
checkI
+
grad_fg
*
checkF
+
state_grad
*
value_fg
;
checkFGrad
=
gradFg
*
prevState
;
checkIGrad
=
grad_ig
*
prev_state
;
checkOGrad
=
gradOg
*
state
;
checkFGrad
=
grad_fg
*
prev_state
;
checkOGrad
=
grad_og
*
state
;
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...
@@ -105,32 +109,32 @@ class lstm {
...
@@ -105,32 +109,32 @@ class lstm {
// Only float support AVX optimization
// Only float support AVX optimization
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
HOSTDEVICE
void
operator
()(
HOSTDEVICE
void
operator
()(
__m256
&
value
In
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueO
g
,
__m256
&
value
_in
,
__m256
&
value_ig
,
__m256
&
value_fg
,
__m256
&
value_o
g
,
__m256
&
grad
In
,
__m256
&
gradIg
,
__m256
&
gradFg
,
__m256
&
gradO
g
,
__m256
&
grad
_in
,
__m256
&
grad_ig
,
__m256
&
grad_fg
,
__m256
&
grad_o
g
,
__m256
&
prev
State
,
__m256
&
prevStateG
rad
,
__m256
&
state
,
__m256
&
prev
_state
,
__m256
&
prev_state_g
rad
,
__m256
&
state
,
__m256
&
state
Grad
,
__m256
&
stateAtv
,
__m256
&
outputGrad
,
__m256
&
checkI
,
__m256
&
state
_grad
,
__m256
&
state_atv
,
__m256
&
output_grad
,
__m256
&
check
F
,
__m256
&
checkO
,
__m256
&
checkIGrad
,
__m256
&
checkF
Grad
,
__m256
&
check
I
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkI
Grad
,
__m256
&
checkOGrad
,
activation_mode_t
active_node
,
__m256
&
check
FGrad
,
__m256
&
check
OGrad
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
grad
Og
=
grad
_og
=
activation
(
_mm256_mul_ps
(
output_grad
,
state_atv
),
value_og
,
activation
(
_mm256_mul_ps
(
outputGrad
,
stateAtv
),
valueOg
,
active_gate
);
active_gate
);
state
Grad
=
_mm256_add_ps
(
state
_grad
=
_mm256_add_ps
(
activation
(
_mm256_mul_ps
(
output_grad
,
value_og
),
activation
(
_mm256_mul_ps
(
outputGrad
,
valueOg
),
stateA
tv
,
active_state
),
state_a
tv
,
active_state
),
stateG
rad
);
state_g
rad
);
state
Grad
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradOg
,
checkO
),
stateG
rad
);
state
_grad
=
_mm256_add_ps
(
_mm256_mul_ps
(
grad_og
,
checkO
),
state_g
rad
);
grad
I
n
=
grad
_i
n
=
activation
(
_mm256_mul_ps
(
state
Grad
,
valueIg
),
valueI
n
,
active_node
);
activation
(
_mm256_mul_ps
(
state
_grad
,
value_ig
),
value_i
n
,
active_node
);
grad
I
g
=
grad
_i
g
=
activation
(
_mm256_mul_ps
(
state
Grad
,
valueIn
),
valueI
g
,
active_gate
);
activation
(
_mm256_mul_ps
(
state
_grad
,
value_in
),
value_i
g
,
active_gate
);
grad
Fg
=
grad
_fg
=
activation
(
_mm256_mul_ps
(
state_grad
,
prev_state
),
value_fg
,
activation
(
_mm256_mul_ps
(
stateGrad
,
prevState
),
valueFg
,
active_gate
);
active_gate
);
prev
StateGrad
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradI
g
,
checkI
),
prev
_state_grad
=
_mm256_add_ps
(
_mm256_mul_ps
(
grad_i
g
,
checkI
),
_mm256_mul_ps
(
gradF
g
,
checkF
));
_mm256_mul_ps
(
grad_f
g
,
checkF
));
prev
StateG
rad
=
prev
_state_g
rad
=
_mm256_add_ps
(
_mm256_mul_ps
(
state
Grad
,
valueFg
),
prevStateG
rad
);
_mm256_add_ps
(
_mm256_mul_ps
(
state
_grad
,
value_fg
),
prev_state_g
rad
);
checkIGrad
=
_mm256_mul_ps
(
grad
Ig
,
prevS
tate
);
checkIGrad
=
_mm256_mul_ps
(
grad
_ig
,
prev_s
tate
);
checkFGrad
=
_mm256_mul_ps
(
grad
Fg
,
prevS
tate
);
checkFGrad
=
_mm256_mul_ps
(
grad
_fg
,
prev_s
tate
);
checkOGrad
=
_mm256_mul_ps
(
grad
O
g
,
state
);
checkOGrad
=
_mm256_mul_ps
(
grad
_o
g
,
state
);
}
}
#endif
#endif
#endif
#endif
...
...
paddle/operators/math/lstm_compute.cc
浏览文件 @
74725d05
...
@@ -30,12 +30,12 @@ struct LstmUnitFunctor<platform::CPUPlace, T> {
...
@@ -30,12 +30,12 @@ struct LstmUnitFunctor<platform::CPUPlace, T> {
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
ActiveType
(
cell_act
));
value
.
gate
V
alue
+=
frame_size
*
4
;
value
.
gate
_v
alue
+=
frame_size
*
4
;
value
.
state
V
alue
+=
frame_size
;
value
.
state
_v
alue
+=
frame_size
;
value
.
state
ActiveV
alue
+=
frame_size
;
value
.
state
_active_v
alue
+=
frame_size
;
value
.
output
V
alue
+=
frame_size
;
value
.
output
_v
alue
+=
frame_size
;
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
value
.
prev
StateV
alue
+=
frame_size
;
value
.
prev
_state_v
alue
+=
frame_size
;
}
}
}
}
}
}
...
@@ -53,20 +53,20 @@ struct LstmUnitGradFunctor<platform::CPUPlace, T> {
...
@@ -53,20 +53,20 @@ struct LstmUnitGradFunctor<platform::CPUPlace, T> {
frame_size
,
ActiveType
(
cand_act
),
frame_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
value
.
gate
V
alue
+=
frame_size
*
4
;
value
.
gate
_v
alue
+=
frame_size
*
4
;
value
.
state
V
alue
+=
frame_size
;
value
.
state
_v
alue
+=
frame_size
;
value
.
state
ActiveV
alue
+=
frame_size
;
value
.
state
_active_v
alue
+=
frame_size
;
value
.
output
V
alue
+=
frame_size
;
value
.
output
_v
alue
+=
frame_size
;
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
value
.
prev
StateV
alue
+=
frame_size
;
value
.
prev
_state_v
alue
+=
frame_size
;
}
}
grad
.
gate
G
rad
+=
frame_size
*
4
;
grad
.
gate
_g
rad
+=
frame_size
*
4
;
grad
.
state
G
rad
+=
frame_size
;
grad
.
state
_g
rad
+=
frame_size
;
grad
.
state
ActiveG
rad
+=
frame_size
;
grad
.
state
_active_g
rad
+=
frame_size
;
grad
.
output
G
rad
+=
frame_size
;
grad
.
output
_g
rad
+=
frame_size
;
if
(
grad
.
prev
StateG
rad
)
{
if
(
grad
.
prev
_state_g
rad
)
{
grad
.
prev
StateG
rad
+=
frame_size
;
grad
.
prev
_state_g
rad
+=
frame_size
;
}
}
}
}
}
}
...
...
paddle/operators/math/lstm_compute.h
浏览文件 @
74725d05
...
@@ -31,26 +31,26 @@ typedef enum {
...
@@ -31,26 +31,26 @@ typedef enum {
template
<
class
T
>
template
<
class
T
>
struct
LstmMetaValue
{
struct
LstmMetaValue
{
T
*
gate
V
alue
;
T
*
gate
_v
alue
;
T
*
prev
StateV
alue
;
T
*
prev
_state_v
alue
;
T
*
state
V
alue
;
T
*
state
_v
alue
;
T
*
state
ActiveV
alue
;
T
*
state
_active_v
alue
;
T
*
output
V
alue
;
T
*
output
_v
alue
;
T
*
check
I
g
;
T
*
check
_i
g
;
T
*
check
F
g
;
T
*
check
_f
g
;
T
*
check
O
g
;
T
*
check
_o
g
;
};
};
template
<
class
T
>
template
<
class
T
>
struct
LstmMetaGrad
{
struct
LstmMetaGrad
{
T
*
gate
G
rad
;
T
*
gate
_g
rad
;
T
*
prev
StateG
rad
;
T
*
prev
_state_g
rad
;
T
*
state
G
rad
;
T
*
state
_g
rad
;
T
*
state
ActiveG
rad
;
T
*
state
_active_g
rad
;
T
*
output
G
rad
;
T
*
output
_g
rad
;
T
*
check
IgG
rad
;
T
*
check
_ig_g
rad
;
T
*
check
FgG
rad
;
T
*
check
_fg_g
rad
;
T
*
check
OgG
rad
;
T
*
check
_og_g
rad
;
};
};
inline
activation_mode_t
ActiveType
(
const
std
::
string
&
type
)
{
inline
activation_mode_t
ActiveType
(
const
std
::
string
&
type
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录