Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
74725d05
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
74725d05
编写于
12月 04, 2017
作者:
Q
qingqing01
提交者:
GitHub
12月 04, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #6198 from qingqing01/lstm_style
Make lstm_op follow google code style.
上级
0c15b6c2
e5b51c4d
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
505 addition
and
492 deletion
+505
-492
paddle/operators/lstm_op.h
paddle/operators/lstm_op.h
+35
-35
paddle/operators/math/detail/lstm_cpu_kernel.h
paddle/operators/math/detail/lstm_cpu_kernel.h
+216
-210
paddle/operators/math/detail/lstm_gpu_kernel.h
paddle/operators/math/detail/lstm_gpu_kernel.h
+154
-151
paddle/operators/math/detail/lstm_kernel.h
paddle/operators/math/detail/lstm_kernel.h
+66
-62
paddle/operators/math/lstm_compute.cc
paddle/operators/math/lstm_compute.cc
+18
-18
paddle/operators/math/lstm_compute.h
paddle/operators/math/lstm_compute.h
+16
-16
未找到文件。
paddle/operators/lstm_op.h
浏览文件 @
74725d05
...
@@ -73,15 +73,15 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -73,15 +73,15 @@ class LSTMKernel : public framework::OpKernel<T> {
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
// the code style in LstmMetaValue will be updated later.
// the code style in LstmMetaValue will be updated later.
lstm_value
.
check
I
g
=
bias_data
+
4
*
frame_size
;
lstm_value
.
check
_i
g
=
bias_data
+
4
*
frame_size
;
lstm_value
.
check
Fg
=
lstm_value
.
checkI
g
+
frame_size
;
lstm_value
.
check
_fg
=
lstm_value
.
check_i
g
+
frame_size
;
lstm_value
.
check
Og
=
lstm_value
.
checkF
g
+
frame_size
;
lstm_value
.
check
_og
=
lstm_value
.
check_f
g
+
frame_size
;
}
else
{
}
else
{
lstm_value
.
check
I
g
=
nullptr
;
lstm_value
.
check
_i
g
=
nullptr
;
lstm_value
.
check
F
g
=
nullptr
;
lstm_value
.
check
_f
g
=
nullptr
;
lstm_value
.
check
O
g
=
nullptr
;
lstm_value
.
check
_o
g
=
nullptr
;
}
}
lstm_value
.
prev
StateV
alue
=
nullptr
;
lstm_value
.
prev
_state_v
alue
=
nullptr
;
Tensor
ordered_c0
;
Tensor
ordered_c0
;
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
const
size_t
*
order
=
batch_gate
->
lod
()[
2
].
data
();
if
(
cell_t0
)
{
if
(
cell_t0
)
{
...
@@ -90,7 +90,7 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -90,7 +90,7 @@ class LSTMKernel : public framework::OpKernel<T> {
// to reorder.
// to reorder.
ReorderInitState
<
Place
,
T
>
(
device_ctx
,
*
cell_t0
,
order
,
&
ordered_c0
,
ReorderInitState
<
Place
,
T
>
(
device_ctx
,
*
cell_t0
,
order
,
&
ordered_c0
,
true
);
true
);
lstm_value
.
prev
StateV
alue
=
ordered_c0
.
data
<
T
>
();
lstm_value
.
prev
_state_v
alue
=
ordered_c0
.
data
<
T
>
();
}
}
// Use the local variable as here.
// Use the local variable as here.
...
@@ -140,14 +140,14 @@ class LSTMKernel : public framework::OpKernel<T> {
...
@@ -140,14 +140,14 @@ class LSTMKernel : public framework::OpKernel<T> {
static_cast
<
T
>
(
1.0
));
static_cast
<
T
>
(
1.0
));
}
}
lstm_value
.
gate
V
alue
=
gate_t
.
data
<
T
>
();
lstm_value
.
gate
_v
alue
=
gate_t
.
data
<
T
>
();
lstm_value
.
output
V
alue
=
out_t
.
data
<
T
>
();
lstm_value
.
output
_v
alue
=
out_t
.
data
<
T
>
();
lstm_value
.
state
V
alue
=
cell_t
.
data
<
T
>
();
lstm_value
.
state
_v
alue
=
cell_t
.
data
<
T
>
();
lstm_value
.
state
ActiveV
alue
=
cell_pre_act_t
.
data
<
T
>
();
lstm_value
.
state
_active_v
alue
=
cell_pre_act_t
.
data
<
T
>
();
math
::
LstmUnitFunctor
<
Place
,
T
>::
compute
(
device_ctx
,
lstm_value
,
math
::
LstmUnitFunctor
<
Place
,
T
>::
compute
(
device_ctx
,
lstm_value
,
frame_size
,
cur_batch_size
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
gate_act
,
cell_act
,
cand_act
);
lstm_value
.
prev
StateValue
=
lstm_value
.
stateV
alue
;
lstm_value
.
prev
_state_value
=
lstm_value
.
state_v
alue
;
}
}
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
Place
,
T
>
to_seq
;
...
@@ -214,13 +214,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -214,13 +214,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math
::
LstmMetaValue
<
T
>
lstm_value
;
math
::
LstmMetaValue
<
T
>
lstm_value
;
if
(
bias
&&
ctx
.
Attr
<
bool
>
(
"use_peepholes"
))
{
if
(
bias
&&
ctx
.
Attr
<
bool
>
(
"use_peepholes"
))
{
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
T
*
bias_data
=
const_cast
<
T
*>
(
bias
->
data
<
T
>
());
lstm_value
.
check
I
g
=
bias_data
+
4
*
frame_size
;
lstm_value
.
check
_i
g
=
bias_data
+
4
*
frame_size
;
lstm_value
.
check
Fg
=
lstm_value
.
checkI
g
+
frame_size
;
lstm_value
.
check
_fg
=
lstm_value
.
check_i
g
+
frame_size
;
lstm_value
.
check
Og
=
lstm_value
.
checkF
g
+
frame_size
;
lstm_value
.
check
_og
=
lstm_value
.
check_f
g
+
frame_size
;
}
else
{
}
else
{
lstm_value
.
check
I
g
=
nullptr
;
lstm_value
.
check
_i
g
=
nullptr
;
lstm_value
.
check
F
g
=
nullptr
;
lstm_value
.
check
_f
g
=
nullptr
;
lstm_value
.
check
O
g
=
nullptr
;
lstm_value
.
check
_o
g
=
nullptr
;
}
}
math
::
LstmMetaGrad
<
T
>
lstm_grad
;
math
::
LstmMetaGrad
<
T
>
lstm_grad
;
...
@@ -231,13 +231,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -231,13 +231,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
}
}
if
(
bias
&&
bias_g
&&
ctx
.
Attr
<
bool
>
(
"use_peepholes"
))
{
if
(
bias
&&
bias_g
&&
ctx
.
Attr
<
bool
>
(
"use_peepholes"
))
{
T
*
bias_g_data
=
bias_g
->
data
<
T
>
();
T
*
bias_g_data
=
bias_g
->
data
<
T
>
();
lstm_grad
.
check
IgG
rad
=
bias_g_data
+
4
*
frame_size
;
lstm_grad
.
check
_ig_g
rad
=
bias_g_data
+
4
*
frame_size
;
lstm_grad
.
check
FgGrad
=
lstm_grad
.
checkIgG
rad
+
frame_size
;
lstm_grad
.
check
_fg_grad
=
lstm_grad
.
check_ig_g
rad
+
frame_size
;
lstm_grad
.
check
OgGrad
=
lstm_grad
.
checkFgG
rad
+
frame_size
;
lstm_grad
.
check
_og_grad
=
lstm_grad
.
check_fg_g
rad
+
frame_size
;
}
else
{
}
else
{
lstm_grad
.
check
IgG
rad
=
nullptr
;
lstm_grad
.
check
_ig_g
rad
=
nullptr
;
lstm_grad
.
check
FgG
rad
=
nullptr
;
lstm_grad
.
check
_fg_g
rad
=
nullptr
;
lstm_grad
.
check
OgG
rad
=
nullptr
;
lstm_grad
.
check
_og_g
rad
=
nullptr
;
}
}
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
math
::
LoDTensor2BatchFunctor
<
Place
,
T
>
to_batch
;
...
@@ -276,26 +276,26 @@ class LSTMGradKernel : public framework::OpKernel<T> {
...
@@ -276,26 +276,26 @@ class LSTMGradKernel : public framework::OpKernel<T> {
Tensor
gate
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
gate
=
batch_gate
->
Slice
(
bstart
,
bend
);
Tensor
cell
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act
=
batch_cell_pre_act
->
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act
=
batch_cell_pre_act
->
Slice
(
bstart
,
bend
);
lstm_value
.
gate
V
alue
=
gate
.
data
<
T
>
();
lstm_value
.
gate
_v
alue
=
gate
.
data
<
T
>
();
lstm_value
.
state
V
alue
=
cell
.
data
<
T
>
();
lstm_value
.
state
_v
alue
=
cell
.
data
<
T
>
();
lstm_value
.
state
ActiveV
alue
=
cell_pre_act
.
data
<
T
>
();
lstm_value
.
state
_active_v
alue
=
cell_pre_act
.
data
<
T
>
();
Tensor
out_g
=
batch_hidden_g
.
Slice
(
bstart
,
bend
);
Tensor
out_g
=
batch_hidden_g
.
Slice
(
bstart
,
bend
);
Tensor
gate_g
=
batch_gate_g
.
Slice
(
bstart
,
bend
);
Tensor
gate_g
=
batch_gate_g
.
Slice
(
bstart
,
bend
);
Tensor
cell_g
=
batch_cell_g
.
Slice
(
bstart
,
bend
);
Tensor
cell_g
=
batch_cell_g
.
Slice
(
bstart
,
bend
);
lstm_grad
.
state
G
rad
=
cell_g
.
data
<
T
>
();
lstm_grad
.
state
_g
rad
=
cell_g
.
data
<
T
>
();
lstm_grad
.
gate
G
rad
=
gate_g
.
data
<
T
>
();
lstm_grad
.
gate
_g
rad
=
gate_g
.
data
<
T
>
();
lstm_grad
.
output
G
rad
=
out_g
.
data
<
T
>
();
lstm_grad
.
output
_g
rad
=
out_g
.
data
<
T
>
();
if
(
n
>
0
)
{
if
(
n
>
0
)
{
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
bstart_pre
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
Tensor
cell_pre
=
batch_cell
.
Slice
(
bstart_pre
,
bstart
);
Tensor
cell_pre
=
batch_cell
.
Slice
(
bstart_pre
,
bstart
);
Tensor
cell_pre_g
=
batch_cell_g
.
Slice
(
bstart_pre
,
bstart
);
Tensor
cell_pre_g
=
batch_cell_g
.
Slice
(
bstart_pre
,
bstart
);
lstm_value
.
prev
StateV
alue
=
cell_pre
.
data
<
T
>
();
lstm_value
.
prev
_state_v
alue
=
cell_pre
.
data
<
T
>
();
lstm_grad
.
prev
StateG
rad
=
cell_pre_g
.
data
<
T
>
();
lstm_grad
.
prev
_state_g
rad
=
cell_pre_g
.
data
<
T
>
();
}
else
{
}
else
{
lstm_value
.
prev
StateV
alue
=
c0
?
ordered_c0
.
data
<
T
>
()
:
nullptr
;
lstm_value
.
prev
_state_v
alue
=
c0
?
ordered_c0
.
data
<
T
>
()
:
nullptr
;
lstm_grad
.
prev
StateG
rad
=
c0_g
?
ordered_c0_g
.
data
<
T
>
()
:
nullptr
;
lstm_grad
.
prev
_state_g
rad
=
c0_g
?
ordered_c0_g
.
data
<
T
>
()
:
nullptr
;
}
}
int
cur_batch_size
=
bend
-
bstart
;
int
cur_batch_size
=
bend
-
bstart
;
...
...
paddle/operators/math/detail/lstm_cpu_kernel.h
浏览文件 @
74725d05
...
@@ -26,278 +26,284 @@ namespace detail {
...
@@ -26,278 +26,284 @@ namespace detail {
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
naive_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
S
ize
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
T
r
ValueI
n
;
T
r
_value_i
n
;
T
r
ValueI
g
;
T
r
_value_i
g
;
T
r
ValueF
g
;
T
r
_value_f
g
;
T
r
ValueO
g
;
T
r
_value_o
g
;
T
r
C
heckI
;
T
r
_c
heckI
;
T
r
C
heckF
;
T
r
_c
heckF
;
T
r
C
heckO
;
T
r
_c
heckO
;
T
r
S
tate
;
T
r
_s
tate
;
T
r
PrevS
tate
=
0
;
T
r
_prev_s
tate
=
0
;
T
r
StateA
tv
;
T
r
_state_a
tv
;
T
r
O
ut
;
T
r
_o
ut
;
T
*
value
In
=
value
.
gateV
alue
;
T
*
value
_in
=
value
.
gate_v
alue
;
T
*
value
Ig
=
value
.
gateValue
+
frameS
ize
;
T
*
value
_ig
=
value
.
gate_value
+
frame_s
ize
;
T
*
value
Fg
=
value
.
gateValue
+
frameS
ize
*
2
;
T
*
value
_fg
=
value
.
gate_value
+
frame_s
ize
*
2
;
T
*
value
Og
=
value
.
gateValue
+
frameS
ize
*
3
;
T
*
value
_og
=
value
.
gate_value
+
frame_s
ize
*
3
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
ValueIn
=
valueI
n
[
i
];
r
_value_in
=
value_i
n
[
i
];
r
ValueIg
=
valueI
g
[
i
];
r
_value_ig
=
value_i
g
[
i
];
r
ValueFg
=
valueF
g
[
i
];
r
_value_fg
=
value_f
g
[
i
];
r
ValueOg
=
valueO
g
[
i
];
r
_value_og
=
value_o
g
[
i
];
r
CheckI
=
value
.
checkIg
?
value
.
checkI
g
[
i
]
:
0
;
r
_checkI
=
value
.
check_ig
?
value
.
check_i
g
[
i
]
:
0
;
r
CheckF
=
value
.
checkFg
?
value
.
checkF
g
[
i
]
:
0
;
r
_checkF
=
value
.
check_fg
?
value
.
check_f
g
[
i
]
:
0
;
r
CheckO
=
value
.
checkOg
?
value
.
checkO
g
[
i
]
:
0
;
r
_checkO
=
value
.
check_og
?
value
.
check_o
g
[
i
]
:
0
;
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
r
PrevState
=
value
.
prevStateV
alue
[
i
];
r
_prev_state
=
value
.
prev_state_v
alue
[
i
];
}
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_prev_state
,
r_state
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
r_state_atv
,
r_out
,
r_checkI
,
r_checkF
,
r_checkO
,
active_node
,
active_gate
,
active_state
);
value
In
[
i
]
=
rValueI
n
;
value
_in
[
i
]
=
r_value_i
n
;
value
Ig
[
i
]
=
rValueI
g
;
value
_ig
[
i
]
=
r_value_i
g
;
value
Fg
[
i
]
=
rValueF
g
;
value
_fg
[
i
]
=
r_value_f
g
;
value
Og
[
i
]
=
rValueO
g
;
value
_og
[
i
]
=
r_value_o
g
;
value
.
state
Value
[
i
]
=
rS
tate
;
value
.
state
_value
[
i
]
=
r_s
tate
;
value
.
state
ActiveValue
[
i
]
=
rStateA
tv
;
value
.
state
_active_value
[
i
]
=
r_state_a
tv
;
value
.
output
Value
[
i
]
=
rO
ut
;
value
.
output
_value
[
i
]
=
r_o
ut
;
}
}
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
naive_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
S
ize
,
LstmMetaGrad
<
T
>
grad
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
T
r
ValueI
n
;
T
r
_value_i
n
;
T
r
ValueI
g
;
T
r
_value_i
g
;
T
r
ValueF
g
;
T
r
_value_f
g
;
T
r
ValueO
g
;
T
r
_value_o
g
;
T
r
GradI
n
;
T
r
_grad_i
n
;
T
r
GradI
g
;
T
r
_grad_i
g
;
T
r
GradF
g
;
T
r
_grad_f
g
;
T
r
GradO
g
;
T
r
_grad_o
g
;
T
r
PrevS
tate
=
0
;
T
r
_prev_s
tate
=
0
;
T
r
PrevStateG
rad
;
T
r
_prev_state_g
rad
;
T
r
S
tate
;
T
r
_s
tate
;
T
r
StateG
rad
;
T
r
_state_g
rad
;
T
r
StateA
tv
;
T
r
_state_a
tv
;
T
r
OutputG
rad
;
T
r
_output_g
rad
;
T
r
C
heckI
;
T
r
_c
heckI
;
T
r
C
heckF
;
T
r
_c
heckF
;
T
r
C
heckO
;
T
r
_c
heckO
;
T
r
C
heckIGrad
;
T
r
_c
heckIGrad
;
T
r
C
heckFGrad
;
T
r
_c
heckFGrad
;
T
r
C
heckOGrad
;
T
r
_c
heckOGrad
;
T
*
value
In
=
value
.
gateV
alue
;
T
*
value
_in
=
value
.
gate_v
alue
;
T
*
value
Ig
=
value
.
gateValue
+
frameS
ize
;
T
*
value
_ig
=
value
.
gate_value
+
frame_s
ize
;
T
*
value
Fg
=
value
.
gateValue
+
frameS
ize
*
2
;
T
*
value
_fg
=
value
.
gate_value
+
frame_s
ize
*
2
;
T
*
value
Og
=
value
.
gateValue
+
frameS
ize
*
3
;
T
*
value
_og
=
value
.
gate_value
+
frame_s
ize
*
3
;
T
*
grad
In
=
grad
.
gateG
rad
;
T
*
grad
_in
=
grad
.
gate_g
rad
;
T
*
grad
Ig
=
grad
.
gateGrad
+
frameS
ize
;
T
*
grad
_ig
=
grad
.
gate_grad
+
frame_s
ize
;
T
*
grad
Fg
=
grad
.
gateGrad
+
frameS
ize
*
2
;
T
*
grad
_fg
=
grad
.
gate_grad
+
frame_s
ize
*
2
;
T
*
grad
Og
=
grad
.
gateGrad
+
frameS
ize
*
3
;
T
*
grad
_og
=
grad
.
gate_grad
+
frame_s
ize
*
3
;
for
(
int
i
=
0
;
i
<
frame
S
ize
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
;
i
++
)
{
r
ValueIn
=
valueI
n
[
i
];
r
_value_in
=
value_i
n
[
i
];
r
ValueIg
=
valueI
g
[
i
];
r
_value_ig
=
value_i
g
[
i
];
r
ValueFg
=
valueF
g
[
i
];
r
_value_fg
=
value_f
g
[
i
];
r
ValueOg
=
valueO
g
[
i
];
r
_value_og
=
value_o
g
[
i
];
r
CheckI
=
value
.
checkIg
?
value
.
checkI
g
[
i
]
:
0
;
r
_checkI
=
value
.
check_ig
?
value
.
check_i
g
[
i
]
:
0
;
r
CheckF
=
value
.
checkFg
?
value
.
checkF
g
[
i
]
:
0
;
r
_checkF
=
value
.
check_fg
?
value
.
check_f
g
[
i
]
:
0
;
r
CheckO
=
value
.
checkOg
?
value
.
checkO
g
[
i
]
:
0
;
r
_checkO
=
value
.
check_og
?
value
.
check_o
g
[
i
]
:
0
;
r
State
=
value
.
stateV
alue
[
i
];
r
_state
=
value
.
state_v
alue
[
i
];
r
StateAtv
=
value
.
stateActiveV
alue
[
i
];
r
_state_atv
=
value
.
state_active_v
alue
[
i
];
r
OutputGrad
=
grad
.
outputG
rad
[
i
];
r
_output_grad
=
grad
.
output_g
rad
[
i
];
r
StateGrad
=
grad
.
stateG
rad
[
i
];
r
_state_grad
=
grad
.
state_g
rad
[
i
];
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
r
PrevState
=
value
.
prevStateV
alue
[
i
];
r
_prev_state
=
value
.
prev_state_v
alue
[
i
];
}
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_grad_in
,
r_grad_ig
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
r_grad_fg
,
r_grad_og
,
r_prev_state
,
r_prev_state_grad
,
r_state
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
r_state_grad
,
r_state_atv
,
r_output_grad
,
r_checkI
,
r_checkF
,
r_checkO
,
rCheckOGrad
,
active_node
,
active_gate
,
active_state
);
r_checkIGrad
,
r_checkFGrad
,
r_checkOGrad
,
active_node
,
active_gate
,
active_state
);
gradIn
[
i
]
=
rGradIn
;
gradIg
[
i
]
=
rGradIg
;
grad_in
[
i
]
=
r_grad_in
;
gradFg
[
i
]
=
rGradFg
;
grad_ig
[
i
]
=
r_grad_ig
;
gradOg
[
i
]
=
rGradOg
;
grad_fg
[
i
]
=
r_grad_fg
;
grad
.
stateGrad
[
i
]
=
rStateGrad
;
grad_og
[
i
]
=
r_grad_og
;
grad
.
state_grad
[
i
]
=
r_state_grad
;
if
(
grad
.
prevStateGrad
)
grad
.
prevStateGrad
[
i
]
=
rPrevStateGrad
;
if
(
value
.
prevStateValue
)
{
if
(
grad
.
prev_state_grad
)
grad
.
prev_state_grad
[
i
]
=
r_prev_state_grad
;
if
(
grad
.
checkIgGrad
)
grad
.
checkIgGrad
[
i
]
+=
rCheckIGrad
;
if
(
value
.
prev_state_value
)
{
if
(
grad
.
checkFgGrad
)
grad
.
checkFgGrad
[
i
]
+=
rCheckFGrad
;
if
(
grad
.
check_ig_grad
)
grad
.
check_ig_grad
[
i
]
+=
r_checkIGrad
;
if
(
grad
.
check_fg_grad
)
grad
.
check_fg_grad
[
i
]
+=
r_checkFGrad
;
}
}
if
(
grad
.
check
OgGrad
)
grad
.
checkOgGrad
[
i
]
+=
rC
heckOGrad
;
if
(
grad
.
check
_og_grad
)
grad
.
check_og_grad
[
i
]
+=
r_c
heckOGrad
;
}
}
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
avx_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frameSize
,
void
avx_lstm_forward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame_size
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r
ValueI
n
;
__m256
r
_value_i
n
;
__m256
r
ValueI
g
;
__m256
r
_value_i
g
;
__m256
r
ValueF
g
;
__m256
r
_value_f
g
;
__m256
r
ValueO
g
;
__m256
r
_value_o
g
;
__m256
r
C
heckI
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckI
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckF
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckF
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckO
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckO
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
S
tate
;
__m256
r
_s
tate
;
__m256
r
PrevS
tate
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_s
tate
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
StateA
tv
;
__m256
r
_state_a
tv
;
__m256
r
O
ut
;
__m256
r
_o
ut
;
__m256
*
value
In
=
(
__m256
*
)
value
.
gateV
alue
;
__m256
*
value
_in
=
(
__m256
*
)
value
.
gate_v
alue
;
__m256
*
value
Ig
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
);
__m256
*
value
_ig
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
);
__m256
*
value
Fg
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
*
2
);
__m256
*
value
_fg
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
*
2
);
__m256
*
value
Og
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
*
3
);
__m256
*
value
_og
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
*
3
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
ValueIn
=
valueI
n
[
i
];
r
_value_in
=
value_i
n
[
i
];
r
ValueIg
=
valueI
g
[
i
];
r
_value_ig
=
value_i
g
[
i
];
r
ValueFg
=
valueF
g
[
i
];
r
_value_fg
=
value_f
g
[
i
];
r
ValueOg
=
valueO
g
[
i
];
r
_value_og
=
value_o
g
[
i
];
if
(
value
.
check
I
g
)
{
if
(
value
.
check
_i
g
)
{
r
CheckI
=
((
__m256
*
)
value
.
checkI
g
)[
i
];
r
_checkI
=
((
__m256
*
)
value
.
check_i
g
)[
i
];
r
CheckF
=
((
__m256
*
)
value
.
checkF
g
)[
i
];
r
_checkF
=
((
__m256
*
)
value
.
check_f
g
)[
i
];
r
CheckO
=
((
__m256
*
)
value
.
checkO
g
)[
i
];
r
_checkO
=
((
__m256
*
)
value
.
check_o
g
)[
i
];
}
}
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
r
PrevState
=
((
__m256
*
)
value
.
prevStateV
alue
)[
i
];
r
_prev_state
=
((
__m256
*
)
value
.
prev_state_v
alue
)[
i
];
}
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_prev_state
,
r_state
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
r_state_atv
,
r_out
,
r_checkI
,
r_checkF
,
r_checkO
,
active_node
,
active_gate
,
active_state
);
value
In
[
i
]
=
rValueI
n
;
value
_in
[
i
]
=
r_value_i
n
;
value
Ig
[
i
]
=
rValueI
g
;
value
_ig
[
i
]
=
r_value_i
g
;
value
Fg
[
i
]
=
rValueF
g
;
value
_fg
[
i
]
=
r_value_f
g
;
value
Og
[
i
]
=
rValueO
g
;
value
_og
[
i
]
=
r_value_o
g
;
((
__m256
*
)
value
.
state
Value
)[
i
]
=
rS
tate
;
((
__m256
*
)
value
.
state
_value
)[
i
]
=
r_s
tate
;
((
__m256
*
)
value
.
state
ActiveValue
)[
i
]
=
rStateA
tv
;
((
__m256
*
)
value
.
state
_active_value
)[
i
]
=
r_state_a
tv
;
((
__m256
*
)
value
.
output
Value
)[
i
]
=
rO
ut
;
((
__m256
*
)
value
.
output
_value
)[
i
]
=
r_o
ut
;
}
}
#endif
#endif
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
avx_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
void
avx_lstm_backward_one_sequence
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
S
ize
,
LstmMetaGrad
<
T
>
grad
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
#ifdef __AVX__
#ifdef __AVX__
__m256
r
ValueI
n
;
__m256
r
_value_i
n
;
__m256
r
ValueI
g
;
__m256
r
_value_i
g
;
__m256
r
ValueF
g
;
__m256
r
_value_f
g
;
__m256
r
ValueO
g
;
__m256
r
_value_o
g
;
__m256
r
GradI
n
;
__m256
r
_grad_i
n
;
__m256
r
GradI
g
;
__m256
r
_grad_i
g
;
__m256
r
GradF
g
;
__m256
r
_grad_f
g
;
__m256
r
GradO
g
;
__m256
r
_grad_o
g
;
__m256
r
PrevS
tate
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_prev_s
tate
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
PrevStateG
rad
;
__m256
r
_prev_state_g
rad
;
__m256
r
StateG
rad
;
__m256
r
_state_g
rad
;
__m256
r
S
tate
;
__m256
r
_s
tate
;
__m256
r
StateA
tv
;
__m256
r
_state_a
tv
;
__m256
r
OutputG
rad
;
__m256
r
_output_g
rad
;
__m256
r
C
heckI
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckI
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckF
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckF
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckO
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
_c
heckO
=
_mm256_set1_ps
(
0.0
f
);
__m256
r
C
heckIGrad
;
__m256
r
_c
heckIGrad
;
__m256
r
C
heckFGrad
;
__m256
r
_c
heckFGrad
;
__m256
r
C
heckOGrad
;
__m256
r
_c
heckOGrad
;
__m256
*
value
In
=
(
__m256
*
)
value
.
gateV
alue
;
__m256
*
value
_in
=
(
__m256
*
)
value
.
gate_v
alue
;
__m256
*
value
Ig
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
);
__m256
*
value
_ig
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
);
__m256
*
value
Fg
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
*
2
);
__m256
*
value
_fg
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
*
2
);
__m256
*
value
Og
=
(
__m256
*
)(
value
.
gateValue
+
frameS
ize
*
3
);
__m256
*
value
_og
=
(
__m256
*
)(
value
.
gate_value
+
frame_s
ize
*
3
);
__m256
*
grad
In
=
(
__m256
*
)
grad
.
gateG
rad
;
__m256
*
grad
_in
=
(
__m256
*
)
grad
.
gate_g
rad
;
__m256
*
grad
Ig
=
(
__m256
*
)(
grad
.
gateGrad
+
frameS
ize
);
__m256
*
grad
_ig
=
(
__m256
*
)(
grad
.
gate_grad
+
frame_s
ize
);
__m256
*
grad
Fg
=
(
__m256
*
)(
grad
.
gateGrad
+
frameS
ize
*
2
);
__m256
*
grad
_fg
=
(
__m256
*
)(
grad
.
gate_grad
+
frame_s
ize
*
2
);
__m256
*
grad
Og
=
(
__m256
*
)(
grad
.
gateGrad
+
frameS
ize
*
3
);
__m256
*
grad
_og
=
(
__m256
*
)(
grad
.
gate_grad
+
frame_s
ize
*
3
);
for
(
int
i
=
0
;
i
<
frame
S
ize
/
8
;
i
++
)
{
for
(
int
i
=
0
;
i
<
frame
_s
ize
/
8
;
i
++
)
{
r
ValueIn
=
valueI
n
[
i
];
r
_value_in
=
value_i
n
[
i
];
r
ValueIg
=
valueI
g
[
i
];
r
_value_ig
=
value_i
g
[
i
];
r
ValueFg
=
valueF
g
[
i
];
r
_value_fg
=
value_f
g
[
i
];
r
ValueOg
=
valueO
g
[
i
];
r
_value_og
=
value_o
g
[
i
];
if
(
value
.
check
I
g
)
{
if
(
value
.
check
_i
g
)
{
r
CheckI
=
((
__m256
*
)
value
.
checkI
g
)[
i
];
r
_checkI
=
((
__m256
*
)
value
.
check_i
g
)[
i
];
r
CheckF
=
((
__m256
*
)
value
.
checkF
g
)[
i
];
r
_checkF
=
((
__m256
*
)
value
.
check_f
g
)[
i
];
r
CheckO
=
((
__m256
*
)
value
.
checkO
g
)[
i
];
r
_checkO
=
((
__m256
*
)
value
.
check_o
g
)[
i
];
}
}
r
State
=
((
__m256
*
)
value
.
stateV
alue
)[
i
];
r
_state
=
((
__m256
*
)
value
.
state_v
alue
)[
i
];
r
StateAtv
=
((
__m256
*
)
value
.
stateActiveV
alue
)[
i
];
r
_state_atv
=
((
__m256
*
)
value
.
state_active_v
alue
)[
i
];
r
OutputGrad
=
((
__m256
*
)
grad
.
outputG
rad
)[
i
];
r
_output_grad
=
((
__m256
*
)
grad
.
output_g
rad
)[
i
];
r
StateGrad
=
((
__m256
*
)
grad
.
stateG
rad
)[
i
];
r
_state_grad
=
((
__m256
*
)
grad
.
state_g
rad
)[
i
];
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
r
PrevState
=
((
__m256
*
)
value
.
prevStateV
alue
)[
i
];
r
_prev_state
=
((
__m256
*
)
value
.
prev_state_v
alue
)[
i
];
}
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_grad_in
,
r_grad_ig
,
rGradOg
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
r_grad_fg
,
r_grad_og
,
r_prev_state
,
r_prev_state_grad
,
r_state
,
rOutputGrad
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
r_state_grad
,
r_state_atv
,
r_output_grad
,
r_checkI
,
r_checkF
,
r_checkO
,
rCheckOGrad
,
active_node
,
active_gate
,
active_state
);
r_checkIGrad
,
r_checkFGrad
,
r_checkOGrad
,
active_node
,
active_gate
,
active_state
);
gradIn
[
i
]
=
rGradIn
;
gradIg
[
i
]
=
rGradIg
;
grad_in
[
i
]
=
r_grad_in
;
gradFg
[
i
]
=
rGradFg
;
grad_ig
[
i
]
=
r_grad_ig
;
gradOg
[
i
]
=
rGradOg
;
grad_fg
[
i
]
=
r_grad_fg
;
((
__m256
*
)
grad
.
stateGrad
)[
i
]
=
rStateGrad
;
grad_og
[
i
]
=
r_grad_og
;
((
__m256
*
)
grad
.
state_grad
)[
i
]
=
r_state_grad
;
if
(
grad
.
prevStateGrad
)
((
__m256
*
)
grad
.
prevStateGrad
)[
i
]
=
rPrevStateGrad
;
if
(
value
.
prevStateValue
)
{
if
(
grad
.
prev_state_grad
)
if
(
grad
.
checkIgGrad
)
((
__m256
*
)
grad
.
checkIgGrad
)[
i
]
+=
rCheckIGrad
;
((
__m256
*
)
grad
.
prev_state_grad
)[
i
]
=
r_prev_state_grad
;
if
(
grad
.
checkFgGrad
)
((
__m256
*
)
grad
.
checkFgGrad
)[
i
]
+=
rCheckFGrad
;
if
(
value
.
prev_state_value
)
{
if
(
grad
.
check_ig_grad
)
((
__m256
*
)
grad
.
check_ig_grad
)[
i
]
+=
r_checkIGrad
;
if
(
grad
.
check_fg_grad
)
((
__m256
*
)
grad
.
check_fg_grad
)[
i
]
+=
r_checkFGrad
;
}
}
if
(
grad
.
check
OgGrad
)
((
__m256
*
)
grad
.
checkOgGrad
)[
i
]
+=
rC
heckOGrad
;
if
(
grad
.
check
_og_grad
)
((
__m256
*
)
grad
.
check_og_grad
)[
i
]
+=
r_c
heckOGrad
;
}
}
#endif
#endif
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
cpu_lstm_forward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
S
ize
,
void
cpu_lstm_forward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
if
(
Op
::
avx
&&
!
(
frame
S
ize
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
if
(
Op
::
avx
&&
!
(
frame
_s
ize
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame
S
ize
,
active_node
,
avx_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame
_s
ize
,
active_node
,
active_gate
,
active_state
);
active_gate
,
active_state
);
}
else
{
}
else
{
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame
S
ize
,
active_node
,
naive_lstm_forward_one_sequence
<
T
>
(
op
,
value
,
frame
_s
ize
,
active_node
,
active_gate
,
active_state
);
active_gate
,
active_state
);
}
}
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
cpu_lstm_backward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
void
cpu_lstm_backward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
S
ize
,
activation_mode_t
active_node
,
int
frame
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
if
(
Op
::
avx
&&
!
(
frame
S
ize
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
if
(
Op
::
avx
&&
!
(
frame
_s
ize
&
(
8
-
1
))
&&
(
std
::
is_same
<
T
,
float
>::
value
))
{
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame
S
ize
,
active_node
,
avx_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame
_s
ize
,
active_node
,
active_gate
,
active_state
);
active_gate
,
active_state
);
}
else
{
}
else
{
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame
Size
,
active_nod
e
,
naive_lstm_backward_one_sequence
<
T
>
(
op
,
value
,
grad
,
frame
_siz
e
,
active_gate
,
active_state
);
active_
node
,
active_
gate
,
active_state
);
}
}
}
}
...
...
paddle/operators/math/detail/lstm_gpu_kernel.h
浏览文件 @
74725d05
...
@@ -26,189 +26,192 @@ namespace math {
...
@@ -26,189 +26,192 @@ namespace math {
namespace
detail
{
namespace
detail
{
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
T
,
class
Op
,
bool
is
B
atch
>
template
<
class
T
,
class
Op
,
bool
is
_b
atch
>
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
S
ize
,
__global__
void
KeLstmForward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
_s
ize
,
int
batch
S
ize
,
activation_mode_t
active_node
,
int
batch
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
const
int
frame
I
dx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame
_i
dx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
I
dx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch
_i
dx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
value
.
gate
Value
+=
batchIdx
*
frameS
ize
*
4
;
value
.
gate
_value
+=
batch_idx
*
frame_s
ize
*
4
;
value
.
output
Value
+=
batchIdx
*
frameS
ize
;
value
.
output
_value
+=
batch_idx
*
frame_s
ize
;
value
.
state
Value
+=
batchIdx
*
frameS
ize
;
value
.
state
_value
+=
batch_idx
*
frame_s
ize
;
value
.
state
ActiveValue
+=
batchIdx
*
frameS
ize
;
value
.
state
_active_value
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
S
tate
;
T
r
_s
tate
;
T
r
PrevS
tate
=
0
;
T
r
_prev_s
tate
=
0
;
T
r
StateA
tv
;
T
r
_state_a
tv
;
T
r
O
ut
;
T
r
_o
ut
;
T
r
ValueI
n
;
T
r
_value_i
n
;
T
r
ValueI
g
;
T
r
_value_i
g
;
T
r
ValueF
g
;
T
r
_value_f
g
;
T
r
ValueO
g
;
T
r
_value_o
g
;
T
r
CheckI
=
value
.
checkIg
?
value
.
checkIg
[
frameI
dx
]
:
0
;
T
r
_checkI
=
value
.
check_ig
?
value
.
check_ig
[
frame_i
dx
]
:
0
;
T
r
CheckF
=
value
.
checkFg
?
value
.
checkFg
[
frameI
dx
]
:
0
;
T
r
_checkF
=
value
.
check_fg
?
value
.
check_fg
[
frame_i
dx
]
:
0
;
T
r
CheckO
=
value
.
checkOg
?
value
.
checkOg
[
frameI
dx
]
:
0
;
T
r
_checkO
=
value
.
check_og
?
value
.
check_og
[
frame_i
dx
]
:
0
;
r
ValueIn
=
value
.
gateValue
[
frameI
dx
];
r
_value_in
=
value
.
gate_value
[
frame_i
dx
];
r
ValueIg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
];
r
_value_ig
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
];
r
ValueFg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
*
2
];
r
_value_fg
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
r
ValueOg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
*
3
];
r
_value_og
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
*
3
];
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
if
(
is
Batch
)
value
.
prevStateValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
value
.
prev_state_value
+=
batch_idx
*
frame_s
ize
;
r
PrevState
=
value
.
prevStateValue
[
frameI
dx
];
r
_prev_state
=
value
.
prev_state_value
[
frame_i
dx
];
}
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rPrevState
,
rState
,
rStateAtv
,
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_prev_state
,
r_state
,
rOut
,
rCheckI
,
rCheckF
,
rCheckO
,
active_node
,
active_gate
,
active_state
);
r_state_atv
,
r_out
,
r_checkI
,
r_checkF
,
r_checkO
,
active_node
,
active_gate
,
active_state
);
value
.
gate
Value
[
frameIdx
]
=
rValueI
n
;
value
.
gate
_value
[
frame_idx
]
=
r_value_i
n
;
value
.
gate
Value
[
frameIdx
+
frameSize
]
=
rValueI
g
;
value
.
gate
_value
[
frame_idx
+
frame_size
]
=
r_value_i
g
;
value
.
gate
Value
[
frameIdx
+
frameSize
*
2
]
=
rValueF
g
;
value
.
gate
_value
[
frame_idx
+
frame_size
*
2
]
=
r_value_f
g
;
value
.
gate
Value
[
frameIdx
+
frameSize
*
3
]
=
rValueO
g
;
value
.
gate
_value
[
frame_idx
+
frame_size
*
3
]
=
r_value_o
g
;
value
.
state
Value
[
frameIdx
]
=
rS
tate
;
value
.
state
_value
[
frame_idx
]
=
r_s
tate
;
value
.
state
ActiveValue
[
frameIdx
]
=
rStateA
tv
;
value
.
state
_active_value
[
frame_idx
]
=
r_state_a
tv
;
value
.
output
Value
[
frameIdx
]
=
rO
ut
;
value
.
output
_value
[
frame_idx
]
=
r_o
ut
;
}
}
/*
/*
* threads(frame
PerBlock, batchPerB
lock)
* threads(frame
_per_block, batch_per_b
lock)
* grid(frame
Blocks, batchB
locks)
* grid(frame
_blocks, batch_b
locks)
*/
*/
template
<
class
T
,
class
Op
,
bool
is
B
atch
>
template
<
class
T
,
class
Op
,
bool
is
_b
atch
>
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
__global__
void
KeLstmBackward
(
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
S
ize
,
LstmMetaGrad
<
T
>
grad
,
int
frame
_s
ize
,
int
batch
S
ize
,
activation_mode_t
active_node
,
int
batch
_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
const
int
frame
I
dx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
frame
_i
dx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
frame
Idx
>=
frameS
ize
)
return
;
if
(
frame
_idx
>=
frame_s
ize
)
return
;
int
batch
I
dx
=
0
;
int
batch
_i
dx
=
0
;
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
batch
I
dx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
batch
_i
dx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
batch
Idx
>=
batchS
ize
)
return
;
if
(
batch
_idx
>=
batch_s
ize
)
return
;
value
.
gate
Value
+=
batchIdx
*
frameS
ize
*
4
;
value
.
gate
_value
+=
batch_idx
*
frame_s
ize
*
4
;
value
.
state
Value
+=
batchIdx
*
frameS
ize
;
value
.
state
_value
+=
batch_idx
*
frame_s
ize
;
value
.
state
ActiveValue
+=
batchIdx
*
frameS
ize
;
value
.
state
_active_value
+=
batch_idx
*
frame_s
ize
;
grad
.
gate
Grad
+=
batchIdx
*
frameS
ize
*
4
;
grad
.
gate
_grad
+=
batch_idx
*
frame_s
ize
*
4
;
grad
.
state
Grad
+=
batchIdx
*
frameS
ize
;
grad
.
state
_grad
+=
batch_idx
*
frame_s
ize
;
grad
.
output
Grad
+=
batchIdx
*
frameS
ize
;
grad
.
output
_grad
+=
batch_idx
*
frame_s
ize
;
}
}
T
r
ValueI
n
;
T
r
_value_i
n
;
T
r
ValueI
g
;
T
r
_value_i
g
;
T
r
ValueF
g
;
T
r
_value_f
g
;
T
r
ValueO
g
;
T
r
_value_o
g
;
T
r
GradI
n
;
T
r
_grad_i
n
;
T
r
GradI
g
;
T
r
_grad_i
g
;
T
r
GradF
g
;
T
r
_grad_f
g
;
T
r
GradO
g
;
T
r
_grad_o
g
;
T
r
PrevS
tate
=
0
;
T
r
_prev_s
tate
=
0
;
T
r
PrevStateG
rad
;
T
r
_prev_state_g
rad
;
T
r
S
tate
;
T
r
_s
tate
;
T
r
StateG
rad
;
T
r
_state_g
rad
;
T
r
StateA
tv
;
T
r
_state_a
tv
;
T
r
OutputG
rad
;
T
r
_output_g
rad
;
T
r
CheckI
=
value
.
checkIg
?
value
.
checkIg
[
frameI
dx
]
:
0
;
T
r
_checkI
=
value
.
check_ig
?
value
.
check_ig
[
frame_i
dx
]
:
0
;
T
r
CheckF
=
value
.
checkFg
?
value
.
checkFg
[
frameI
dx
]
:
0
;
T
r
_checkF
=
value
.
check_fg
?
value
.
check_fg
[
frame_i
dx
]
:
0
;
T
r
CheckO
=
value
.
checkOg
?
value
.
checkOg
[
frameI
dx
]
:
0
;
T
r
_checkO
=
value
.
check_og
?
value
.
check_og
[
frame_i
dx
]
:
0
;
T
r
C
heckIGrad
;
T
r
_c
heckIGrad
;
T
r
C
heckFGrad
;
T
r
_c
heckFGrad
;
T
r
C
heckOGrad
;
T
r
_c
heckOGrad
;
r
ValueIn
=
value
.
gateValue
[
frameI
dx
];
r
_value_in
=
value
.
gate_value
[
frame_i
dx
];
r
ValueIg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
];
r
_value_ig
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
];
r
ValueFg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
*
2
];
r
_value_fg
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
*
2
];
r
ValueOg
=
value
.
gateValue
[
frameIdx
+
frameS
ize
*
3
];
r
_value_og
=
value
.
gate_value
[
frame_idx
+
frame_s
ize
*
3
];
r
State
=
value
.
stateValue
[
frameI
dx
];
r
_state
=
value
.
state_value
[
frame_i
dx
];
r
StateAtv
=
value
.
stateActiveValue
[
frameI
dx
];
r
_state_atv
=
value
.
state_active_value
[
frame_i
dx
];
r
OutputGrad
=
grad
.
outputGrad
[
frameI
dx
];
r
_output_grad
=
grad
.
output_grad
[
frame_i
dx
];
r
StateGrad
=
grad
.
stateGrad
[
frameI
dx
];
r
_state_grad
=
grad
.
state_grad
[
frame_i
dx
];
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
if
(
is
Batch
)
value
.
prevStateValue
+=
batchIdx
*
frameS
ize
;
if
(
is
_batch
)
value
.
prev_state_value
+=
batch_idx
*
frame_s
ize
;
r
PrevState
=
value
.
prevStateValue
[
frameI
dx
];
r
_prev_state
=
value
.
prev_state_value
[
frame_i
dx
];
}
}
op
(
rValueIn
,
rValueIg
,
rValueFg
,
rValueOg
,
rGradIn
,
rGradIg
,
rGradFg
,
rGradOg
,
op
(
r_value_in
,
r_value_ig
,
r_value_fg
,
r_value_og
,
r_grad_in
,
r_grad_ig
,
rPrevState
,
rPrevStateGrad
,
rState
,
rStateGrad
,
rStateAtv
,
rOutputGrad
,
r_grad_fg
,
r_grad_og
,
r_prev_state
,
r_prev_state_grad
,
r_state
,
rCheckI
,
rCheckF
,
rCheckO
,
rCheckIGrad
,
rCheckFGrad
,
rCheckOGrad
,
r_state_grad
,
r_state_atv
,
r_output_grad
,
r_checkI
,
r_checkF
,
r_checkO
,
active_node
,
active_gate
,
active_state
);
r_checkIGrad
,
r_checkFGrad
,
r_checkOGrad
,
active_node
,
active_gate
,
active_state
);
grad
.
gateGrad
[
frameIdx
]
=
rGradIn
;
grad
.
gateGrad
[
frameIdx
+
frameSize
]
=
rGradIg
;
grad
.
gate_grad
[
frame_idx
]
=
r_grad_in
;
grad
.
gateGrad
[
frameIdx
+
frameSize
*
2
]
=
rGradFg
;
grad
.
gate_grad
[
frame_idx
+
frame_size
]
=
r_grad_ig
;
grad
.
gateGrad
[
frameIdx
+
frameSize
*
3
]
=
rGradOg
;
grad
.
gate_grad
[
frame_idx
+
frame_size
*
2
]
=
r_grad_fg
;
grad
.
stateGrad
[
frameIdx
]
=
rStateGrad
;
grad
.
gate_grad
[
frame_idx
+
frame_size
*
3
]
=
r_grad_og
;
if
(
grad
.
prevStateGrad
)
{
grad
.
state_grad
[
frame_idx
]
=
r_state_grad
;
if
(
isBatch
)
grad
.
prevStateGrad
+=
batchIdx
*
frameSize
;
if
(
grad
.
prev_state_grad
)
{
grad
.
prevStateGrad
[
frameIdx
]
=
rPrevStateGrad
;
if
(
is_batch
)
grad
.
prev_state_grad
+=
batch_idx
*
frame_size
;
grad
.
prev_state_grad
[
frame_idx
]
=
r_prev_state_grad
;
}
}
if
(
is
B
atch
)
{
if
(
is
_b
atch
)
{
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
if
(
grad
.
check
IgG
rad
)
if
(
grad
.
check
_ig_g
rad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check
IgGrad
+
frameI
dx
,
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check
_ig_grad
+
frame_i
dx
,
r
C
heckIGrad
);
r
_c
heckIGrad
);
if
(
grad
.
check
FgG
rad
)
if
(
grad
.
check
_fg_g
rad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check
FgGrad
+
frameI
dx
,
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check
_fg_grad
+
frame_i
dx
,
r
C
heckFGrad
);
r
_c
heckFGrad
);
}
}
if
(
grad
.
checkOgGrad
)
if
(
grad
.
check_og_grad
)
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
checkOgGrad
+
frameIdx
,
rCheckOGrad
);
paddle
::
platform
::
CudaAtomicAdd
(
grad
.
check_og_grad
+
frame_idx
,
r_checkOGrad
);
}
else
{
}
else
{
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
if
(
grad
.
check
IgGrad
)
grad
.
checkIgGrad
[
frameIdx
]
+=
rC
heckIGrad
;
if
(
grad
.
check
_ig_grad
)
grad
.
check_ig_grad
[
frame_idx
]
+=
r_c
heckIGrad
;
if
(
grad
.
check
FgGrad
)
grad
.
checkFgGrad
[
frameIdx
]
+=
rC
heckFGrad
;
if
(
grad
.
check
_fg_grad
)
grad
.
check_fg_grad
[
frame_idx
]
+=
r_c
heckFGrad
;
}
}
if
(
grad
.
check
OgGrad
)
grad
.
checkOgGrad
[
frameIdx
]
+=
rC
heckOGrad
;
if
(
grad
.
check
_og_grad
)
grad
.
check_og_grad
[
frame_idx
]
+=
r_c
heckOGrad
;
}
}
}
}
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
gpu_lstm_forward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
void
gpu_lstm_forward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
LstmMetaValue
<
T
>
value
,
int
frame
Size
,
int
batchS
ize
,
LstmMetaValue
<
T
>
value
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
}
else
{
/* frame
PerBlock = 32 batchPerB
lock = 32 */
/* frame
_per_block = 32 batch_per_b
lock = 32 */
threads
=
dim3
(
32
,
32
);
threads
=
dim3
(
32
,
32
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
32
-
1
)
/
32
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
32
-
1
)
/
32
);
}
}
auto
stream
=
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
KeLstmForward
<
T
,
Op
,
KeLstmForward
<
T
,
Op
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frame
Size
,
batchS
ize
,
active_node
,
active_gate
,
op
,
value
,
frame
_size
,
batch_s
ize
,
active_node
,
active_gate
,
active_state
);
active_state
);
}
else
{
}
else
{
KeLstmForward
<
T
,
Op
,
KeLstmForward
<
T
,
Op
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
frame
Size
,
batchS
ize
,
active_node
,
active_gate
,
op
,
value
,
frame
_size
,
batch_s
ize
,
active_node
,
active_gate
,
active_state
);
active_state
);
}
}
}
}
...
@@ -216,34 +219,34 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
...
@@ -216,34 +219,34 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
template
<
class
T
,
class
Op
>
template
<
class
T
,
class
Op
>
void
gpu_lstm_backward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
void
gpu_lstm_backward
(
const
platform
::
DeviceContext
&
context
,
Op
op
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
LstmMetaValue
<
T
>
value
,
LstmMetaGrad
<
T
>
grad
,
int
frame
Size
,
int
batchS
ize
,
int
frame
_size
,
int
batch_s
ize
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
dim3
threads
;
dim3
threads
;
dim3
grid
;
dim3
grid
;
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
int
frame
PerBlock
=
frameSize
<=
1024
?
frameS
ize
:
1024
;
int
frame
_per_block
=
frame_size
<=
1024
?
frame_s
ize
:
1024
;
int
frame
Blocks
=
(
frameS
ize
+
1024
-
1
)
/
1024
;
int
frame
_blocks
=
(
frame_s
ize
+
1024
-
1
)
/
1024
;
threads
=
dim3
(
frame
PerB
lock
,
1
);
threads
=
dim3
(
frame
_per_b
lock
,
1
);
grid
=
dim3
(
frame
B
locks
,
1
);
grid
=
dim3
(
frame
_b
locks
,
1
);
}
else
{
}
else
{
/* frame
PerBlock = 32 batchPerB
lock = 16 */
/* frame
_per_block = 32 batch_per_b
lock = 16 */
threads
=
dim3
(
32
,
16
);
threads
=
dim3
(
32
,
16
);
grid
=
dim3
((
frame
Size
+
32
-
1
)
/
32
,
(
batchS
ize
+
16
-
1
)
/
16
);
grid
=
dim3
((
frame
_size
+
32
-
1
)
/
32
,
(
batch_s
ize
+
16
-
1
)
/
16
);
}
}
auto
stream
=
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
if
(
batch
S
ize
==
1
)
{
if
(
batch
_s
ize
==
1
)
{
KeLstmBackward
<
T
,
Op
,
KeLstmBackward
<
T
,
Op
,
/* is
B
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
false
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frame
Size
,
batchS
ize
,
active_node
,
active_gate
,
op
,
value
,
grad
,
frame
_size
,
batch_s
ize
,
active_node
,
active_gate
,
active_state
);
active_state
);
}
else
{
}
else
{
KeLstmBackward
<
T
,
Op
,
KeLstmBackward
<
T
,
Op
,
/* is
B
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
/* is
_b
atch= */
true
><<<
grid
,
threads
,
0
,
stream
>>>
(
op
,
value
,
grad
,
frame
Size
,
batchS
ize
,
active_node
,
active_gate
,
op
,
value
,
grad
,
frame
_size
,
batch_s
ize
,
active_node
,
active_gate
,
active_state
);
active_state
);
}
}
}
}
...
...
paddle/operators/math/detail/lstm_kernel.h
浏览文件 @
74725d05
...
@@ -27,19 +27,19 @@ namespace forward {
...
@@ -27,19 +27,19 @@ namespace forward {
template
<
class
T
>
template
<
class
T
>
class
lstm
{
class
lstm
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
value
In
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueO
g
,
HOSTDEVICE
void
operator
()(
T
&
value
_in
,
T
&
value_ig
,
T
&
value_fg
,
T
&
value_o
g
,
T
&
prev
State
,
T
&
state
,
T
&
stateA
tv
,
T
&
output
,
T
&
prev
_state
,
T
&
state
,
T
&
state_a
tv
,
T
&
output
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
value
In
=
activation
(
valueI
n
,
active_node
);
value
_in
=
activation
(
value_i
n
,
active_node
);
value
Ig
=
activation
(
valueIg
+
prevS
tate
*
checkI
,
active_gate
);
value
_ig
=
activation
(
value_ig
+
prev_s
tate
*
checkI
,
active_gate
);
value
Fg
=
activation
(
valueFg
+
prevS
tate
*
checkF
,
active_gate
);
value
_fg
=
activation
(
value_fg
+
prev_s
tate
*
checkF
,
active_gate
);
state
=
value
In
*
valueIg
+
prevState
*
valueF
g
;
state
=
value
_in
*
value_ig
+
prev_state
*
value_f
g
;
value
Og
=
activation
(
valueO
g
+
state
*
checkO
,
active_gate
);
value
_og
=
activation
(
value_o
g
+
state
*
checkO
,
active_gate
);
state
A
tv
=
activation
(
state
,
active_state
);
state
_a
tv
=
activation
(
state
,
active_state
);
output
=
value
Og
*
stateA
tv
;
output
=
value
_og
*
state_a
tv
;
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...
@@ -48,24 +48,27 @@ class lstm {
...
@@ -48,24 +48,27 @@ class lstm {
// Only float support AVX optimization
// Only float support AVX optimization
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
HOSTDEVICE
void
operator
()(
__m256
&
valueIn
,
__m256
&
valueIg
,
__m256
&
valueFg
,
HOSTDEVICE
void
operator
()(
__m256
&
value_in
,
__m256
&
value_ig
,
__m256
&
valueOg
,
__m256
&
prevState
,
__m256
&
state
,
__m256
&
value_fg
,
__m256
&
value_og
,
__m256
&
stateAtv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
prev_state
,
__m256
&
state
,
__m256
&
state_atv
,
__m256
&
output
,
__m256
&
checkI
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkF
,
__m256
&
checkO
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
valueIn
=
activation
(
valueIn
,
active_node
);
value_in
=
activation
(
value_in
,
active_node
);
valueIg
=
activation
(
value_ig
=
_mm256_add_ps
(
valueIg
,
_mm256_mul_ps
(
prevState
,
checkI
)),
active_gate
);
activation
(
_mm256_add_ps
(
value_ig
,
_mm256_mul_ps
(
prev_state
,
checkI
)),
valueFg
=
activation
(
_mm256_add_ps
(
valueFg
,
_mm256_mul_ps
(
prevState
,
checkF
)),
active_gate
);
state
=
_mm256_add_ps
(
_mm256_mul_ps
(
valueIn
,
valueIg
),
_mm256_mul_ps
(
prevState
,
valueFg
));
valueOg
=
activation
(
_mm256_add_ps
(
valueOg
,
_mm256_mul_ps
(
state
,
checkO
)),
active_gate
);
active_gate
);
stateAtv
=
activation
(
state
,
active_state
);
value_fg
=
output
=
_mm256_mul_ps
(
valueOg
,
stateAtv
);
activation
(
_mm256_add_ps
(
value_fg
,
_mm256_mul_ps
(
prev_state
,
checkF
)),
active_gate
);
state
=
_mm256_add_ps
(
_mm256_mul_ps
(
value_in
,
value_ig
),
_mm256_mul_ps
(
prev_state
,
value_fg
));
value_og
=
activation
(
_mm256_add_ps
(
value_og
,
_mm256_mul_ps
(
state
,
checkO
)),
active_gate
);
state_atv
=
activation
(
state
,
active_state
);
output
=
_mm256_mul_ps
(
value_og
,
state_atv
);
}
}
#endif
#endif
#endif
#endif
...
@@ -78,25 +81,26 @@ namespace backward {
...
@@ -78,25 +81,26 @@ namespace backward {
template
<
class
T
>
template
<
class
T
>
class
lstm
{
class
lstm
{
public:
public:
HOSTDEVICE
void
operator
()(
T
&
value
In
,
T
&
valueIg
,
T
&
valueFg
,
T
&
valueO
g
,
HOSTDEVICE
void
operator
()(
T
&
value
_in
,
T
&
value_ig
,
T
&
value_fg
,
T
&
value_o
g
,
T
&
grad
In
,
T
&
gradIg
,
T
&
gradFg
,
T
&
gradO
g
,
T
&
grad
_in
,
T
&
grad_ig
,
T
&
grad_fg
,
T
&
grad_o
g
,
T
&
prev
State
,
T
&
prevStateG
rad
,
T
&
state
,
T
&
prev
_state
,
T
&
prev_state_g
rad
,
T
&
state
,
T
&
state
Grad
,
T
&
stateAtv
,
T
&
outputG
rad
,
T
&
state
_grad
,
T
&
state_atv
,
T
&
output_g
rad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkI
,
T
&
checkF
,
T
&
checkO
,
T
&
checkIGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
T
&
checkFGrad
,
T
&
checkOGrad
,
activation_mode_t
active_node
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_state
)
{
gradOg
=
activation
(
outputGrad
*
stateAtv
,
valueOg
,
active_gate
);
grad_og
=
activation
(
output_grad
*
state_atv
,
value_og
,
active_gate
);
stateGrad
+=
activation
(
outputGrad
*
valueOg
,
stateAtv
,
active_state
)
+
state_grad
+=
activation
(
output_grad
*
value_og
,
state_atv
,
active_state
)
+
gradOg
*
checkO
;
grad_og
*
checkO
;
gradIn
=
activation
(
stateGrad
*
valueIg
,
valueIn
,
active_node
);
grad_in
=
activation
(
state_grad
*
value_ig
,
value_in
,
active_node
);
gradIg
=
activation
(
stateGrad
*
valueIn
,
valueIg
,
active_gate
);
grad_ig
=
activation
(
state_grad
*
value_in
,
value_ig
,
active_gate
);
gradFg
=
activation
(
stateGrad
*
prevState
,
valueFg
,
active_gate
);
grad_fg
=
activation
(
state_grad
*
prev_state
,
value_fg
,
active_gate
);
prevStateGrad
=
gradIg
*
checkI
+
gradFg
*
checkF
+
stateGrad
*
valueFg
;
prev_state_grad
=
checkIGrad
=
gradIg
*
prevState
;
grad_ig
*
checkI
+
grad_fg
*
checkF
+
state_grad
*
value_fg
;
checkFGrad
=
gradFg
*
prevState
;
checkIGrad
=
grad_ig
*
prev_state
;
checkOGrad
=
gradOg
*
state
;
checkFGrad
=
grad_fg
*
prev_state
;
checkOGrad
=
grad_og
*
state
;
}
}
#ifndef __NVCC__
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...
@@ -105,32 +109,32 @@ class lstm {
...
@@ -105,32 +109,32 @@ class lstm {
// Only float support AVX optimization
// Only float support AVX optimization
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
static
const
bool
avx
=
std
::
is_same
<
T
,
float
>::
value
;
HOSTDEVICE
void
operator
()(
HOSTDEVICE
void
operator
()(
__m256
&
value
In
,
__m256
&
valueIg
,
__m256
&
valueFg
,
__m256
&
valueO
g
,
__m256
&
value
_in
,
__m256
&
value_ig
,
__m256
&
value_fg
,
__m256
&
value_o
g
,
__m256
&
grad
In
,
__m256
&
gradIg
,
__m256
&
gradFg
,
__m256
&
gradO
g
,
__m256
&
grad
_in
,
__m256
&
grad_ig
,
__m256
&
grad_fg
,
__m256
&
grad_o
g
,
__m256
&
prev
State
,
__m256
&
prevStateG
rad
,
__m256
&
state
,
__m256
&
prev
_state
,
__m256
&
prev_state_g
rad
,
__m256
&
state
,
__m256
&
state
Grad
,
__m256
&
stateAtv
,
__m256
&
outputGrad
,
__m256
&
checkI
,
__m256
&
state
_grad
,
__m256
&
state_atv
,
__m256
&
output_grad
,
__m256
&
check
F
,
__m256
&
checkO
,
__m256
&
checkIGrad
,
__m256
&
checkF
Grad
,
__m256
&
check
I
,
__m256
&
checkF
,
__m256
&
checkO
,
__m256
&
checkI
Grad
,
__m256
&
checkOGrad
,
activation_mode_t
active_node
,
__m256
&
check
FGrad
,
__m256
&
check
OGrad
,
activation_mode_t
active_node
,
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
activation_mode_t
active_gate
,
activation_mode_t
active_state
)
{
grad
Og
=
grad
_og
=
activation
(
_mm256_mul_ps
(
output_grad
,
state_atv
),
value_og
,
activation
(
_mm256_mul_ps
(
outputGrad
,
stateAtv
),
valueOg
,
active_gate
);
active_gate
);
state
Grad
=
_mm256_add_ps
(
state
_grad
=
_mm256_add_ps
(
activation
(
_mm256_mul_ps
(
output_grad
,
value_og
),
activation
(
_mm256_mul_ps
(
outputGrad
,
valueOg
),
stateA
tv
,
active_state
),
state_a
tv
,
active_state
),
stateG
rad
);
state_g
rad
);
state
Grad
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradOg
,
checkO
),
stateG
rad
);
state
_grad
=
_mm256_add_ps
(
_mm256_mul_ps
(
grad_og
,
checkO
),
state_g
rad
);
grad
I
n
=
grad
_i
n
=
activation
(
_mm256_mul_ps
(
state
Grad
,
valueIg
),
valueI
n
,
active_node
);
activation
(
_mm256_mul_ps
(
state
_grad
,
value_ig
),
value_i
n
,
active_node
);
grad
I
g
=
grad
_i
g
=
activation
(
_mm256_mul_ps
(
state
Grad
,
valueIn
),
valueI
g
,
active_gate
);
activation
(
_mm256_mul_ps
(
state
_grad
,
value_in
),
value_i
g
,
active_gate
);
grad
Fg
=
grad
_fg
=
activation
(
_mm256_mul_ps
(
state_grad
,
prev_state
),
value_fg
,
activation
(
_mm256_mul_ps
(
stateGrad
,
prevState
),
valueFg
,
active_gate
);
active_gate
);
prev
StateGrad
=
_mm256_add_ps
(
_mm256_mul_ps
(
gradI
g
,
checkI
),
prev
_state_grad
=
_mm256_add_ps
(
_mm256_mul_ps
(
grad_i
g
,
checkI
),
_mm256_mul_ps
(
gradF
g
,
checkF
));
_mm256_mul_ps
(
grad_f
g
,
checkF
));
prev
StateG
rad
=
prev
_state_g
rad
=
_mm256_add_ps
(
_mm256_mul_ps
(
state
Grad
,
valueFg
),
prevStateG
rad
);
_mm256_add_ps
(
_mm256_mul_ps
(
state
_grad
,
value_fg
),
prev_state_g
rad
);
checkIGrad
=
_mm256_mul_ps
(
grad
Ig
,
prevS
tate
);
checkIGrad
=
_mm256_mul_ps
(
grad
_ig
,
prev_s
tate
);
checkFGrad
=
_mm256_mul_ps
(
grad
Fg
,
prevS
tate
);
checkFGrad
=
_mm256_mul_ps
(
grad
_fg
,
prev_s
tate
);
checkOGrad
=
_mm256_mul_ps
(
grad
O
g
,
state
);
checkOGrad
=
_mm256_mul_ps
(
grad
_o
g
,
state
);
}
}
#endif
#endif
#endif
#endif
...
...
paddle/operators/math/lstm_compute.cc
浏览文件 @
74725d05
...
@@ -30,12 +30,12 @@ struct LstmUnitFunctor<platform::CPUPlace, T> {
...
@@ -30,12 +30,12 @@ struct LstmUnitFunctor<platform::CPUPlace, T> {
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
detail
::
cpu_lstm_forward
(
detail
::
forward
::
lstm
<
T
>
(),
value
,
frame_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
ActiveType
(
cell_act
));
value
.
gate
V
alue
+=
frame_size
*
4
;
value
.
gate
_v
alue
+=
frame_size
*
4
;
value
.
state
V
alue
+=
frame_size
;
value
.
state
_v
alue
+=
frame_size
;
value
.
state
ActiveV
alue
+=
frame_size
;
value
.
state
_active_v
alue
+=
frame_size
;
value
.
output
V
alue
+=
frame_size
;
value
.
output
_v
alue
+=
frame_size
;
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
value
.
prev
StateV
alue
+=
frame_size
;
value
.
prev
_state_v
alue
+=
frame_size
;
}
}
}
}
}
}
...
@@ -53,20 +53,20 @@ struct LstmUnitGradFunctor<platform::CPUPlace, T> {
...
@@ -53,20 +53,20 @@ struct LstmUnitGradFunctor<platform::CPUPlace, T> {
frame_size
,
ActiveType
(
cand_act
),
frame_size
,
ActiveType
(
cand_act
),
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
ActiveType
(
gate_act
),
ActiveType
(
cell_act
));
value
.
gate
V
alue
+=
frame_size
*
4
;
value
.
gate
_v
alue
+=
frame_size
*
4
;
value
.
state
V
alue
+=
frame_size
;
value
.
state
_v
alue
+=
frame_size
;
value
.
state
ActiveV
alue
+=
frame_size
;
value
.
state
_active_v
alue
+=
frame_size
;
value
.
output
V
alue
+=
frame_size
;
value
.
output
_v
alue
+=
frame_size
;
if
(
value
.
prev
StateV
alue
)
{
if
(
value
.
prev
_state_v
alue
)
{
value
.
prev
StateV
alue
+=
frame_size
;
value
.
prev
_state_v
alue
+=
frame_size
;
}
}
grad
.
gate
G
rad
+=
frame_size
*
4
;
grad
.
gate
_g
rad
+=
frame_size
*
4
;
grad
.
state
G
rad
+=
frame_size
;
grad
.
state
_g
rad
+=
frame_size
;
grad
.
state
ActiveG
rad
+=
frame_size
;
grad
.
state
_active_g
rad
+=
frame_size
;
grad
.
output
G
rad
+=
frame_size
;
grad
.
output
_g
rad
+=
frame_size
;
if
(
grad
.
prev
StateG
rad
)
{
if
(
grad
.
prev
_state_g
rad
)
{
grad
.
prev
StateG
rad
+=
frame_size
;
grad
.
prev
_state_g
rad
+=
frame_size
;
}
}
}
}
}
}
...
...
paddle/operators/math/lstm_compute.h
浏览文件 @
74725d05
...
@@ -31,26 +31,26 @@ typedef enum {
...
@@ -31,26 +31,26 @@ typedef enum {
template
<
class
T
>
template
<
class
T
>
struct
LstmMetaValue
{
struct
LstmMetaValue
{
T
*
gate
V
alue
;
T
*
gate
_v
alue
;
T
*
prev
StateV
alue
;
T
*
prev
_state_v
alue
;
T
*
state
V
alue
;
T
*
state
_v
alue
;
T
*
state
ActiveV
alue
;
T
*
state
_active_v
alue
;
T
*
output
V
alue
;
T
*
output
_v
alue
;
T
*
check
I
g
;
T
*
check
_i
g
;
T
*
check
F
g
;
T
*
check
_f
g
;
T
*
check
O
g
;
T
*
check
_o
g
;
};
};
template
<
class
T
>
template
<
class
T
>
struct
LstmMetaGrad
{
struct
LstmMetaGrad
{
T
*
gate
G
rad
;
T
*
gate
_g
rad
;
T
*
prev
StateG
rad
;
T
*
prev
_state_g
rad
;
T
*
state
G
rad
;
T
*
state
_g
rad
;
T
*
state
ActiveG
rad
;
T
*
state
_active_g
rad
;
T
*
output
G
rad
;
T
*
output
_g
rad
;
T
*
check
IgG
rad
;
T
*
check
_ig_g
rad
;
T
*
check
FgG
rad
;
T
*
check
_fg_g
rad
;
T
*
check
OgG
rad
;
T
*
check
_og_g
rad
;
};
};
inline
activation_mode_t
ActiveType
(
const
std
::
string
&
type
)
{
inline
activation_mode_t
ActiveType
(
const
std
::
string
&
type
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录