Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1cc35f36
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1cc35f36
编写于
9月 03, 2018
作者:
T
tensor-tang
提交者:
GitHub
9月 03, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13118 from tensor-tang/optimize/op/fusion_lstm
Optimize fusion lstm batch mode
上级
6fb28796
93c034ee
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
221 addition
and
185 deletion
+221
-185
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+23
-3
paddle/fluid/operators/fusion_lstm_op.cc
paddle/fluid/operators/fusion_lstm_op.cc
+198
-182
未找到文件。
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
1cc35f36
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -94,11 +95,31 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
...
@@ -94,11 +95,31 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden_n
->
Name
()});
op_desc
.
SetOutput
(
"Hidden"
,
{
hidden_n
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell_n
->
Name
()});
op_desc
.
SetOutput
(
"Cell"
,
{
cell_n
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx_n
->
Name
()});
op_desc
.
SetOutput
(
"XX"
,
{
xx_n
->
Name
()});
op_desc
.
SetOutput
(
"BatchedGate"
,
{
"blstm_0.tmp_2"
});
op_desc
.
SetOutput
(
"BatchedInput"
,
{
"blstm_0.tmp_2"
});
op_desc
.
SetOutput
(
"BatchCellPreAct"
,
{
"blstm_1.tmp_2"
});
op_desc
.
SetAttr
(
"is_reverse"
,
lstm_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"is_reverse"
,
lstm_n
->
Op
()
->
GetAttr
(
"is_reverse"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
lstm_n
->
Op
()
->
GetAttr
(
"use_peepholes"
));
op_desc
.
SetAttr
(
"use_peepholes"
,
lstm_n
->
Op
()
->
GetAttr
(
"use_peepholes"
));
// TODO(TJ): get from attr
op_desc
.
SetAttr
(
"use_seq"
,
true
);
#define TMP_NAME(x) "at.new.tmp." #x
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)})
OP_SET_OUT
(
BatchedCell
);
OP_SET_OUT
(
BatchedHidden
);
OP_SET_OUT
(
ReorderedH0
);
OP_SET_OUT
(
ReorderedC0
);
#undef OP_SET_OUT
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
auto
*
op
=
graph
->
CreateOpNode
(
&
op_desc
);
PADDLE_ENFORCE
(
graph
->
Has
(
kParamScopeAttr
));
auto
*
scope
=
graph
->
Get
<
Scope
*>
(
kParamScopeAttr
);
#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable<LoDTensor>()
TMP_NEW
(
BatchedCell
);
TMP_NEW
(
BatchedHidden
);
TMP_NEW
(
ReorderedH0
);
TMP_NEW
(
ReorderedC0
);
#undef TMP_NEW
#undef TMP_NAME
#define LINK_TO(a, b) \
#define LINK_TO(a, b) \
a->outputs.push_back(b); \
a->outputs.push_back(b); \
...
@@ -116,7 +137,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
...
@@ -116,7 +137,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
auto
fc_no_bias_handler
=
[
&
](
auto
fc_no_bias_handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
#define GET_NODE(name__) \
#define GET_NODE(name__) \
std::string name__##key = name_scope + "/" + #name__; \
std::string name__##key = name_scope + "/" + #name__; \
auto* name__##n = pattern->RetrieveNode(name__##key); \
auto* name__##n = pattern->RetrieveNode(name__##key); \
...
...
paddle/fluid/operators/fusion_lstm_op.cc
浏览文件 @
1cc35f36
...
@@ -16,14 +16,10 @@ limitations under the License. */
...
@@ -16,14 +16,10 @@ limitations under the License. */
#include <string>
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool
(
seq_mode
,
true
,
"Use sequence mode"
);
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -42,10 +38,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -42,10 +38,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"Output(Hidden) of LSTM should not be null."
);
"Output(Hidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Output(Cell) of LSTM should not be null."
);
"Output(Cell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedGate"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Output(BatchedGate) of LSTM should not be null."
);
"Output(BatchedInput) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedHidden"
),
"Output(BatchedGate) of LSTM should not be null."
);
"Output(BatchedHidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedCell"
),
"Output(BatchedCell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Output(ReorderedH0) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedC0"
),
"Output(ReorderedC0) of LSTM should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
...
@@ -97,13 +99,14 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -97,13 +99,14 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedGate"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchCellPreAct"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedCell"
,
out_dims
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
ctx
->
ShareLoD
(
"X"
,
"Cell"
);
ctx
->
ShareLoD
(
"X"
,
"Cell"
);
int
xx_width
;
int
xx_width
;
if
(
FLAGS_seq_mode
)
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_seq"
)
)
{
xx_width
=
wx_dims
[
1
];
xx_width
=
wx_dims
[
1
];
}
else
{
}
else
{
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
...
@@ -169,9 +172,11 @@ void FusionLSTMOpMaker::Make() {
...
@@ -169,9 +172,11 @@ void FusionLSTMOpMaker::Make() {
" where T is the total time steps in this mini-batch,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input."
)
" D is the hidden size, M is the dim size of x input."
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"BatchedGate"
,
"(LoDTensor) (same as LSTMOp)."
).
AsIntermediate
();
AddOutput
(
"BatchedInput"
,
"(LoDTensor) (T x 4D)."
).
AsIntermediate
();
AddOutput
(
"BatchCellPreAct"
,
"(LoDTensor) (same as LSTMOp)."
)
AddOutput
(
"BatchedHidden"
,
"(LoDTensor) (T x D)."
).
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"BatchedCell"
,
"(LoDTensor) (T x D)."
).
AsIntermediate
();
AddOutput
(
"ReorderedH0"
,
"(LoDTensor) (N x D)."
).
AsIntermediate
();
AddOutput
(
"ReorderedC0"
,
"(LoDTensor) (N x D)."
).
AsIntermediate
();
AddAttr
<
bool
>
(
"use_peepholes"
,
AddAttr
<
bool
>
(
"use_peepholes"
,
"(bool, defalut: True) "
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections."
)
"whether to enable diagonal/peephole connections."
)
...
@@ -180,6 +185,10 @@ void FusionLSTMOpMaker::Make() {
...
@@ -180,6 +185,10 @@ void FusionLSTMOpMaker::Make() {
"(bool, defalut: False) "
"(bool, defalut: False) "
"whether to compute reversed LSTM."
)
"whether to compute reversed LSTM."
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"use_seq"
,
"(bool, defalut: True) "
"whether to use seq mode to compute."
)
.
SetDefault
(
true
);
AddAttr
<
std
::
string
>
(
"gate_activation"
,
AddAttr
<
std
::
string
>
(
"gate_activation"
,
"(string, default: sigmoid)"
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"The activation for input gate, forget gate and output "
...
@@ -203,64 +212,60 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
...
@@ -203,64 +212,60 @@ This operator fuse the X into LSTM, more details can refer to LSTM op.
)DOC"
);
)DOC"
);
}
}
template
<
typename
DeviceContext
,
typename
T
>
inline
void
ReorderInitState
(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
src
,
framework
::
Vector
<
size_t
>
index_lod
,
framework
::
Tensor
*
dst
,
bool
indexed_src
)
{
math
::
CopyMatrixRowsFunctor
<
DeviceContext
,
T
>
row_shuffle
;
dst
->
mutable_data
<
T
>
(
src
.
dims
(),
ctx
.
GetPlace
());
// TODO(TJ): check mem copy perf
row_shuffle
(
ctx
,
src
,
index_lod
,
dst
,
indexed_src
);
}
template
<
typename
T
>
template
<
typename
T
>
class
FuisonLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
class
FuisonLSTMKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
#define INIT_VEC_FUNC \
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand; \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \
auto& act_cell_str = ctx.Attr<std::string>("cell_activation"); \
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation"); \
if (platform::jit::MayIUse(platform::jit::avx)) { \
math::VecActivations<T, platform::jit::avx> act_functor; \
act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \
} else { \
math::VecActivations<T, platform::jit::isa_any> act_functor; \
act_gate = act_functor(act_gate_str); \
act_cell = act_functor(act_cell_str); \
act_cand = act_functor(act_cand_str); \
}
#define INIT_BASE_INPUT_OUTPUT \
auto* x = ctx.Input<LoDTensor>("X"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* c0 = ctx.Input<Tensor>("C0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
bool is_reverse = ctx.Attr<bool>("is_reverse");
#define INIT_BASE_SIZES \
auto x_dims = x->dims();
/* T x M*/
\
auto wh_dims = wh->dims();
/* D x 4D*/
\
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const int D3 = D * 3; \
const int D4 = wh_dims[1];
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
INIT_BASE_INPUT_OUTPUT
auto
*
h0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
INIT_BASE_SIZES
auto
*
c0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
INIT_VEC_FUNC
auto
*
wx
=
ctx
.
Input
<
Tensor
>
(
"WeightX"
);
auto
*
wh
=
ctx
.
Input
<
Tensor
>
(
"WeightH"
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
xx
=
ctx
.
Output
<
LoDTensor
>
(
"XX"
);
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"is_reverse"
);
std
::
function
<
void
(
const
int
,
const
T
*
,
T
*
)
>
act_gate
,
act_cell
,
act_cand
;
auto
&
act_gate_str
=
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
);
auto
&
act_cell_str
=
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
);
auto
&
act_cand_str
=
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
);
if
(
platform
::
jit
::
MayIUse
(
platform
::
jit
::
avx
))
{
math
::
VecActivations
<
T
,
platform
::
jit
::
avx
>
act_functor
;
act_gate
=
act_functor
(
act_gate_str
);
act_cell
=
act_functor
(
act_cell_str
);
act_cand
=
act_functor
(
act_cand_str
);
}
else
{
math
::
VecActivations
<
T
,
platform
::
jit
::
isa_any
>
act_functor
;
act_gate
=
act_functor
(
act_gate_str
);
act_cell
=
act_functor
(
act_cell_str
);
act_cand
=
act_functor
(
act_cand_str
);
}
auto
x_lod
=
x
->
lod
();
auto
x_lod
=
x
->
lod
();
auto
x_dims
=
x
->
dims
();
// T x M
auto
wh_dims
=
wh
->
dims
();
// D x 4D
const
int
total_T
=
x_dims
[
0
];
const
int
total_T
=
x_dims
[
0
];
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
// batch size
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
// batch size
const
int
M
=
x_dims
[
1
];
// x frame size
const
int
D
=
wh_dims
[
0
];
const
int
D2
=
D
*
2
;
const
int
D3
=
D
*
3
;
const
int
D4
=
wh_dims
[
1
];
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
NULL
;
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
const
T
*
c0_data
=
c0
?
c0
->
data
<
T
>
()
:
NULL
;
const
T
*
c0_data
=
c0
?
c0
->
data
<
T
>
()
:
nullptr
;
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
...
@@ -290,12 +295,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -290,12 +295,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
int
bid
=
is_reverse
?
N
-
1
-
i
:
i
;
int
bid
=
is_reverse
?
N
-
1
-
i
:
i
;
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
const
T
*
prev_c
ell_data
=
NULL
;
const
T
*
prev_c
_data
=
nullptr
;
const
T
*
prev_h
idden_data
=
NULL
;
const
T
*
prev_h
_data
=
nullptr
;
int
tstart
=
0
;
int
tstart
=
0
;
if
(
h0_data
)
{
if
(
h0_data
)
{
prev_h
idden
_data
=
h0_data
+
bid
*
D
;
prev_h_data
=
h0_data
+
bid
*
D
;
prev_c
ell
_data
=
c0_data
+
bid
*
D
;
prev_c_data
=
c0_data
+
bid
*
D
;
}
else
{
}
else
{
// W_ch, W_ih, W_fh, W_oh
// W_ch, W_ih, W_fh, W_oh
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
...
@@ -307,23 +312,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -307,23 +312,22 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
// prev
prev_h
idden
_data
=
hidden_out_data
;
prev_h_data
=
hidden_out_data
;
prev_c
ell
_data
=
cell_out_data
;
prev_c_data
=
cell_out_data
;
tstart
=
1
;
tstart
=
1
;
move_step
();
move_step
();
}
}
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_hidden_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
D4
);
D4
);
// W_ch, W_ih, W_fh, W_oh
// W_ch, W_ih, W_fh, W_oh
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
act_cand
(
D
,
xx_data
,
xx_data
);
act_cand
(
D
,
xx_data
,
xx_data
);
// a = forget * prev_cell
// a = forget * prev_cell
blas
.
VMUL
(
D
,
xx_data
+
D2
,
prev_c
ell
_data
,
xx_data
+
D2
);
blas
.
VMUL
(
D
,
xx_data
+
D2
,
prev_c_data
,
xx_data
+
D2
);
// b = input * tilde
// b = input * tilde
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
xx_data
+
D
);
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
xx_data
+
D
);
...
@@ -336,8 +340,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -336,8 +340,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
// prev
prev_h
idden
_data
=
hidden_out_data
;
prev_h_data
=
hidden_out_data
;
prev_c
ell
_data
=
cell_out_data
;
prev_c_data
=
cell_out_data
;
move_step
();
move_step
();
}
}
...
@@ -346,143 +350,155 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -346,143 +350,155 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
platform
::
CPUDeviceContext
;
using
DeviceContext
=
platform
::
CPUDeviceContext
;
auto
*
x
=
ctx
.
Input
<
LoDTensor
>
(
"X"
);
INIT_BASE_INPUT_OUTPUT
auto
*
wx
=
ctx
.
Input
<
Tensor
>
(
"WeightX"
);
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
auto
*
wh
=
ctx
.
Input
<
Tensor
>
(
"WeightH"
);
SeqCompute
(
ctx
);
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
return
;
auto
*
hidden_t0
=
ctx
.
Input
<
Tensor
>
(
"H0"
);
}
auto
*
cell_t0
=
ctx
.
Input
<
Tensor
>
(
"C0"
);
INIT_BASE_SIZES
INIT_VEC_FUNC
auto
*
xx
=
ctx
.
Output
<
LoDTensor
>
(
"XX"
);
auto
*
batched_gate
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedGate"
);
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
bool
is_reverse
=
ctx
.
Attr
<
bool
>
(
"is_reverse"
);
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
T
*
batched_gate_data
=
batched_gate
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
reordered_c0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedC0"
);
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
batched_input
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedInput"
);
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
batched_c_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedCell"
);
auto
*
batched_h_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedHidden"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
auto
x_dims
=
x
->
dims
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
auto
wx_dims
=
wx
->
dims
();
auto
place
=
ctx
.
GetPlace
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
place
);
T
*
batched_c_out_data
=
batched_c_out
->
mutable_data
<
T
>
(
place
);
T
*
batched_h_out_data
=
batched_h_out
->
mutable_data
<
T
>
(
place
);
hidden_out
->
mutable_data
<
T
>
(
place
);
cell_out
->
mutable_data
<
T
>
(
place
);
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
if
(
x_dims
[
1
]
>
wx_dims
[
1
])
{
if
(
M
>
D4
)
{
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
wx_dims
[
1
],
x_dims
[
1
],
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
D4
,
M
,
x_data
,
wx_data
,
x_data
,
wx_data
,
xx_data
,
xx_data
,
bias
->
data
<
T
>
());
bias
->
data
<
T
>
());
to_batch
(
dev_ctx
,
*
xx
,
batched_input
,
true
,
is_reverse
);
to_batch
(
dev_ctx
,
*
xx
,
batched_gate
,
true
,
is_reverse
);
}
else
{
}
else
{
to_batch
(
dev_ctx
,
*
x
,
xx
,
true
,
is_reverse
);
to_batch
(
dev_ctx
,
*
x
,
xx
,
true
,
is_reverse
);
batched_
gate
->
set_lod
(
xx
->
lod
());
batched_
input
->
set_lod
(
xx
->
lod
());
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
wx_dims
[
1
],
x_dims
[
1
]
,
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
x_dims
[
0
],
D4
,
M
,
xx_data
,
xx_data
,
wx_data
,
batched_gate
_data
,
wx_data
,
batched_input
_data
,
bias
->
data
<
T
>
());
bias
->
data
<
T
>
());
}
}
int
frame_size
=
static_cast
<
int
>
(
wx_dims
[
1
]
/
4
);
auto
batched_lod
=
batched_input
->
lod
();
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
const
auto
&
seq_order
=
batched_lod
[
2
];
math
::
LstmMetaValue
<
T
>
lstm_value
;
const
int
max_bs
=
seq_order
.
size
();
// no peephole
reordered_h0
->
Resize
({
max_bs
,
D
});
lstm_value
.
check_ig
=
nullptr
;
reordered_c0
->
Resize
({
max_bs
,
D
});
lstm_value
.
check_fg
=
nullptr
;
lstm_value
.
check_og
=
nullptr
;
int
tstart
=
0
;
lstm_value
.
prev_state_value
=
nullptr
;
T
*
prev_h_data
=
nullptr
;
Tensor
ordered_c0
;
T
*
prev_c_data
=
nullptr
;
if
(
h0
)
{
framework
::
Vector
<
size_t
>
order
(
batched_gate
->
lod
()[
2
]);
// reorder h0, c0
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
place
);
if
(
cell_t0
)
{
T
*
reordered_c0_data
=
reordered_c0
->
mutable_data
<
T
>
(
place
);
// Since the batch computing for LSTM reorders the input sequence
const
T
*
h0_data
=
h0
->
data
<
T
>
();
// according to their length. The initialized cell state also needs
const
T
*
c0_data
=
c0
->
data
<
T
>
();
// to reorder.
prev_h_data
=
reordered_h0_data
;
ReorderInitState
<
DeviceContext
,
T
>
(
dev_ctx
,
*
cell_t0
,
order
,
&
ordered_c0
,
prev_c_data
=
reordered_c0_data
;
true
);
size_t
sz
=
sizeof
(
T
)
*
D
;
lstm_value
.
prev_state_value
=
ordered_c0
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
std
::
memcpy
(
reordered_c0_data
,
c0_data
+
seq_order
[
i
]
*
D
,
sz
);
reordered_h0_data
+=
D
;
reordered_c0_data
+=
D
;
}
}
else
{
// compute without h0, c0
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
// W_ch, W_ih, W_fh, W_oh
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
// cell out= input*tilde
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_c_out_data
);
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
// add offset
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
}
tstart
=
1
;
prev_h_data
=
batched_h_out_data
;
prev_c_data
=
batched_c_out_data
;
}
}
// Then start from next
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
offset
=
tstart
*
max_bs
*
D
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
batched_h_out_data
=
batched_h_out_data
+
offset
;
batched_c_out_data
=
batched_c_out_data
+
offset
;
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
cur_bs
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
batched_input_data
,
D4
);
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_prev_c_data
=
prev_c_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
// W_ch, W_ih, W_fh, W_oh
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
// a = forget * prev_cell
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_prev_c_data
,
cur_in_data
+
D2
);
// b = input * tilde
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_in_data
+
D
);
// cell out= a+b
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D2
,
cur_c_out_data
);
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
// Use the local variable as here.
cur_in_data
+=
D4
;
LoDTensor
batch_hidden
,
batch_cell
;
cur_prev_c_data
+=
D
;
auto
*
batch_cell_pre_act
=
ctx
.
Output
<
LoDTensor
>
(
"BatchCellPreAct"
);
cur_c_out_data
+=
D
;
batch_hidden
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
cur_h_out_data
+=
D
;
batch_cell
.
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
batch_cell_pre_act
->
mutable_data
<
T
>
(
out_dims
,
ctx
.
GetPlace
());
auto
batch_starts
=
batched_gate
->
lod
()[
0
];
size_t
max_seq_len
=
batch_starts
.
size
()
-
1
;
auto
gate_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"gate_activation"
));
auto
cell_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"cell_activation"
));
auto
cand_act
=
math
::
detail
::
GetActivationType
(
ctx
.
Attr
<
std
::
string
>
(
"candidate_activation"
));
for
(
size_t
n
=
0
;
n
<
max_seq_len
;
n
++
)
{
int
bstart
=
static_cast
<
int
>
(
batch_starts
[
n
]);
int
bend
=
static_cast
<
int
>
(
batch_starts
[
n
+
1
]);
Tensor
gate_t
=
batched_gate
->
Slice
(
bstart
,
bend
);
Tensor
out_t
=
batch_hidden
.
Slice
(
bstart
,
bend
);
Tensor
cell_t
=
batch_cell
.
Slice
(
bstart
,
bend
);
Tensor
cell_pre_act_t
=
batch_cell_pre_act
->
Slice
(
bstart
,
bend
);
int
cur_batch_size
=
bend
-
bstart
;
if
(
n
>
0
)
{
int
pre_h_start
=
static_cast
<
int
>
(
batch_starts
[
n
-
1
]);
int
pre_h_end
=
pre_h_start
+
cur_batch_size
;
auto
pre_hidden_t
=
batch_hidden
.
Slice
(
pre_h_start
,
pre_h_end
);
// TODO(TJ): use gemm directly
blas
.
MatMul
(
pre_hidden_t
,
false
,
*
wh
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
else
if
(
hidden_t0
)
{
// TODO(TJ): move h0 outside for
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
// If n == 0 and there is initialized hidden state, calculate W_h * H0.
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor
ordered_h0
;
ReorderInitState
<
DeviceContext
,
T
>
(
dev_ctx
,
*
hidden_t0
,
order
,
&
ordered_h0
,
true
);
// TODO(TJ): use gemm directly
blas
.
MatMul
(
ordered_h0
,
false
,
*
wh
,
false
,
static_cast
<
T
>
(
1.0
),
&
gate_t
,
static_cast
<
T
>
(
1.0
));
}
}
lstm_value
.
gate_value
=
gate_t
.
data
<
T
>
();
prev_c_data
=
batched_c_out_data
;
lstm_value
.
output_value
=
out_t
.
data
<
T
>
();
prev_h_data
=
batched_h_out_data
;
lstm_value
.
state_value
=
cell_t
.
data
<
T
>
();
batched_c_out_data
=
cur_c_out_data
;
lstm_value
.
state_active_value
=
cell_pre_act_t
.
data
<
T
>
();
batched_h_out_data
=
cur_h_out_data
;
math
::
LstmUnitFunctor
<
DeviceContext
,
T
>::
compute
(
batched_input_data
=
cur_in_data
;
dev_ctx
,
lstm_value
,
frame_size
,
cur_batch_size
,
gate_act
,
cell_act
,
cand_act
);
lstm_value
.
prev_state_value
=
lstm_value
.
state_value
;
}
}
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batch_hidden
.
set_lod
(
batched_gate
->
lod
());
batched_h_out
->
set_lod
(
batched_lod
);
// restore the output hidden in LoDTensor from the batch hidden
to_seq
(
dev_ctx
,
*
batched_h_out
,
hidden_out
);
to_seq
(
dev_ctx
,
batch_hidden
,
hidden_out
);
batched_c_out
->
set_lod
(
batched_lod
);
to_seq
(
dev_ctx
,
*
batched_c_out
,
cell_out
);
batch_cell
.
set_lod
(
batched_gate
->
lod
());
// restore the output cell state in LoDTensor from the batch cell
to_seq
(
dev_ctx
,
batch_cell
,
cell_out
);
}
}
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
if
(
FLAGS_seq_mode
)
{
if
(
ctx
.
Attr
<
bool
>
(
"use_seq"
)
)
{
SeqCompute
(
ctx
);
SeqCompute
(
ctx
);
}
else
{
}
else
{
BatchCompute
(
ctx
);
BatchCompute
(
ctx
);
}
}
}
}
#undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_VEC_FUNC
};
};
}
// namespace operators
}
// namespace operators
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录