Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c459fb5b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c459fb5b
编写于
8月 31, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fusion lstm batch mode
上级
c709a04a
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
175 addition
and
132 deletion
+175
-132
paddle/fluid/operators/fusion_lstm_op.cc
paddle/fluid/operators/fusion_lstm_op.cc
+175
-132
未找到文件。
paddle/fluid/operators/fusion_lstm_op.cc
浏览文件 @
c459fb5b
...
@@ -22,7 +22,7 @@ limitations under the License. */
...
@@ -22,7 +22,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool
(
seq_mode
,
tru
e
,
"Use sequence mode"
);
DEFINE_bool
(
seq_mode
,
fals
e
,
"Use sequence mode"
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -42,10 +42,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -42,10 +42,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"Output(Hidden) of LSTM should not be null."
);
"Output(Hidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Output(Cell) of LSTM should not be null."
);
"Output(Cell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedGate"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Output(BatchedGate) of LSTM should not be null."
);
"Output(BatchedInput) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedHidden"
),
"Output(BatchedGate) of LSTM should not be null."
);
"Output(BatchedHidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedCell"
),
"Output(BatchedCell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Output(ReorderedH0) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedC0"
),
"Output(ReorderedC0) of LSTM should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
...
@@ -97,8 +103,9 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -97,8 +103,9 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedGate"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchCellPreAct"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedCell"
,
out_dims
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
ctx
->
ShareLoD
(
"X"
,
"Cell"
);
ctx
->
ShareLoD
(
"X"
,
"Cell"
);
...
@@ -169,9 +176,11 @@ void FusionLSTMOpMaker::Make() {
...
@@ -169,9 +176,11 @@ void FusionLSTMOpMaker::Make() {
" where T is the total time steps in this mini-batch,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input."
)
" D is the hidden size, M is the dim size of x input."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"BatchedGate"
,
"(LoDTensor) (same as LSTMOp)."
).
AsIntermediate
();
AddOutput
(
"BatchedInput"
,
"(LoDTensor) (T x 4D)."
).
AsIntermediate
();
AddOutput
(
"BatchCellPreAct"
,
"(LoDTensor) (same as LSTMOp)."
)
AddOutput
(
"BatchedHidden"
,
"(LoDTensor) (T x D)."
).
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"BatchedCell"
,
"(LoDTensor) (T x D)."
).
AsIntermediate
();
AddOutput
(
"ReorderedH0"
,
"(LoDTensor) (N x D)."
).
AsIntermediate
();
AddOutput
(
"ReorderedC0"
,
"(LoDTensor) (N x D)."
).
AsIntermediate
();
AddAttr
<
bool
>
(
"use_peepholes"
,
AddAttr
<
bool
>
(
"use_peepholes"
,
"(bool, defalut: True) "
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections."
)
"whether to enable diagonal/peephole connections."
)
...
@@ -203,17 +212,6 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
...
@@ -203,17 +212,6 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
)DOC"
);
)DOC"
);
}
}
template
<
typename
DeviceContext
,
typename
T
>
inline
void
ReorderInitState
(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
src
,
framework
::
Vector
<
size_t
>
index_lod
,
framework
::
Tensor
*
dst
,
bool
indexed_src
)
{
math
::
CopyMatrixRowsFunctor
<
DeviceContext
,
T
>
row_shuffle
;
dst
->
mutable_data
<
T
>
(
src
.
dims
(),
ctx
.
GetPlace
());
// TODO(TJ): check mem copy perf
row_shuffle
(
ctx
,
src
,
index_lod
,
dst
,
indexed_src
);
}
template
<
typename
T
>
template
<
typename
T
>
class
FuisonLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FuisonLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -290,12 +288,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -290,12 +288,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
int
bid
=
is_reverse
?
N
-
1
-
i
:
i
;
int
bid
=
is_reverse
?
N
-
1
-
i
:
i
;
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
const
T
*
prev_c
ell
_data
=
NULL
;
const
T
*
prev_c_data
=
NULL
;
const
T
*
prev_h
idden
_data
=
NULL
;
const
T
*
prev_h_data
=
NULL
;
int
tstart
=
0
;
int
tstart
=
0
;
if
(
h0_data
)
{
if
(
h0_data
)
{
prev_h
idden
_data
=
h0_data
+
bid
*
D
;
prev_h_data
=
h0_data
+
bid
*
D
;
prev_c
ell
_data
=
c0_data
+
bid
*
D
;
prev_c_data
=
c0_data
+
bid
*
D
;
}
else
{
}
else
{
// W_ch, W_ih, W_fh, W_oh
// W_ch, W_ih, W_fh, W_oh
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
...
@@ -307,23 +305,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -307,23 +305,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
// prev
prev_h
idden
_data
=
hidden_out_data
;
prev_h_data
=
hidden_out_data
;
prev_c
ell
_data
=
cell_out_data
;
prev_c_data
=
cell_out_data
;
tstart
=
1
;
tstart
=
1
;
move_step
();
move_step
();
}
}
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_hidden_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
D4
);
D4
);
// W_ch, W_ih, W_fh, W_oh
// W_ch, W_ih, W_fh, W_oh
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
act_cand
(
D
,
xx_data
,
xx_data
);
act_cand
(
D
,
xx_data
,
xx_data
);
// a = forget * prev_cell
// a = forget * prev_cell
blas
.
VMUL
(
D
,
xx_data
+
D2
,
prev_c
ell
_data
,
xx_data
+
D2
);
blas
.
VMUL
(
D
,
xx_data
+
D2
,
prev_c_data
,
xx_data
+
D2
);
// b = input * tilde
// b = input * tilde
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
xx_data
+
D
);
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
xx_data
+
D
);
...
@@ -336,8 +333,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -336,8 +333,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
// prev
prev_h
idden
_data
=
hidden_out_data
;
prev_h_data
=
hidden_out_data
;
prev_c
ell
_data
=
cell_out_data
;
prev_c_data
=
cell_out_data
;
move_step
();
move_step
();
}
}
...
@@ -350,132 +347,178 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -350,132 +347,178 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
auto
*
wx
=
ctx
.
Input
<
Tensor
>
(
"WeightX"
);
auto
*
wx
=
ctx
.
Input
<
Tensor
>
(
"WeightX"
);
auto
*
wh
=
ctx
.
Input
<
Tensor
>
(
"WeightH"
);
auto
*
wh
=
ctx
.
Input
<
Tensor
>
(
"WeightH"
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
h
idden_t
0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
h0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
auto
*
c
ell_t
0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
auto
*
c0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
auto
*
xx
=
ctx
.
Output
<
LoDTensor
>
(
"XX"
);
auto
*
xx
=
ctx
.
Output
<
LoDTensor
>
(
"XX"
);
auto
*
batched_gate
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedGate"
);
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
auto
*
reordered_c0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedC0"
);
auto
*
batched_input
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedInput"
);
auto
*
batched_c_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedCell"
);
auto
*
batched_h_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedHidden"
);
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"is_reverse"
);
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"is_reverse"
);
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
std
::
function
<
void
(
const
int
,
const
T
*
,
T
*
)
>
act_gate
,
act_cell
,
act_cand
;
T
*
batched_gate_data
=
batched_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
act_gate_str
=
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
);
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
act_cell_str
=
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
);
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
act_cand_str
=
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
);
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx
))
{
math
::
VecActivations
<
T
,
platform
::
jit
::
avx
>
act_functor
;
act_gate
=
act_functor
(
act_gate_str
);
act_cell
=
act_functor
(
act_cell_str
);
act_cand
=
act_functor
(
act_cand_str
);
}
else
{
math
::
VecActivations
<
T
,
platform
::
jit
::
isa_any
>
act_functor
;
act_gate
=
act_functor
(
act_gate_str
);
act_cell
=
act_functor
(
act_cell_str
);
act_cand
=
act_functor
(
act_cand_str
);
}
auto
x_dims
=
x
->
dims
();
// T x M
auto
wh_dims
=
wh
->
dims
();
// D x 4D
// auto x_lod = x->lod();
// const int N = x_lod[0].size() - 1; // batch size
// if (N == 1) {
// SeqCompute(ctx);
// }
const
int
M
=
x_dims
[
1
];
const
int
D
=
wh_dims
[
0
];
const
int
D2
=
D
*
2
;
const
int
D3
=
D
*
3
;
const
int
D4
=
wh_dims
[
1
];
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
auto
x_dims
=
x
->
dims
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
auto
wx_dims
=
wx
->
dims
();
auto
place
=
ctx
.
GetPlace
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
place
);
T
*
batched_c_out_data
=
batched_c_out
->
mutable_data
<
T
>
(
place
);
T
*
batched_h_out_data
=
batched_h_out
->
mutable_data
<
T
>
(
place
);
hidden_out
->
mutable_data
<
T
>
(
place
);
cell_out
->
mutable_data
<
T
>
(
place
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
if
(
x_dims
[
1
]
>
wx_dims
[
1
])
{
if
(
M
>
D4
)
{
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
wx_dims
[
1
],
x_dims
[
1
],
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
D4
,
M
,
x_data
,
wx_data
,
x_data
,
wx_data
,
xx_data
,
xx_data
,
bias
->
data
<
T
>
());
bias
->
data
<
T
>
());
to_batch
(
dev_ctx
,
*
xx
,
batched_input
,
true
,
is_reverse
);
to_batch
(
dev_ctx
,
*
xx
,
batched_gate
,
true
,
is_reverse
);
}
else
{
}
else
{
to_batch
(
dev_ctx
,
*
x
,
xx
,
true
,
is_reverse
);
to_batch
(
dev_ctx
,
*
x
,
xx
,
true
,
is_reverse
);
batched_
gate
->
set_lod
(
xx
->
lod
());
batched_
input
->
set_lod
(
xx
->
lod
());
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
wx_dims
[
1
],
x_dims
[
1
]
,
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
D4
,
M
,
xx_data
,
xx_data
,
wx_data
,
batched_gate
_data
,
wx_data
,
batched_input
_data
,
bias
->
data
<
T
>
());
bias
->
data
<
T
>
());
}
}
int
frame_size
=
static_cast
<
int
>
(
wx_dims
[
1
]
/
4
);
auto
batched_lod
=
batched_input
->
lod
();
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
const
auto
&
seq_order
=
batched_lod
[
2
];
math
::
LstmMetaValue
<
T
>
lstm_value
;
const
int
max_bs
=
seq_order
.
size
();
// no peephole
reordered_h0
->
Resize
({
max_bs
,
D
});
lstm_value
.
check_ig
=
nullptr
;
reordered_c0
->
Resize
({
max_bs
,
D
});
lstm_value
.
check_fg
=
nullptr
;
lstm_value
.
check_og
=
nullptr
;
int
tstart
=
0
;
lstm_value
.
prev_state_value
=
nullptr
;
T
*
prev_h_data
=
NULL
;
Tensor
ordered_c0
;
T
*
prev_c_data
=
NULL
;
if
(
h0
)
{
framework
::
Vector
<
size_t
>
order
(
batched_gate
->
lod
()[
2
]);
// reorder h0, c0
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
place
);
if
(
cell_t0
)
{
T
*
reordered_c0_data
=
reordered_c0
->
mutable_data
<
T
>
(
place
);
// Since the batch computing for LSTM reorders the input sequence
const
T
*
h0_data
=
h0
->
data
<
T
>
();
// according to their length. The initialized cell state also needs
const
T
*
c0_data
=
c0
->
data
<
T
>
();
// to reorder.
prev_h_data
=
reordered_h0_data
;
ReorderInitState
<
DeviceContext
,
T
>
(
dev_ctx
,
*
cell_t0
,
order
,
&
ordered_c0
,
prev_c_data
=
reordered_c0_data
;
true
);
size_t
sz
=
sizeof
(
T
)
*
D
;
lstm_value
.
prev_state_value
=
ordered_c0
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
std
::
memcpy
(
reordered_c0_data
,
c0_data
+
seq_order
[
i
]
*
D
,
sz
);
reordered_h0_data
+=
D
;
reordered_c0_data
+=
D
;
}
}
else
{
// compute without h0, c0
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
// W_ch, W_ih, W_fh, W_oh
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
// cell out= input*tilde
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_c_out_data
);
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
// add offset
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
}
tstart
=
1
;
prev_h_data
=
batched_h_out_data
;
prev_c_data
=
batched_c_out_data
;
}
}
// Then start from next
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
offset
=
tstart
*
max_bs
*
D
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
batched_h_out_data
=
batched_h_out_data
+
offset
;
batched_c_out_data
=
batched_c_out_data
+
offset
;
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
cur_bs
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
batched_input_data
,
D4
);
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_prev_c_data
=
prev_c_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
// W_ch, W_ih, W_fh, W_oh
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
// a = forget * prev_cell
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_prev_c_data
,
cur_in_data
+
D2
);
// b = input * tilde
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_in_data
+
D
);
// Use the local variable as here.
// cell out= a+b
LoDTensor
batch_hidden
,
batch_cell
;
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D2
,
cur_c_out_data
);
auto
*
batch_cell_pre_act
=
ctx
.
Output
<
LoDTensor
>
(
"BatchCellPreAct"
);
batch_hidden
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
// hidden out= act_state(cellout) * outgate
batch_cell
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
batch_cell_pre_act
->
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
auto
batch_starts
=
batched_gate
->
lod
()[
0
];
cur_in_data
+=
D4
;
size_t
max_seq_len
=
batch_starts
.
size
()
-
1
;
cur_prev_c_data
+=
D
;
auto
gate_act
=
math
::
detail
::
GetActivationType
(
cur_c_out_data
+=
D
;
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
));
cur_h_out_data
+=
D
;
auto
cell_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
));
auto
cand_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
for
(
size_t
n
=
0
;
n
<
max_seq_len
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
Tensor
gate_t
=
batched_gate
->
Slice
(
bstart
,
bend
);
Tensor
out_t
=
batch_hidden
.
Slice
(
bstart
,
bend
);
Tensor
cell_t
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act_t
=
batch_cell_pre_act
->
Slice
(
bstart
,
bend
);
int
cur_batch_size
=
bend
-
bstart
;
if
(
n
>
0
)
{
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_t
=
batch_hidden
.
Slice
(
pre_h_start
,
pre_h_end
);
// TODO(TJ): use gemm directly
blas
.
MatMul
(
pre_hidden_t
,
false
,
*
wh
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
else
if
(
hidden_t0
)
{
// TODO(TJ): move h0 outside for
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
// If n == 0 and there is initialized hidden state, calculate W_h * H0.
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor
ordered_h0
;
ReorderInitState
<
DeviceContext
,
T
>
(
dev_ctx
,
*
hidden_t0
,
order
,
&
ordered_h0
,
true
);
// TODO(TJ): use gemm directly
blas
.
MatMul
(
ordered_h0
,
false
,
*
wh
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
}
lstm_value
.
gate_value
=
gate_t
.
data
<
T
>
();
prev_c_data
=
batched_c_out_data
;
lstm_value
.
output_value
=
out_t
.
data
<
T
>
();
prev_h_data
=
batched_h_out_data
;
lstm_value
.
state_value
=
cell_t
.
data
<
T
>
();
batched_c_out_data
=
cur_c_out_data
;
lstm_value
.
state_active_value
=
cell_pre_act_t
.
data
<
T
>
();
batched_h_out_data
=
cur_h_out_data
;
math
::
LstmUnitFunctor
<
DeviceContext
,
T
>::
compute
(
batched_input_data
=
cur_in_data
;
dev_ctx
,
lstm_value
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
lstm_value
.
prev_state_value
=
lstm_value
.
state_value
;
}
}
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batch_hidden
.
set_lod
(
batched_gate
->
lod
());
batched_h_out
->
set_lod
(
batched_lod
);
// restore the output hidden in LoDTensor from the batch hidden
to_seq
(
dev_ctx
,
*
batched_h_out
,
hidden_out
);
to_seq
(
dev_ctx
,
batch_hidden
,
hidden_out
);
batched_c_out
->
set_lod
(
batched_lod
);
to_seq
(
dev_ctx
,
*
batched_c_out
,
cell_out
);
batch_cell
.
set_lod
(
batched_gate
->
lod
());
// restore the output cell state in LoDTensor from the batch cell
to_seq
(
dev_ctx
,
batch_cell
,
cell_out
);
}
}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
if
(
FLAGS_seq_mode
)
{
if
(
FLAGS_seq_mode
)
{
SeqCompute
(
ctx
);
SeqCompute
(
ctx
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录