Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9f2ccf5b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9f2ccf5b
编写于
9月 06, 2018
作者:
T
tensor-tang
提交者:
GitHub
9月 06, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13237 from tensor-tang/refine/op/peephole
refine fusion lstm/peephole and fusion gru
上级
225ecee5
718033e1
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
254 addition
and
329 deletion
+254
-329
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+1
-0
paddle/fluid/operators/fusion_gru_op.cc
paddle/fluid/operators/fusion_gru_op.cc
+8
-10
paddle/fluid/operators/fusion_lstm_op.cc
paddle/fluid/operators/fusion_lstm_op.cc
+240
-280
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
+5
-39
未找到文件。
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
9f2ccf5b
...
...
@@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
...
...
paddle/fluid/operators/fusion_gru_op.cc
浏览文件 @
9f2ccf5b
...
...
@@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
"Input(WeightX) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightH"
),
"Input(WeightH) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XX"
),
"Output(XX) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Output(ReorderedH0) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Output(BatchedInput) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedOut"
),
"Output(BatchedOut) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
"Output(Hidden) of GRU should not be null."
);
...
...
@@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
}
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedOut"
,
out_dims
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
int
xx_width
;
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_seq"
))
{
xx_width
=
wx_dims
[
1
];
}
else
{
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Output(ReorderedH0) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Output(BatchedInput) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedOut"
),
"Output(BatchedOut) of GRU should not be null."
);
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedOut"
,
out_dims
);
}
ctx
->
SetOutputDim
(
"XX"
,
{
x_dims
[
0
],
xx_width
});
ctx
->
ShareLoD
(
"X"
,
"XX"
);
...
...
paddle/fluid/operators/fusion_lstm_op.cc
浏览文件 @
9f2ccf5b
...
...
@@ -38,16 +38,6 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"Output(Hidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Output(Cell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Output(BatchedInput) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedHidden"
),
"Output(BatchedHidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedCell"
),
"Output(BatchedCell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Output(ReorderedH0) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedC0"
),
"Output(ReorderedC0) of LSTM should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
...
...
@@ -88,28 +78,36 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
auto
use_peepholes
=
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
(
use_peepholes
?
7
:
4
)
*
frame_size
,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection or"
"4 * %d if disable peepholes"
,
frame_size
,
frame_size
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
)
?
7
:
4
)
*
frame_size
,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection or"
"4 * %d if disable peepholes"
,
frame_size
,
frame_size
);
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedCell"
,
out_dims
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
ctx
->
ShareLoD
(
"X"
,
"Cell"
);
int
xx_width
;
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_seq"
))
{
xx_width
=
wx_dims
[
1
];
}
else
{
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Output(BatchedInput) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedHidden"
),
"Output(BatchedHidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedCell"
),
"Output(BatchedCell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Output(ReorderedH0) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedC0"
),
"Output(ReorderedC0) of LSTM should not be null."
);
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedCell"
,
out_dims
);
}
ctx
->
SetOutputDim
(
"XX"
,
{
x_dims
[
0
],
xx_width
});
ctx
->
ShareLoD
(
"X"
,
"XX"
);
...
...
@@ -232,18 +230,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
act_cand = act_functor(act_cand_str); \
}
#define INIT_BASE_INPUT_OUTPUT
\
auto* x = ctx.Input<LoDTensor>("X");
\
auto* h0 = ctx.Input<Tensor>("H0");
\
auto* c0 = ctx.Input<Tensor>("C0");
\
auto* wx = ctx.Input<Tensor>("WeightX");
\
auto* wh = ctx.Input<Tensor>("WeightH");
\
auto* bias = ctx.Input<Tensor>("Bias");
\
auto* xx = ctx.Output<LoDTensor>("XX");
\
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
\
auto* cell_out = ctx.Output<LoDTensor>("Cell");
\
bool
use_peepholes = ctx.Attr<bool>("use_peepholes");
\
bool
is_reverse = ctx.Attr<bool>("is_reverse
");
#define INIT_BASE_INPUT_OUTPUT \
auto* x = ctx.Input<LoDTensor>("X"); \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* c0 = ctx.Input<Tensor>("C0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
bool
is_reverse = ctx.Attr<bool>("is_reverse");
\
bool
use_peepholes = ctx.Attr<bool>("use_peepholes
");
#define INIT_BASE_SIZES \
auto x_dims = x->dims();
/* T x M*/
\
...
...
@@ -254,172 +252,183 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D3 = D * 3; \
const int D4 = wh_dims[1];
#define INIT_BASE_INPUT_DATAS \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/
\
const T* wc_data = bias->data<T>() + D4; \
/* for peephole only*/
\
Tensor checked_cell; \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
checked_cell_data = checked_cell.mutable_data<T>({2, D}, place); \
}
/// Compute LSTM
#define GEMM_WH_ADDON(bs, prev, out) \
blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
wh_data, D4, static_cast<T>(1), out, D4)
// gates: W_ch, W_ih, W_fh, W_oh
#define GET_Ct(ct_1, gates, ct) \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, gates + D); \
blas.VMUL(D, ct_1, gates + D2, gates + D2); \
blas.VADD(D, gates + D, gates + D2, ct)
#define GET_Ht(ct, gates, ht) \
/* H_t = act_cell(C_t) * ogated */
\
act_cell(D, ct, gates + D2); \
blas.VMUL(D, gates + D2, gates + D3, ht)
#define GET_Ct_NOH0C0(gates, ct) \
/* C_t = igated * cgated*/
\
act_gate(D, gates + D, gates + D); \
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, ct)
#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
/* get outgated, put W_oc * C_t on igated */
\
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt(gates, ct_1, ct, ht) \
act_gate(D3, gates + D, gates + D); \
GET_Ct(ct_1, gates, ct); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \
/* get fgated and igated*/
\
blas.VMUL(D, wc_data, ct_1, checked_cell_data); \
blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \
blas.VADD(D2, checked_cell_data, gates + D, gates + D); \
act_gate(D2, gates + D, gates + D); \
GET_Ct(ct_1, gates, ct); \
/* get ogated*/
\
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
auto
x_lod
=
x
->
lod
();
const
int
total_T
=
x_dims
[
0
];
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
// batch size
const
T
*
x_data
=
x
->
data
<
T
>
();
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
const
T
*
c0_data
=
c0
?
c0
->
data
<
T
>
()
:
nullptr
;
const
T
*
bias_data
=
bias
->
data
<
T
>
();
const
T
*
wc_data
=
bias_data
+
D4
;
// w_ic, w_fc, w_oc
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
hidden_out_data
=
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
cell_out_data
=
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// use local variable
framework
::
DDim
check_dims
({
3
,
D
});
Tensor
checked_cell
;
// w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
auto
checked_cell_data
=
checked_cell
.
mutable_data
<
T
>
(
check_dims
,
ctx
.
GetPlace
());
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
h_out_data
=
hidden_out
->
mutable_data
<
T
>
(
place
);
T
*
c_out_data
=
cell_out
->
mutable_data
<
T
>
(
place
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
total_T
,
D4
,
M
,
x_data
,
wx_data
,
xx_data
,
bias
->
data
<
T
>
());
int
xx_offset
=
D4
;
int
gate_offset
=
D
;
if
(
is_reverse
)
{
const
int
offset
=
(
total_T
-
1
)
*
D
;
xx_data
=
xx_data
+
offset
*
4
;
h
idden_out_data
=
hidden
_out_data
+
offset
;
c
ell_out_data
=
cell
_out_data
+
offset
;
h
_out_data
=
h
_out_data
+
offset
;
c
_out_data
=
c
_out_data
+
offset
;
xx_offset
=
-
D4
;
gate_offset
=
-
D
;
}
auto
move_step
=
[
&
]()
{
xx_data
=
xx_data
+
xx_offset
;
hidden_out_data
=
hidden_out_data
+
gate_offset
;
cell_out_data
=
cell_out_data
+
gate_offset
;
};
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
int
bid
=
is_reverse
?
N
-
1
-
i
:
i
;
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
const
T
*
prev_c_data
=
nullptr
;
const
T
*
prev_h_data
=
nullptr
;
int
tstart
=
0
;
if
(
h0_data
)
{
prev_h_data
=
h0_data
+
bid
*
D
;
prev_c_data
=
c0_data
+
bid
*
D
;
}
else
{
// If step == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros. Then W_h * H_t-1 can be skipped
// ~C_t
act_cand
(
D
,
xx_data
,
xx_data
);
if
(
use_peepholes
)
{
// I_t, F_t
act_gate
(
D2
,
xx_data
+
D
,
xx_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
}
// C_t = I_t * ~C_t
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
cell_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cell_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
xx_data
+
D3
,
checked_cell_data
+
D2
,
xx_data
+
D3
);
// O_t
act_gate
(
D
,
xx_data
+
D3
,
xx_data
+
D3
);
}
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
prev_h_data
=
hidden_out_data
;
prev_c_data
=
cell_out_data
;
tstart
=
1
;
move_step
();
}
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
// + W_h * H_t-1
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
D4
);
#define MOVE_ONE_STEP \
prev_h_data = h_out_data; \
prev_c_data = c_out_data; \
xx_data = xx_data + xx_offset; \
h_out_data = h_out_data + gate_offset; \
c_out_data = c_out_data + gate_offset
#define PROCESS_H0C0_DEFINES \
int bid = is_reverse ? N - 1 - i : i; \
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \
const T* prev_c_data = nullptr; \
const T* prev_h_data = nullptr; \
int tstart = 0
#define PROCESS_H0C0_PEEPHOLE \
PROCESS_H0C0_DEFINES; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
// ~C_t
act_cand
(
D
,
xx_data
,
xx_data
);
#define PROCESS_H0C0 \
PROCESS_H0C0_DEFINES; \
if (h0_data) { \
prev_h_data = h0_data + bid * D; \
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
if
(
use_peepholes
)
{
// + W_ic|W_fc * C_t-1 for peephole connection
blas
.
VMUL
(
D
,
wc_data
,
prev_c_data
,
checked_cell_data
);
blas
.
VMUL
(
D
,
wc_data
+
D
,
prev_c_data
,
checked_cell_data
+
D
);
blas
.
VADD
(
D2
,
xx_data
+
D
,
checked_cell_data
,
xx_data
+
D
);
// I_t, F_t
act_gate
(
D2
,
xx_data
+
D
,
xx_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
if
(
use_peepholes
)
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
PROCESS_H0C0_PEEPHOLE
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
COMPUTE_CtHt_PEEPHOLE
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
);
MOVE_ONE_STEP
;
}
// F_t * C_t-1
blas
.
VMUL
(
D
,
xx_data
+
D2
,
prev_c_data
,
xx_data
+
D2
);
// I_t * ~C_t
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
xx_data
+
D
);
// C_t = F_t * C_t-1 + I_t * ~C_t
blas
.
VADD
(
D
,
xx_data
+
D
,
xx_data
+
D2
,
cell_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cell_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
xx_data
+
D3
,
checked_cell_data
+
D2
,
xx_data
+
D3
);
// O_t
act_gate
(
D
,
xx_data
+
D3
,
xx_data
+
D3
);
}
}
else
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
PROCESS_H0C0
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
COMPUTE_CtHt
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
);
MOVE_ONE_STEP
;
}
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
prev_h_data
=
hidden_out_data
;
prev_c_data
=
cell_out_data
;
move_step
();
}
// for each step in batch
}
// for each batch
}
}
#undef PROCESS_H0C0_DEFINES
#undef PROCESS_H0C0_PEEPHOLE
#undef PROCESS_H0C0
#undef MOVE_ONE_STEP
}
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
platform
::
CPUDeviceContext
;
INIT_BASE_INPUT_OUTPUT
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
// batch size == 1
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
SeqCompute
(
ctx
);
return
;
}
INIT_BASE_SIZES
INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
auto
*
reordered_c0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedC0"
);
auto
*
batched_input
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedInput"
);
auto
*
batched_c_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedCell"
);
auto
*
batched_h_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedHidden"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
const
T
*
bias_data
=
bias
->
data
<
T
>
();
const
T
*
wc_data
=
bias_data
+
D4
;
// w_ic, w_fc, w_oc
auto
place
=
ctx
.
GetPlace
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
place
);
T
*
batched_c_out_data
=
batched_c_out
->
mutable_data
<
T
>
(
place
);
...
...
@@ -427,12 +436,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
hidden_out
->
mutable_data
<
T
>
(
place
);
cell_out
->
mutable_data
<
T
>
(
place
);
// use local variable
framework
::
DDim
check_dims
({
3
,
D
});
Tensor
checked_cell
;
// w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
auto
checked_cell_data
=
checked_cell
.
mutable_data
<
T
>
(
check_dims
,
ctx
.
GetPlace
());
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
...
...
@@ -454,27 +457,17 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_h0
->
Resize
({
max_bs
,
D
});
reordered_c0
->
Resize
({
max_bs
,
D
});
T
*
prev_batch_h_data
=
nullptr
;
T
*
prev_batch_c_data
=
nullptr
;
T
*
cur_batch_in_data
=
batched_input_data
;
T
*
cur_batch_h_out_data
=
batched_h_out_data
;
T
*
cur_batch_c_out_data
=
batched_c_out_data
;
auto
move_step
=
[
&
](
int
bs
)
{
cur_batch_in_data
+=
bs
*
D4
;
cur_batch_c_out_data
+=
bs
*
D
;
cur_batch_h_out_data
+=
bs
*
D
;
};
int
tstart
=
0
;
T
*
prev_h_data
=
nullptr
;
T
*
prev_c_data
=
nullptr
;
if
(
h0
)
{
// reorder h0, c0
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
place
);
T
*
reordered_c0_data
=
reordered_c0
->
mutable_data
<
T
>
(
place
);
const
T
*
h0_data
=
h0
->
data
<
T
>
();
const
T
*
c0_data
=
c0
->
data
<
T
>
();
prev_
batch_
h_data
=
reordered_h0_data
;
prev_
batch_
c_data
=
reordered_c0_data
;
prev_h_data
=
reordered_h0_data
;
prev_c_data
=
reordered_c0_data
;
size_t
sz
=
sizeof
(
T
)
*
D
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
...
...
@@ -483,123 +476,80 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_c0_data
+=
D
;
}
}
else
{
// Compute with no H0/C0
T
*
cur_in_data
=
cur_batch_in_data
;
T
*
cur_c_out_data
=
cur_batch_c_out_data
;
T
*
cur_h_out_data
=
cur_batch_h_out_data
;
// If step == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros. Then W_h * H_t-1 can be skiped
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
// iterate each data in 1st batch
// ~C_t
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
if
(
use_peepholes
)
{
// I_t, F_t
act_gate
(
D2
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
// C_t = I_t * ~C_t
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_c_out_data
);
// compute without h0, c0
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
GET_Ct_NOH0C0
(
cur_in_data
,
cur_c_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
cur_in_data
+
D3
,
checked_cell_data
+
D2
,
cur_in_data
+
D3
);
// O_t
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
cur_in_data
+
D
);
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
}
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
// move to next data in the same batch
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
GET_Ht
(
cur_c_out_data
,
cur_in_data
,
cur_h_out_data
);
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
}
// move to data for next timestep
prev_batch_h_data
=
cur_batch_h_out_data
;
prev_batch_c_data
=
cur_batch_c_out_data
;
move_step
(
max_bs
);
tstart
=
1
;
prev_h_data
=
batched_h_out_data
;
prev_c_data
=
batched_c_out_data
;
}
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
// + W_h * H_t-1
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
cur_bs
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_batch_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
cur_batch_in_data
,
D4
);
T
*
cur_in_data
=
cur_batch_in_data
;
T
*
cur_c_out_data
=
cur_batch_c_out_data
;
T
*
cur_h_out_data
=
cur_batch_h_out_data
;
T
*
prev_c_data
=
prev_batch_c_data
;
// NULL if no C0 in step0
T
*
prev_h_data
=
prev_batch_h_data
;
// NULL if no H0 in step0
auto
next_data_in_batch
=
[
&
]()
{
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
prev_c_data
=
prev_c_data
?
prev_c_data
+
D
:
nullptr
;
prev_h_data
=
prev_h_data
?
prev_h_data
+
D
:
nullptr
;
};
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
// iterate each data in same batch
// ~C_t
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
if
(
use_peepholes
)
{
// + W_ic|W_fc * C_t-1 for peephole connection
blas
.
VMUL
(
D
,
wc_data
,
prev_c_data
,
checked_cell_data
);
blas
.
VMUL
(
D
,
wc_data
+
D
,
prev_c_data
,
checked_cell_data
+
D
);
blas
.
VADD
(
D2
,
cur_in_data
+
D
,
checked_cell_data
,
cur_in_data
+
D
);
// I_t, F_t
act_gate
(
D2
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
const
int
offset
=
tstart
*
max_bs
*
D
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
batched_h_out_data
=
batched_h_out_data
+
offset
;
batched_c_out_data
=
batched_c_out_data
+
offset
;
#define DEFINE_CUR \
T* cur_in_data = batched_input_data; \
T* cur_prev_c_data = prev_c_data; \
T* cur_c_out_data = batched_c_out_data; \
T* cur_h_out_data = batched_h_out_data
#define MOVE_ONE_BATCH \
cur_in_data += D4; \
cur_prev_c_data += D; \
cur_c_out_data += D; \
cur_h_out_data += D
#define MOVE_ONE_STEP \
prev_c_data = batched_c_out_data; \
prev_h_data = batched_h_out_data; \
batched_c_out_data = cur_c_out_data; \
batched_h_out_data = cur_h_out_data; \
batched_input_data = cur_in_data
if
(
use_peepholes
)
{
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
DEFINE_CUR
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
COMPUTE_CtHt_PEEPHOLE
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_h_out_data
);
MOVE_ONE_BATCH
;
}
// F_t * C_t-1
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
prev_c_data
,
cur_in_data
+
D2
);
// I_t * ~C_t
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_in_data
+
D
);
// C_t = F_t * C_t-1 + I_t * ~C_t
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D2
,
cur_c_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
cur_in_data
+
D3
,
checked_cell_data
+
D2
,
cur_in_data
+
D3
);
// O_t
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
MOVE_ONE_STEP
;
}
}
else
{
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
DEFINE_CUR
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
COMPUTE_CtHt
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
cur_h_out_data
);
MOVE_ONE_BATCH
;
}
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
// move to next data in same batch
next_data_in_batch
();
MOVE_ONE_STEP
;
}
// move to data for next timestep
prev_batch_h_data
=
cur_batch_h_out_data
;
prev_batch_c_data
=
cur_batch_c_out_data
;
move_step
(
cur_bs
);
}
#undef MOVE_ONE_STEP
#undef MOVE_ONE_BATCH
#undef DEFINE_CUR
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batched_h_out
->
set_lod
(
batched_lod
);
...
...
@@ -615,6 +565,16 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
BatchCompute
(
ctx
);
}
}
#undef COMPUTE_CtHt_PEEPHOLE
#undef COMPUTE_CtHt
#undef GET_Ct_NOH0C0
#undef COMPUTE_CtHt_NOH0C0
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
#undef GET_Ht
#undef GET_Ct
#undef GEMM_WH_ADDON
#undef INIT_BASE_INPUT_DATAS
#undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_VEC_FUNC
...
...
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
浏览文件 @
9f2ccf5b
...
...
@@ -53,12 +53,11 @@ class TestFusionLSTMOp(OpTest):
self
.
M
=
8
self
.
D
=
16
self
.
has_initial_state
=
False
self
.
use_peepholes
=
False
self
.
is_reverse
=
False
self
.
act_gate
=
'sigmoid'
self
.
act_cell
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
use_peepholes
=
False
self
.
use_seq
=
False
self
.
set_conf
()
T
=
sum
(
self
.
lod
[
0
])
...
...
@@ -108,7 +107,6 @@ class TestFusionLSTMOp(OpTest):
}
self
.
attrs
=
{
'use_peepholes'
:
self
.
use_peepholes
,
'use_seq'
:
self
.
use_seq
,
'is_reverse'
:
self
.
is_reverse
,
'gate_activation'
:
self
.
act_gate
,
'cell_activation'
:
self
.
act_cell
,
...
...
@@ -178,50 +176,18 @@ class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp):
self
.
is_reverse
=
True
class
TestFusionLSTMOpP
oopholesBS1
(
TestFusionLSTMOp
):
class
TestFusionLSTMOpP
eepholesInitReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_peepholes
=
True
self
.
lod
=
[[
3
]]
self
.
D
=
16
class
TestFusionLSTMOpSeqInit
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
has_initial_state
=
True
class
TestFusionLSTMOpSeqReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOpSeqInitReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
has_initial_state
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOp
SeqPeepholes
(
TestFusionLSTMOp
):
class
TestFusionLSTMOp
PeepholesBS1
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
class
TestFusionLSTMOpSeqPeepholesInit
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
self
.
has_initial_state
=
True
class
TestFusionLSTMOpSeqPeepholesReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
self
.
is_reverse
=
True
self
.
lod
=
[[
2
]]
self
.
D
=
8
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录