Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c459fb5b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c459fb5b
编写于
8月 31, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fusion lstm batch mode
上级
c709a04a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
175 addition
and
132 deletion
+175
-132
paddle/fluid/operators/fusion_lstm_op.cc
paddle/fluid/operators/fusion_lstm_op.cc
+175
-132
未找到文件。
paddle/fluid/operators/fusion_lstm_op.cc
浏览文件 @
c459fb5b
...
...
@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool
(
seq_mode
,
tru
e
,
"Use sequence mode"
);
DEFINE_bool
(
seq_mode
,
fals
e
,
"Use sequence mode"
);
namespace
paddle
{
namespace
operators
{
...
...
@@ -42,10 +42,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"Output(Hidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Output(Cell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedGate"
),
"Output(BatchedGate) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
"Output(BatchedGate) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Output(BatchedInput) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedHidden"
),
"Output(BatchedHidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedCell"
),
"Output(BatchedCell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Output(ReorderedH0) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedC0"
),
"Output(ReorderedC0) of LSTM should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
...
...
@@ -97,8 +103,9 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedGate"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchCellPreAct"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedCell"
,
out_dims
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
ctx
->
ShareLoD
(
"X"
,
"Cell"
);
...
...
@@ -169,9 +176,11 @@ void FusionLSTMOpMaker::Make() {
" where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input."
)
.
AsIntermediate
();
AddOutput
(
"BatchedGate"
,
"(LoDTensor) (same as LSTMOp)."
).
AsIntermediate
();
AddOutput
(
"BatchCellPreAct"
,
"(LoDTensor) (same as LSTMOp)."
)
.
AsIntermediate
();
AddOutput
(
"BatchedInput"
,
"(LoDTensor) (T x 4D)."
).
AsIntermediate
();
AddOutput
(
"BatchedHidden"
,
"(LoDTensor) (T x D)."
).
AsIntermediate
();
AddOutput
(
"BatchedCell"
,
"(LoDTensor) (T x D)."
).
AsIntermediate
();
AddOutput
(
"ReorderedH0"
,
"(LoDTensor) (N x D)."
).
AsIntermediate
();
AddOutput
(
"ReorderedC0"
,
"(LoDTensor) (N x D)."
).
AsIntermediate
();
AddAttr
<
bool
>
(
"use_peepholes"
,
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections."
)
...
...
@@ -203,17 +212,6 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
)DOC"
);
}
template
<
typename
DeviceContext
,
typename
T
>
inline
void
ReorderInitState
(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
src
,
framework
::
Vector
<
size_t
>
index_lod
,
framework
::
Tensor
*
dst
,
bool
indexed_src
)
{
math
::
CopyMatrixRowsFunctor
<
DeviceContext
,
T
>
row_shuffle
;
dst
->
mutable_data
<
T
>
(
src
.
dims
(),
ctx
.
GetPlace
());
// TODO(TJ): check mem copy perf
row_shuffle
(
ctx
,
src
,
index_lod
,
dst
,
indexed_src
);
}
template
<
typename
T
>
class
FuisonLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -290,12 +288,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
int
bid
=
is_reverse
?
N
-
1
-
i
:
i
;
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
const
T
*
prev_c
ell
_data
=
NULL
;
const
T
*
prev_h
idden
_data
=
NULL
;
const
T
*
prev_c_data
=
NULL
;
const
T
*
prev_h_data
=
NULL
;
int
tstart
=
0
;
if
(
h0_data
)
{
prev_h
idden
_data
=
h0_data
+
bid
*
D
;
prev_c
ell
_data
=
c0_data
+
bid
*
D
;
prev_h_data
=
h0_data
+
bid
*
D
;
prev_c_data
=
c0_data
+
bid
*
D
;
}
else
{
// W_ch, W_ih, W_fh, W_oh
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
...
...
@@ -307,23 +305,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
prev_h
idden
_data
=
hidden_out_data
;
prev_c
ell
_data
=
cell_out_data
;
prev_h_data
=
hidden_out_data
;
prev_c_data
=
cell_out_data
;
tstart
=
1
;
move_step
();
}
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_hidden_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
D4
);
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
D4
);
// W_ch, W_ih, W_fh, W_oh
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
act_cand
(
D
,
xx_data
,
xx_data
);
// a = forget * prev_cell
blas
.
VMUL
(
D
,
xx_data
+
D2
,
prev_c
ell
_data
,
xx_data
+
D2
);
blas
.
VMUL
(
D
,
xx_data
+
D2
,
prev_c_data
,
xx_data
+
D2
);
// b = input * tilde
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
xx_data
+
D
);
...
...
@@ -336,8 +333,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
prev_h
idden
_data
=
hidden_out_data
;
prev_c
ell
_data
=
cell_out_data
;
prev_h_data
=
hidden_out_data
;
prev_c_data
=
cell_out_data
;
move_step
();
}
...
...
@@ -350,132 +347,178 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
auto
*
wx
=
ctx
.
Input
<
Tensor
>
(
"WeightX"
);
auto
*
wh
=
ctx
.
Input
<
Tensor
>
(
"WeightH"
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
h
idden_t
0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
c
ell_t
0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
auto
*
h0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
c0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
auto
*
xx
=
ctx
.
Output
<
LoDTensor
>
(
"XX"
);
auto
*
batched_gate
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedGate"
);
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
auto
*
reordered_c0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedC0"
);
auto
*
batched_input
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedInput"
);
auto
*
batched_c_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedCell"
);
auto
*
batched_h_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedHidden"
);
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"is_reverse"
);
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
batched_gate_data
=
batched_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
std
::
function
<
void
(
const
int
,
const
T
*
,
T
*
)
>
act_gate
,
act_cell
,
act_cand
;
auto
&
act_gate_str
=
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
);
auto
&
act_cell_str
=
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
);
auto
&
act_cand_str
=
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
);
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx
))
{
math
::
VecActivations
<
T
,
platform
::
jit
::
avx
>
act_functor
;
act_gate
=
act_functor
(
act_gate_str
);
act_cell
=
act_functor
(
act_cell_str
);
act_cand
=
act_functor
(
act_cand_str
);
}
else
{
math
::
VecActivations
<
T
,
platform
::
jit
::
isa_any
>
act_functor
;
act_gate
=
act_functor
(
act_gate_str
);
act_cell
=
act_functor
(
act_cell_str
);
act_cand
=
act_functor
(
act_cand_str
);
}
auto
x_dims
=
x
->
dims
();
// T x M
auto
wh_dims
=
wh
->
dims
();
// D x 4D
// auto x_lod = x->lod();
// const int N = x_lod[0].size() - 1; // batch size
// if (N == 1) {
// SeqCompute(ctx);
// }
const
int
M
=
x_dims
[
1
];
const
int
D
=
wh_dims
[
0
];
const
int
D2
=
D
*
2
;
const
int
D3
=
D
*
3
;
const
int
D4
=
wh_dims
[
1
];
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
auto
x_dims
=
x
->
dims
();
auto
wx_dims
=
wx
->
dims
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
auto
place
=
ctx
.
GetPlace
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
place
);
T
*
batched_c_out_data
=
batched_c_out
->
mutable_data
<
T
>
(
place
);
T
*
batched_h_out_data
=
batched_h_out
->
mutable_data
<
T
>
(
place
);
hidden_out
->
mutable_data
<
T
>
(
place
);
cell_out
->
mutable_data
<
T
>
(
place
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
if
(
x_dims
[
1
]
>
wx_dims
[
1
])
{
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
wx_dims
[
1
],
x_dims
[
1
],
x_data
,
wx_data
,
xx_data
,
bias
->
data
<
T
>
());
to_batch
(
dev_ctx
,
*
xx
,
batched_gate
,
true
,
is_reverse
);
if
(
M
>
D4
)
{
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
D4
,
M
,
x_data
,
wx_data
,
xx_data
,
bias
->
data
<
T
>
());
to_batch
(
dev_ctx
,
*
xx
,
batched_input
,
true
,
is_reverse
);
}
else
{
to_batch
(
dev_ctx
,
*
x
,
xx
,
true
,
is_reverse
);
batched_
gate
->
set_lod
(
xx
->
lod
());
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
wx_dims
[
1
],
x_dims
[
1
]
,
xx_data
,
wx_data
,
batched_gate
_data
,
batched_
input
->
set_lod
(
xx
->
lod
());
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
D4
,
M
,
xx_data
,
wx_data
,
batched_input
_data
,
bias
->
data
<
T
>
());
}
int
frame_size
=
static_cast
<
int
>
(
wx_dims
[
1
]
/
4
);
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
math
::
LstmMetaValue
<
T
>
lstm_value
;
// no peephole
lstm_value
.
check_ig
=
nullptr
;
lstm_value
.
check_fg
=
nullptr
;
lstm_value
.
check_og
=
nullptr
;
lstm_value
.
prev_state_value
=
nullptr
;
Tensor
ordered_c0
;
framework
::
Vector
<
size_t
>
order
(
batched_gate
->
lod
()[
2
]);
if
(
cell_t0
)
{
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState
<
DeviceContext
,
T
>
(
dev_ctx
,
*
cell_t0
,
order
,
&
ordered_c0
,
true
);
lstm_value
.
prev_state_value
=
ordered_c0
.
data
<
T
>
();
auto
batched_lod
=
batched_input
->
lod
();
const
auto
&
seq_order
=
batched_lod
[
2
];
const
int
max_bs
=
seq_order
.
size
();
reordered_h0
->
Resize
({
max_bs
,
D
});
reordered_c0
->
Resize
({
max_bs
,
D
});
int
tstart
=
0
;
T
*
prev_h_data
=
NULL
;
T
*
prev_c_data
=
NULL
;
if
(
h0
)
{
// reorder h0, c0
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
place
);
T
*
reordered_c0_data
=
reordered_c0
->
mutable_data
<
T
>
(
place
);
const
T
*
h0_data
=
h0
->
data
<
T
>
();
const
T
*
c0_data
=
c0
->
data
<
T
>
();
prev_h_data
=
reordered_h0_data
;
prev_c_data
=
reordered_c0_data
;
size_t
sz
=
sizeof
(
T
)
*
D
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
std
::
memcpy
(
reordered_c0_data
,
c0_data
+
seq_order
[
i
]
*
D
,
sz
);
reordered_h0_data
+=
D
;
reordered_c0_data
+=
D
;
}
}
else
{
// compute without h0, c0
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
// W_ch, W_ih, W_fh, W_oh
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
// cell out= input*tilde
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_c_out_data
);
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
// add offset
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
}
tstart
=
1
;
prev_h_data
=
batched_h_out_data
;
prev_c_data
=
batched_c_out_data
;
}
// Then start from next
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
offset
=
tstart
*
max_bs
*
D
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
batched_h_out_data
=
batched_h_out_data
+
offset
;
batched_c_out_data
=
batched_c_out_data
+
offset
;
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
cur_bs
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
batched_input_data
,
D4
);
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_prev_c_data
=
prev_c_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
// W_ch, W_ih, W_fh, W_oh
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
// a = forget * prev_cell
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_prev_c_data
,
cur_in_data
+
D2
);
// b = input * tilde
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_in_data
+
D
);
// Use the local variable as here.
LoDTensor
batch_hidden
,
batch_cell
;
auto
*
batch_cell_pre_act
=
ctx
.
Output
<
LoDTensor
>
(
"BatchCellPreAct"
);
batch_hidden
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
batch_cell
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
batch_cell_pre_act
->
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
auto
batch_starts
=
batched_gate
->
lod
()[
0
];
size_t
max_seq_len
=
batch_starts
.
size
()
-
1
;
auto
gate_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
));
auto
cell_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
));
auto
cand_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
for
(
size_t
n
=
0
;
n
<
max_seq_len
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
Tensor
gate_t
=
batched_gate
->
Slice
(
bstart
,
bend
);
Tensor
out_t
=
batch_hidden
.
Slice
(
bstart
,
bend
);
Tensor
cell_t
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act_t
=
batch_cell_pre_act
->
Slice
(
bstart
,
bend
);
int
cur_batch_size
=
bend
-
bstart
;
if
(
n
>
0
)
{
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_t
=
batch_hidden
.
Slice
(
pre_h_start
,
pre_h_end
);
// TODO(TJ): use gemm directly
blas
.
MatMul
(
pre_hidden_t
,
false
,
*
wh
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
else
if
(
hidden_t0
)
{
// TODO(TJ): move h0 outside for
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
// If n == 0 and there is initialized hidden state, calculate W_h * H0.
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor
ordered_h0
;
ReorderInitState
<
DeviceContext
,
T
>
(
dev_ctx
,
*
hidden_t0
,
order
,
&
ordered_h0
,
true
);
// TODO(TJ): use gemm directly
blas
.
MatMul
(
ordered_h0
,
false
,
*
wh
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
// cell out= a+b
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D2
,
cur_c_out_data
);
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
cur_in_data
+=
D4
;
cur_prev_c_data
+=
D
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
}
lstm_value
.
gate_value
=
gate_t
.
data
<
T
>
();
lstm_value
.
output_value
=
out_t
.
data
<
T
>
();
lstm_value
.
state_value
=
cell_t
.
data
<
T
>
();
lstm_value
.
state_active_value
=
cell_pre_act_t
.
data
<
T
>
();
math
::
LstmUnitFunctor
<
DeviceContext
,
T
>::
compute
(
dev_ctx
,
lstm_value
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
lstm_value
.
prev_state_value
=
lstm_value
.
state_value
;
prev_c_data
=
batched_c_out_data
;
prev_h_data
=
batched_h_out_data
;
batched_c_out_data
=
cur_c_out_data
;
batched_h_out_data
=
cur_h_out_data
;
batched_input_data
=
cur_in_data
;
}
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batch_hidden
.
set_lod
(
batched_gate
->
lod
());
// restore the output hidden in LoDTensor from the batch hidden
to_seq
(
dev_ctx
,
batch_hidden
,
hidden_out
);
batch_cell
.
set_lod
(
batched_gate
->
lod
());
// restore the output cell state in LoDTensor from the batch cell
to_seq
(
dev_ctx
,
batch_cell
,
cell_out
);
batched_h_out
->
set_lod
(
batched_lod
);
to_seq
(
dev_ctx
,
*
batched_h_out
,
hidden_out
);
batched_c_out
->
set_lod
(
batched_lod
);
to_seq
(
dev_ctx
,
*
batched_c_out
,
cell_out
);
}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
if
(
FLAGS_seq_mode
)
{
SeqCompute
(
ctx
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录