Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9f2ccf5b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9f2ccf5b
编写于
9月 06, 2018
作者:
T
tensor-tang
提交者:
GitHub
9月 06, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #13237 from tensor-tang/refine/op/peephole
refine fusion lstm/peephole and fusion gru
上级
225ecee5
718033e1
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
254 addition
and
329 deletion
+254
-329
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+1
-0
paddle/fluid/operators/fusion_gru_op.cc
paddle/fluid/operators/fusion_gru_op.cc
+8
-10
paddle/fluid/operators/fusion_lstm_op.cc
paddle/fluid/operators/fusion_lstm_op.cc
+240
-280
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
+5
-39
未找到文件。
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
9f2ccf5b
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
...
paddle/fluid/operators/fusion_gru_op.cc
浏览文件 @
9f2ccf5b
...
@@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -30,14 +30,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
"Input(WeightX) of GRU should not be null."
);
"Input(WeightX) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightH"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightH"
),
"Input(WeightH) of GRU should not be null."
);
"Input(WeightH) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XX"
),
"Output(XX) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XX"
),
"Output(XX) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Output(ReorderedH0) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Output(BatchedInput) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedOut"
),
"Output(BatchedOut) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
"Output(Hidden) of GRU should not be null."
);
"Output(Hidden) of GRU should not be null."
);
...
@@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -80,15 +73,20 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
}
}
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedOut"
,
out_dims
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
int
xx_width
;
int
xx_width
;
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_seq"
))
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_seq"
))
{
xx_width
=
wx_dims
[
1
];
xx_width
=
wx_dims
[
1
];
}
else
{
}
else
{
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Output(ReorderedH0) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Output(BatchedInput) of GRU should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedOut"
),
"Output(BatchedOut) of GRU should not be null."
);
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedOut"
,
out_dims
);
}
}
ctx
->
SetOutputDim
(
"XX"
,
{
x_dims
[
0
],
xx_width
});
ctx
->
SetOutputDim
(
"XX"
,
{
x_dims
[
0
],
xx_width
});
ctx
->
ShareLoD
(
"X"
,
"XX"
);
ctx
->
ShareLoD
(
"X"
,
"XX"
);
...
...
paddle/fluid/operators/fusion_lstm_op.cc
浏览文件 @
9f2ccf5b
...
@@ -38,16 +38,6 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -38,16 +38,6 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
"Output(Hidden) of LSTM should not be null."
);
"Output(Hidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Output(Cell) of LSTM should not be null."
);
"Output(Cell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Output(BatchedInput) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedHidden"
),
"Output(BatchedHidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedCell"
),
"Output(BatchedCell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Output(ReorderedH0) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedC0"
),
"Output(ReorderedC0) of LSTM should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
...
@@ -88,9 +78,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -88,9 +78,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
"The first dimension of Input(Bias) should be 1."
);
PADDLE_ENFORCE_EQ
(
auto
use_peepholes
=
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
);
b_dims
[
1
],
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
)
?
7
:
4
)
*
frame_size
,
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
(
use_peepholes
?
7
:
4
)
*
frame_size
,
"The second dimension of Input(Bias) should be "
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection or"
"7 * %d if enable peepholes connection or"
"4 * %d if disable peepholes"
,
"4 * %d if disable peepholes"
,
...
@@ -99,17 +88,26 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -99,17 +88,26 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedCell"
,
out_dims
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
ctx
->
ShareLoD
(
"X"
,
"Cell"
);
ctx
->
ShareLoD
(
"X"
,
"Cell"
);
int
xx_width
;
int
xx_width
;
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_seq"
))
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_seq"
))
{
xx_width
=
wx_dims
[
1
];
xx_width
=
wx_dims
[
1
];
}
else
{
}
else
{
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Output(BatchedInput) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedHidden"
),
"Output(BatchedHidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedCell"
),
"Output(BatchedCell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Output(ReorderedH0) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedC0"
),
"Output(ReorderedC0) of LSTM should not be null."
);
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedCell"
,
out_dims
);
}
}
ctx
->
SetOutputDim
(
"XX"
,
{
x_dims
[
0
],
xx_width
});
ctx
->
SetOutputDim
(
"XX"
,
{
x_dims
[
0
],
xx_width
});
ctx
->
ShareLoD
(
"X"
,
"XX"
);
ctx
->
ShareLoD
(
"X"
,
"XX"
);
...
@@ -242,8 +240,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -242,8 +240,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
bool
use_peepholes = ctx.Attr<bool>("use_peepholes");
\
bool
is_reverse = ctx.Attr<bool>("is_reverse");
\
bool
is_reverse = ctx.Attr<bool>("is_reverse
");
bool
use_peepholes = ctx.Attr<bool>("use_peepholes
");
#define INIT_BASE_SIZES \
#define INIT_BASE_SIZES \
auto x_dims = x->dims();
/* T x M*/
\
auto x_dims = x->dims();
/* T x M*/
\
...
@@ -254,172 +252,183 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -254,172 +252,183 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const int D3 = D * 3; \
const int D3 = D * 3; \
const int D4 = wh_dims[1];
const int D4 = wh_dims[1];
#define INIT_BASE_INPUT_DATAS \
const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \
/* diagonal weight*/
\
const T* wc_data = bias->data<T>() + D4; \
/* for peephole only*/
\
Tensor checked_cell; \
T* checked_cell_data = nullptr; \
auto place = ctx.GetPlace(); \
if (use_peepholes) { \
/* w_ic * Ct-1, w_fc * Ct-1 ; w_oc * Ct => ih*/
\
checked_cell_data = checked_cell.mutable_data<T>({2, D}, place); \
}
/// Compute LSTM
#define GEMM_WH_ADDON(bs, prev, out) \
blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast<T>(1), prev, D, \
wh_data, D4, static_cast<T>(1), out, D4)
// gates: W_ch, W_ih, W_fh, W_oh
#define GET_Ct(ct_1, gates, ct) \
/* C_t = C_t-1 * fgated + cand_gated * igated*/
\
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, gates + D); \
blas.VMUL(D, ct_1, gates + D2, gates + D2); \
blas.VADD(D, gates + D, gates + D2, ct)
#define GET_Ht(ct, gates, ht) \
/* H_t = act_cell(C_t) * ogated */
\
act_cell(D, ct, gates + D2); \
blas.VMUL(D, gates + D2, gates + D3, ht)
#define GET_Ct_NOH0C0(gates, ct) \
/* C_t = igated * cgated*/
\
act_gate(D, gates + D, gates + D); \
act_cand(D, gates, gates); \
blas.VMUL(D, gates, gates + D, ct)
#define COMPUTE_CtHt_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE_NOH0C0(gates, ct, ht) \
GET_Ct_NOH0C0(gates, ct); \
/* get outgated, put W_oc * C_t on igated */
\
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt(gates, ct_1, ct, ht) \
act_gate(D3, gates + D, gates + D); \
GET_Ct(ct_1, gates, ct); \
GET_Ht(ct, gates, ht)
#define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \
/* get fgated and igated*/
\
blas.VMUL(D, wc_data, ct_1, checked_cell_data); \
blas.VMUL(D, wc_data + D, ct_1, checked_cell_data + D); \
blas.VADD(D2, checked_cell_data, gates + D, gates + D); \
act_gate(D2, gates + D, gates + D); \
GET_Ct(ct_1, gates, ct); \
/* get ogated*/
\
blas.VMUL(D, wc_data + D2, ct, gates + D); \
blas.VADD(D, gates + D, gates + D3, gates + D3); \
act_gate(D, gates + D3, gates + D3); \
GET_Ht(ct, gates, ht)
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
SeqCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
using
DeviceContext
=
paddle
::
platform
::
CPUDeviceContext
;
INIT_BASE_INPUT_OUTPUT
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_BASE_SIZES
INIT_VEC_FUNC
INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
auto
x_lod
=
x
->
lod
();
auto
x_lod
=
x
->
lod
();
const
int
total_T
=
x_dims
[
0
];
const
int
total_T
=
x_dims
[
0
];
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
// batch size
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
const
T
*
c0_data
=
c0
?
c0
->
data
<
T
>
()
:
nullptr
;
const
T
*
c0_data
=
c0
?
c0
->
data
<
T
>
()
:
nullptr
;
const
T
*
bias_data
=
bias
->
data
<
T
>
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
const
T
*
wc_data
=
bias_data
+
D4
;
// w_ic, w_fc, w_oc
T
*
h_out_data
=
hidden_out
->
mutable_data
<
T
>
(
place
);
const
T
*
wx_data
=
wx
->
data
<
T
>
();
T
*
c_out_data
=
cell_out
->
mutable_data
<
T
>
(
place
);
const
T
*
wh_data
=
wh
->
data
<
T
>
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
hidden_out_data
=
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
cell_out_data
=
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// use local variable
framework
::
DDim
check_dims
({
3
,
D
});
Tensor
checked_cell
;
// w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
auto
checked_cell_data
=
checked_cell
.
mutable_data
<
T
>
(
check_dims
,
ctx
.
GetPlace
());
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
total_T
,
D4
,
M
,
x_data
,
wx_data
,
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
total_T
,
D4
,
M
,
x_data
,
wx_data
,
xx_data
,
bias
->
data
<
T
>
());
xx_data
,
bias
->
data
<
T
>
());
int
xx_offset
=
D4
;
int
xx_offset
=
D4
;
int
gate_offset
=
D
;
int
gate_offset
=
D
;
if
(
is_reverse
)
{
if
(
is_reverse
)
{
const
int
offset
=
(
total_T
-
1
)
*
D
;
const
int
offset
=
(
total_T
-
1
)
*
D
;
xx_data
=
xx_data
+
offset
*
4
;
xx_data
=
xx_data
+
offset
*
4
;
h
idden_out_data
=
hidden
_out_data
+
offset
;
h
_out_data
=
h
_out_data
+
offset
;
c
ell_out_data
=
cell
_out_data
+
offset
;
c
_out_data
=
c
_out_data
+
offset
;
xx_offset
=
-
D4
;
xx_offset
=
-
D4
;
gate_offset
=
-
D
;
gate_offset
=
-
D
;
}
}
auto
move_step
=
[
&
]()
{
#define MOVE_ONE_STEP \
xx_data
=
xx_data
+
xx_offset
;
prev_h_data = h_out_data; \
hidden_out_data
=
hidden_out_data
+
gate_offset
;
prev_c_data = c_out_data; \
cell_out_data
=
cell_out_data
+
gate_offset
;
xx_data = xx_data + xx_offset; \
};
h_out_data = h_out_data + gate_offset; \
c_out_data = c_out_data + gate_offset
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
int
bid
=
is_reverse
?
N
-
1
-
i
:
i
;
#define PROCESS_H0C0_DEFINES \
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
int bid = is_reverse ? N - 1 - i : i; \
const
T
*
prev_c_data
=
nullptr
;
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; \
const
T
*
prev_h_data
=
nullptr
;
const T* prev_c_data = nullptr; \
const T* prev_h_data = nullptr; \
int
tstart
=
0
;
int tstart = 0
if
(
h0_data
)
{
prev_h_data
=
h0_data
+
bid
*
D
;
#define PROCESS_H0C0_PEEPHOLE \
prev_c_data
=
c0_data
+
bid
*
D
;
PROCESS_H0C0_DEFINES; \
}
else
{
if (h0_data) { \
// If step == 0 and there is no initialized hidden state, that is to say
prev_h_data = h0_data + bid * D; \
// the H0 is zeros. Then W_h * H_t-1 can be skipped
prev_c_data = c0_data + bid * D; \
} else { \
COMPUTE_CtHt_PEEPHOLE_NOH0C0(xx_data, c_out_data, h_out_data); \
MOVE_ONE_STEP; \
tstart = 1; \
}
// ~C_t
#define PROCESS_H0C0 \
act_cand
(
D
,
xx_data
,
xx_data
);
PROCESS_H0C0_DEFINES; \
if
(
use_peepholes
)
{
if (h0_data) { \
// I_t, F_t
prev_h_data = h0_data + bid * D; \
act_gate
(
D2
,
xx_data
+
D
,
xx_data
+
D
);
prev_c_data = c0_data + bid * D; \
}
else
{
} else { \
// I_t, F_t, O_t
COMPUTE_CtHt_NOH0C0(xx_data, c_out_data, h_out_data); \
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
MOVE_ONE_STEP; \
tstart = 1; \
}
}
// C_t = I_t * ~C_t
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
cell_out_data
);
if
(
use_peepholes
)
{
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cell_out_data
,
checked_cell_data
+
D2
);
PROCESS_H0C0_PEEPHOLE
blas
.
VADD
(
D
,
xx_data
+
D3
,
checked_cell_data
+
D2
,
xx_data
+
D3
);
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
// O_t
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
act_gate
(
D
,
xx_data
+
D3
,
xx_data
+
D3
);
COMPUTE_CtHt_PEEPHOLE
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
);
MOVE_ONE_STEP
;
}
}
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
prev_h_data
=
hidden_out_data
;
prev_c_data
=
cell_out_data
;
tstart
=
1
;
move_step
();
}
}
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
// + W_h * H_t-1
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
D4
);
// ~C_t
act_cand
(
D
,
xx_data
,
xx_data
);
if
(
use_peepholes
)
{
// + W_ic|W_fc * C_t-1 for peephole connection
blas
.
VMUL
(
D
,
wc_data
,
prev_c_data
,
checked_cell_data
);
blas
.
VMUL
(
D
,
wc_data
+
D
,
prev_c_data
,
checked_cell_data
+
D
);
blas
.
VADD
(
D2
,
xx_data
+
D
,
checked_cell_data
,
xx_data
+
D
);
// I_t, F_t
act_gate
(
D2
,
xx_data
+
D
,
xx_data
+
D
);
}
else
{
}
else
{
// I_t, F_t, O_t
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
PROCESS_H0C0
for
(
int
step
=
tstart
;
step
<
seq_len
;
++
step
)
{
GEMM_WH_ADDON
(
1
,
prev_h_data
,
xx_data
);
COMPUTE_CtHt
(
xx_data
,
prev_c_data
,
c_out_data
,
h_out_data
);
MOVE_ONE_STEP
;
}
}
// F_t * C_t-1
blas
.
VMUL
(
D
,
xx_data
+
D2
,
prev_c_data
,
xx_data
+
D2
);
// I_t * ~C_t
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
xx_data
+
D
);
// C_t = F_t * C_t-1 + I_t * ~C_t
blas
.
VADD
(
D
,
xx_data
+
D
,
xx_data
+
D2
,
cell_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cell_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
xx_data
+
D3
,
checked_cell_data
+
D2
,
xx_data
+
D3
);
// O_t
act_gate
(
D
,
xx_data
+
D3
,
xx_data
+
D3
);
}
}
}
// hidden out= act_state(cellout) * outgate
#undef PROCESS_H0C0_DEFINES
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
#undef PROCESS_H0C0_PEEPHOLE
// H_t = O_t * act_state(C_t)
#undef PROCESS_H0C0
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
#undef MOVE_ONE_STEP
// prev
prev_h_data
=
hidden_out_data
;
prev_c_data
=
cell_out_data
;
move_step
();
}
// for each step in batch
}
// for each batch
}
}
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
using
DeviceContext
=
platform
::
CPUDeviceContext
;
using
DeviceContext
=
platform
::
CPUDeviceContext
;
INIT_BASE_INPUT_OUTPUT
INIT_BASE_INPUT_OUTPUT
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
// batch size == 1
if
(
x
->
lod
()[
0
].
size
()
==
2
)
{
SeqCompute
(
ctx
);
SeqCompute
(
ctx
);
return
;
return
;
}
}
INIT_BASE_SIZES
INIT_BASE_SIZES
INIT_VEC_FUNC
INIT_VEC_FUNC
INIT_BASE_INPUT_DATAS
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
auto
*
reordered_h0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedH0"
);
auto
*
reordered_c0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedC0"
);
auto
*
reordered_c0
=
ctx
.
Output
<
Tensor
>
(
"ReorderedC0"
);
auto
*
batched_input
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedInput"
);
auto
*
batched_input
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedInput"
);
auto
*
batched_c_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedCell"
);
auto
*
batched_c_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedCell"
);
auto
*
batched_h_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedHidden"
);
auto
*
batched_h_out
=
ctx
.
Output
<
LoDTensor
>
(
"BatchedHidden"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
const
T
*
bias_data
=
bias
->
data
<
T
>
();
const
T
*
wc_data
=
bias_data
+
D4
;
// w_ic, w_fc, w_oc
auto
place
=
ctx
.
GetPlace
();
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
place
);
T
*
batched_input_data
=
batched_input
->
mutable_data
<
T
>
(
place
);
T
*
batched_c_out_data
=
batched_c_out
->
mutable_data
<
T
>
(
place
);
T
*
batched_c_out_data
=
batched_c_out
->
mutable_data
<
T
>
(
place
);
...
@@ -427,12 +436,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -427,12 +436,6 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
hidden_out
->
mutable_data
<
T
>
(
place
);
hidden_out
->
mutable_data
<
T
>
(
place
);
cell_out
->
mutable_data
<
T
>
(
place
);
cell_out
->
mutable_data
<
T
>
(
place
);
// use local variable
framework
::
DDim
check_dims
({
3
,
D
});
Tensor
checked_cell
;
// w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
auto
checked_cell_data
=
checked_cell
.
mutable_data
<
T
>
(
check_dims
,
ctx
.
GetPlace
());
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
math
::
LoDTensor2BatchFunctor
<
DeviceContext
,
T
>
to_batch
;
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
dev_ctx
);
...
@@ -454,27 +457,17 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -454,27 +457,17 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_h0
->
Resize
({
max_bs
,
D
});
reordered_h0
->
Resize
({
max_bs
,
D
});
reordered_c0
->
Resize
({
max_bs
,
D
});
reordered_c0
->
Resize
({
max_bs
,
D
});
T
*
prev_batch_h_data
=
nullptr
;
T
*
prev_batch_c_data
=
nullptr
;
T
*
cur_batch_in_data
=
batched_input_data
;
T
*
cur_batch_h_out_data
=
batched_h_out_data
;
T
*
cur_batch_c_out_data
=
batched_c_out_data
;
auto
move_step
=
[
&
](
int
bs
)
{
cur_batch_in_data
+=
bs
*
D4
;
cur_batch_c_out_data
+=
bs
*
D
;
cur_batch_h_out_data
+=
bs
*
D
;
};
int
tstart
=
0
;
int
tstart
=
0
;
T
*
prev_h_data
=
nullptr
;
T
*
prev_c_data
=
nullptr
;
if
(
h0
)
{
if
(
h0
)
{
// reorder h0, c0
// reorder h0, c0
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
place
);
T
*
reordered_h0_data
=
reordered_h0
->
mutable_data
<
T
>
(
place
);
T
*
reordered_c0_data
=
reordered_c0
->
mutable_data
<
T
>
(
place
);
T
*
reordered_c0_data
=
reordered_c0
->
mutable_data
<
T
>
(
place
);
const
T
*
h0_data
=
h0
->
data
<
T
>
();
const
T
*
h0_data
=
h0
->
data
<
T
>
();
const
T
*
c0_data
=
c0
->
data
<
T
>
();
const
T
*
c0_data
=
c0
->
data
<
T
>
();
prev_
batch_
h_data
=
reordered_h0_data
;
prev_h_data
=
reordered_h0_data
;
prev_
batch_
c_data
=
reordered_c0_data
;
prev_c_data
=
reordered_c0_data
;
size_t
sz
=
sizeof
(
T
)
*
D
;
size_t
sz
=
sizeof
(
T
)
*
D
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
std
::
memcpy
(
reordered_h0_data
,
h0_data
+
seq_order
[
i
]
*
D
,
sz
);
...
@@ -483,123 +476,80 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -483,123 +476,80 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_c0_data
+=
D
;
reordered_c0_data
+=
D
;
}
}
}
else
{
}
else
{
// Compute with no H0/C0
// compute without h0, c0
T
*
cur_in_data
=
cur_batch_in_data
;
T
*
cur_in_data
=
batched_input_data
;
T
*
cur_c_out_data
=
cur_batch_c_out_data
;
T
*
cur_h_out_data
=
batched_h_out_data
;
T
*
cur_h_out_data
=
cur_batch_h_out_data
;
T
*
cur_c_out_data
=
batched_c_out_data
;
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
// If step == 0 and there is no initialized hidden state, that is to say
GET_Ct_NOH0C0
(
cur_in_data
,
cur_c_out_data
);
// the H0 is zeros. Then W_h * H_t-1 can be skiped
for
(
int
i
=
0
;
i
<
max_bs
;
++
i
)
{
// iterate each data in 1st batch
// ~C_t
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
if
(
use_peepholes
)
{
if
(
use_peepholes
)
{
// I_t, F_t
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
cur_in_data
+
D
);
act_gate
(
D2
,
cur_in_data
+
D
,
cur_in_data
+
D
);
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
}
// C_t = I_t * ~C_t
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_c_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
cur_in_data
+
D3
,
checked_cell_data
+
D2
,
cur_in_data
+
D3
);
// O_t
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
}
GET_Ht
(
cur_c_out_data
,
cur_in_data
,
cur_h_out_data
);
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
// move to next data in the same batch
cur_in_data
+=
D4
;
cur_in_data
+=
D4
;
cur_c_out_data
+=
D
;
cur_c_out_data
+=
D
;
cur_h_out_data
+=
D
;
cur_h_out_data
+=
D
;
}
}
// move to data for next timestep
prev_batch_h_data
=
cur_batch_h_out_data
;
prev_batch_c_data
=
cur_batch_c_out_data
;
move_step
(
max_bs
);
tstart
=
1
;
tstart
=
1
;
prev_h_data
=
batched_h_out_data
;
prev_c_data
=
batched_c_out_data
;
}
}
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
auto
&
batch_starts
=
batched_lod
[
0
];
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
const
int
max_seq_len
=
batch_starts
.
size
()
-
1
;
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
const
int
offset
=
tstart
*
max_bs
*
D
;
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
]
;
batched_input_data
=
batched_input_data
+
offset
*
4
;
// + W_h * H_t-1
batched_h_out_data
=
batched_h_out_data
+
offset
;
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
cur_bs
,
D4
,
D
,
static_cast
<
T
>
(
1
),
batched_c_out_data
=
batched_c_out_data
+
offset
;
prev_batch_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
cur_batch_in_data
,
D4
);
#define DEFINE_CUR \
T* cur_in_data = batched_input_data; \
T
*
cur_in_data
=
cur_batch_in_data
;
T* cur_prev_c_data = prev_c_data; \
T
*
cur_c_out_data
=
cur_batch_c_out_data
;
T* cur_c_out_data = batched_c_out_data; \
T
*
cur_h_out_data
=
cur_batch_h_out_data
;
T* cur_h_out_data = batched_h_out_data
T
*
prev_c_data
=
prev_batch_c_data
;
// NULL if no C0 in step0
T
*
prev_h_data
=
prev_batch_h_data
;
// NULL if no H0 in step0
#define MOVE_ONE_BATCH \
auto
next_data_in_batch
=
[
&
]()
{
cur_in_data += D4; \
cur_in_data
+=
D4
;
cur_prev_c_data += D; \
cur_c_out_data
+=
D
;
cur_c_out_data += D; \
cur_h_out_data
+=
D
;
cur_h_out_data += D
prev_c_data
=
prev_c_data
?
prev_c_data
+
D
:
nullptr
;
prev_h_data
=
prev_h_data
?
prev_h_data
+
D
:
nullptr
;
#define MOVE_ONE_STEP \
};
prev_c_data = batched_c_out_data; \
prev_h_data = batched_h_out_data; \
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
// iterate each data in same batch
batched_c_out_data = cur_c_out_data; \
// ~C_t
batched_h_out_data = cur_h_out_data; \
act_cand
(
D
,
cur_in_data
,
cur_in_data
);
batched_input_data = cur_in_data
if
(
use_peepholes
)
{
if
(
use_peepholes
)
{
// + W_ic|W_fc * C_t-1 for peephole connection
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
blas
.
VMUL
(
D
,
wc_data
,
prev_c_data
,
checked_cell_data
);
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
blas
.
VMUL
(
D
,
wc_data
+
D
,
prev_c_data
,
checked_cell_data
+
D
);
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
blas
.
VADD
(
D2
,
cur_in_data
+
D
,
checked_cell_data
,
cur_in_data
+
D
);
DEFINE_CUR
;
// I_t, F_t
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
act_gate
(
D2
,
cur_in_data
+
D
,
cur_in_data
+
D
);
COMPUTE_CtHt_PEEPHOLE
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
}
else
{
cur_h_out_data
);
// I_t, F_t, O_t
MOVE_ONE_BATCH
;
act_gate
(
D3
,
cur_in_data
+
D
,
cur_in_data
+
D
);
}
}
MOVE_ONE_STEP
;
// F_t * C_t-1
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
prev_c_data
,
cur_in_data
+
D2
);
// I_t * ~C_t
blas
.
VMUL
(
D
,
cur_in_data
,
cur_in_data
+
D
,
cur_in_data
+
D
);
// C_t = F_t * C_t-1 + I_t * ~C_t
blas
.
VADD
(
D
,
cur_in_data
+
D
,
cur_in_data
+
D2
,
cur_c_out_data
);
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cur_c_out_data
,
checked_cell_data
+
D2
);
blas
.
VADD
(
D
,
cur_in_data
+
D3
,
checked_cell_data
+
D2
,
cur_in_data
+
D3
);
// O_t
act_gate
(
D
,
cur_in_data
+
D3
,
cur_in_data
+
D3
);
}
}
}
else
{
// hidden out= act_state(cellout) * outgate
for
(
int
step
=
tstart
;
step
<
max_seq_len
;
++
step
)
{
act_cell
(
D
,
cur_c_out_data
,
cur_in_data
+
D2
);
const
int
cur_bs
=
batch_starts
[
step
+
1
]
-
batch_starts
[
step
];
// H_t = O_t * act_state(C_t)
GEMM_WH_ADDON
(
cur_bs
,
prev_h_data
,
batched_input_data
);
blas
.
VMUL
(
D
,
cur_in_data
+
D2
,
cur_in_data
+
D3
,
cur_h_out_data
);
DEFINE_CUR
;
for
(
int
i
=
0
;
i
<
cur_bs
;
++
i
)
{
// move to next data in same batch
COMPUTE_CtHt
(
cur_in_data
,
cur_prev_c_data
,
cur_c_out_data
,
next_data_in_batch
();
cur_h_out_data
);
MOVE_ONE_BATCH
;
}
MOVE_ONE_STEP
;
}
}
// move to data for next timestep
prev_batch_h_data
=
cur_batch_h_out_data
;
prev_batch_c_data
=
cur_batch_c_out_data
;
move_step
(
cur_bs
);
}
}
#undef MOVE_ONE_STEP
#undef MOVE_ONE_BATCH
#undef DEFINE_CUR
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
math
::
Batch2LoDTensorFunctor
<
DeviceContext
,
T
>
to_seq
;
batched_h_out
->
set_lod
(
batched_lod
);
batched_h_out
->
set_lod
(
batched_lod
);
...
@@ -615,6 +565,16 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -615,6 +565,16 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
BatchCompute
(
ctx
);
BatchCompute
(
ctx
);
}
}
}
}
#undef COMPUTE_CtHt_PEEPHOLE
#undef COMPUTE_CtHt
#undef GET_Ct_NOH0C0
#undef COMPUTE_CtHt_NOH0C0
#undef COMPUTE_CtHt_PEEPHOLE_NOH0C0
#undef GET_Ht
#undef GET_Ct
#undef GEMM_WH_ADDON
#undef INIT_BASE_INPUT_DATAS
#undef INIT_BASE_SIZES
#undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_BASE_INPUT_OUTPUT
#undef INIT_VEC_FUNC
#undef INIT_VEC_FUNC
...
...
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
浏览文件 @
9f2ccf5b
...
@@ -53,12 +53,11 @@ class TestFusionLSTMOp(OpTest):
...
@@ -53,12 +53,11 @@ class TestFusionLSTMOp(OpTest):
self
.
M
=
8
self
.
M
=
8
self
.
D
=
16
self
.
D
=
16
self
.
has_initial_state
=
False
self
.
has_initial_state
=
False
self
.
use_peepholes
=
False
self
.
is_reverse
=
False
self
.
is_reverse
=
False
self
.
act_gate
=
'sigmoid'
self
.
act_gate
=
'sigmoid'
self
.
act_cell
=
'tanh'
self
.
act_cell
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
use_peepholes
=
False
self
.
use_seq
=
False
self
.
set_conf
()
self
.
set_conf
()
T
=
sum
(
self
.
lod
[
0
])
T
=
sum
(
self
.
lod
[
0
])
...
@@ -108,7 +107,6 @@ class TestFusionLSTMOp(OpTest):
...
@@ -108,7 +107,6 @@ class TestFusionLSTMOp(OpTest):
}
}
self
.
attrs
=
{
self
.
attrs
=
{
'use_peepholes'
:
self
.
use_peepholes
,
'use_peepholes'
:
self
.
use_peepholes
,
'use_seq'
:
self
.
use_seq
,
'is_reverse'
:
self
.
is_reverse
,
'is_reverse'
:
self
.
is_reverse
,
'gate_activation'
:
self
.
act_gate
,
'gate_activation'
:
self
.
act_gate
,
'cell_activation'
:
self
.
act_cell
,
'cell_activation'
:
self
.
act_cell
,
...
@@ -178,50 +176,18 @@ class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp):
...
@@ -178,50 +176,18 @@ class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp):
self
.
is_reverse
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOpP
oopholesBS1
(
TestFusionLSTMOp
):
class
TestFusionLSTMOpP
eepholesInitReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
def
set_conf
(
self
):
self
.
use_peepholes
=
True
self
.
use_peepholes
=
True
self
.
lod
=
[[
3
]]
self
.
D
=
16
class
TestFusionLSTMOpSeqInit
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
has_initial_state
=
True
class
TestFusionLSTMOpSeqReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOpSeqInitReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
has_initial_state
=
True
self
.
has_initial_state
=
True
self
.
is_reverse
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOp
SeqPeepholes
(
TestFusionLSTMOp
):
class
TestFusionLSTMOp
PeepholesBS1
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
self
.
use_peepholes
=
True
self
.
lod
=
[[
2
]]
self
.
D
=
8
class
TestFusionLSTMOpSeqPeepholesInit
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
self
.
has_initial_state
=
True
class
TestFusionLSTMOpSeqPeepholesReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
self
.
is_reverse
=
True
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录