Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2f3b4989
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2f3b4989
编写于
9月 05, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine fusion seq lstm peephole
上级
5f586e22
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
58 addition
and
113 deletion
+58
-113
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
+1
-0
paddle/fluid/operators/fusion_lstm_op.cc
paddle/fluid/operators/fusion_lstm_op.cc
+52
-74
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
+5
-39
未找到文件。
paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc
浏览文件 @
2f3b4989
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
...
paddle/fluid/operators/fusion_lstm_op.cc
浏览文件 @
2f3b4989
...
@@ -78,13 +78,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -78,13 +78,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
"The first dimension of Input(Bias) should be 1."
);
PADDLE_ENFORCE_EQ
(
auto
use_peepholes
=
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
);
b_dims
[
1
],
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
)
?
7
:
4
)
*
frame_size
,
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
(
use_peepholes
?
7
:
4
)
*
frame_size
,
"The second dimension of Input(Bias) should be "
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection or"
"7 * %d if enable peepholes connection or"
"4 * %d if disable peepholes"
,
"4 * %d if disable peepholes"
,
frame_size
,
frame_size
);
frame_size
,
frame_size
);
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
...
@@ -231,18 +230,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -231,18 +230,18 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
act_cand = act_functor(act_cand_str); \
act_cand = act_functor(act_cand_str); \
}
}
#define INIT_BASE_INPUT_OUTPUT
\
#define INIT_BASE_INPUT_OUTPUT \
auto* x = ctx.Input<LoDTensor>("X");
\
auto* x = ctx.Input<LoDTensor>("X"); \
auto* h0 = ctx.Input<Tensor>("H0");
\
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* c0 = ctx.Input<Tensor>("C0");
\
auto* c0 = ctx.Input<Tensor>("C0"); \
auto* wx = ctx.Input<Tensor>("WeightX");
\
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH");
\
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias");
\
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX");
\
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
\
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell");
\
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
bool
use_peepholes = ctx.Attr<bool>("use_peepholes");
\
bool
is_reverse = ctx.Attr<bool>("is_reverse");
\
bool
is_reverse = ctx.Attr<bool>("is_reverse
");
bool
use_peepholes = ctx.Attr<bool>("use_peepholes
");
#define INIT_BASE_SIZES \
#define INIT_BASE_SIZES \
auto x_dims = x->dims();
/* T x M*/
\
auto x_dims = x->dims();
/* T x M*/
\
...
@@ -261,25 +260,24 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -261,25 +260,24 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
auto
x_lod
=
x
->
lod
();
auto
x_lod
=
x
->
lod
();
const
int
total_T
=
x_dims
[
0
];
const
int
total_T
=
x_dims
[
0
];
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
// batch size
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
nullptr
;
const
T
*
c0_data
=
c0
?
c0
->
data
<
T
>
()
:
nullptr
;
const
T
*
c0_data
=
c0
?
c0
->
data
<
T
>
()
:
nullptr
;
const
T
*
bias_data
=
bias
->
data
<
T
>
();
const
T
*
wc_data
=
bias_data
+
D4
;
// w_ic, w_fc, w_oc
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wx_data
=
wx
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
const
T
*
wh_data
=
wh
->
data
<
T
>
();
const
T
*
wc_data
=
bias
->
data
<
T
>
()
+
D4
;
// diagonal weight
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
place
=
ctx
.
GetPlace
();
T
*
hidden_out_data
=
hidden_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
xx_data
=
xx
->
mutable_data
<
T
>
(
place
);
T
*
cell_out_data
=
cell_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
hidden_out_data
=
hidden_out
->
mutable_data
<
T
>
(
place
);
T
*
cell_out_data
=
cell_out
->
mutable_data
<
T
>
(
place
);
// use local variable
framework
::
DDim
check_dims
({
3
,
D
});
Tensor
checked_cell
;
Tensor
checked_cell
;
// w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
T
*
checked_cell_data
=
nullptr
;
auto
checked_cell_data
=
if
(
use_peepholes
)
{
checked_cell
.
mutable_data
<
T
>
(
check_dims
,
ctx
.
GetPlace
());
// w_ic * Ct-1, w_fc * Ct-1 // , w_oc * Ct => ih
checked_cell_data
=
checked_cell
.
mutable_data
<
T
>
({
2
,
D
},
place
);
}
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
total_T
,
D4
,
M
,
x_data
,
wx_data
,
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
total_T
,
D4
,
M
,
x_data
,
wx_data
,
...
@@ -306,44 +304,31 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -306,44 +304,31 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
int
seq_len
=
x_lod
[
0
][
bid
+
1
]
-
x_lod
[
0
][
bid
];
const
T
*
prev_c_data
=
nullptr
;
const
T
*
prev_c_data
=
nullptr
;
const
T
*
prev_h_data
=
nullptr
;
const
T
*
prev_h_data
=
nullptr
;
int
tstart
=
0
;
int
tstart
=
0
;
if
(
h0_data
)
{
if
(
h0_data
)
{
prev_h_data
=
h0_data
+
bid
*
D
;
prev_h_data
=
h0_data
+
bid
*
D
;
prev_c_data
=
c0_data
+
bid
*
D
;
prev_c_data
=
c0_data
+
bid
*
D
;
}
else
{
}
else
{
// If step == 0 and there is no initialized hidden state, that is to say
// W_ch, W_ih, W_fh, W_oh
// the H0 is zeros. Then W_h * H_t-1 can be skipped
act_gate
(
D
,
xx_data
+
D
,
xx_data
+
D
);
// ~C_t
act_cand
(
D
,
xx_data
,
xx_data
);
act_cand
(
D
,
xx_data
,
xx_data
);
if
(
use_peepholes
)
{
// C_t = input * tilde
// I_t, F_t
act_gate
(
D2
,
xx_data
+
D
,
xx_data
+
D
);
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
}
// C_t = I_t * ~C_t
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
cell_out_data
);
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
cell_out_data
);
// H_t = act_state(cellout) * outgate
if
(
use_peepholes
)
{
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
// + W_oc * C_t for peephole connection
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cell_out_data
,
checked_cell_data
+
D2
);
// put result on W_ih
blas
.
VADD
(
D
,
xx_data
+
D3
,
checked_cell_data
+
D2
,
xx_data
+
D3
);
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cell_out_data
,
xx_data
+
D
);
// O_t
blas
.
VADD
(
D
,
xx_data
+
D
,
xx_data
+
D3
,
xx_data
+
D3
);
act_gate
(
D
,
xx_data
+
D3
,
xx_data
+
D3
);
}
}
act_gate
(
D
,
xx_data
+
D3
,
xx_data
+
D3
);
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
// prev
prev_h_data
=
hidden_out_data
;
prev_h_data
=
hidden_out_data
;
prev_c_data
=
cell_out_data
;
prev_c_data
=
cell_out_data
;
tstart
=
1
;
tstart
=
1
;
move_step
();
move_step
();
}
}
...
@@ -353,39 +338,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -353,39 +338,32 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
1
,
D4
,
D
,
static_cast
<
T
>
(
1
),
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
D4
);
prev_h_data
,
D
,
wh_data
,
D4
,
static_cast
<
T
>
(
1
),
xx_data
,
D4
);
// ~C_t
// W_ch, W_ih, W_fh, W_oh
act_cand
(
D
,
xx_data
,
xx_data
);
if
(
use_peepholes
)
{
if
(
use_peepholes
)
{
// + W_ic|W_fc * C_t-1 for peephole connection
// + W_ic|W_fc * C_t-1 for peephole connection
blas
.
VMUL
(
D
,
wc_data
,
prev_c_data
,
checked_cell_data
);
blas
.
VMUL
(
D
,
wc_data
,
prev_c_data
,
checked_cell_data
);
blas
.
VMUL
(
D
,
wc_data
+
D
,
prev_c_data
,
checked_cell_data
+
D
);
blas
.
VMUL
(
D
,
wc_data
+
D
,
prev_c_data
,
checked_cell_data
+
D
);
blas
.
VADD
(
D2
,
xx_data
+
D
,
checked_cell_data
,
xx_data
+
D
);
blas
.
VADD
(
D2
,
checked_cell_data
,
xx_data
+
D
,
xx_data
+
D
);
// I_t, F_t
act_gate
(
D2
,
xx_data
+
D
,
xx_data
+
D
);
act_gate
(
D2
,
xx_data
+
D
,
xx_data
+
D
);
}
else
{
}
else
{
// I_t, F_t, O_t
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
act_gate
(
D3
,
xx_data
+
D
,
xx_data
+
D
);
}
}
// a = I_t * act_cand(ch)
// F_t * C_t-1
act_cand
(
D
,
xx_data
,
xx_data
);
blas
.
VMUL
(
D
,
xx_data
+
D2
,
prev_c_data
,
xx_data
+
D2
);
// I_t * ~C_t
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
xx_data
+
D
);
blas
.
VMUL
(
D
,
xx_data
,
xx_data
+
D
,
xx_data
+
D
);
// C_t = F_t * C_t-1 + I_t * ~C_t
// b = C_t-1 * F_t
blas
.
VMUL
(
D
,
prev_c_data
,
xx_data
+
D2
,
xx_data
+
D2
);
// C_t = a + b
blas
.
VADD
(
D
,
xx_data
+
D
,
xx_data
+
D2
,
cell_out_data
);
blas
.
VADD
(
D
,
xx_data
+
D
,
xx_data
+
D2
,
cell_out_data
);
// H_t = act_cell(C_t) * act_gate(O_c += C_t * W_oc)
if
(
use_peepholes
)
{
if
(
use_peepholes
)
{
// + W_oc * C_t for peephole connection
// put result on W_ih
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cell_out_data
,
checked_cell_data
+
D2
);
blas
.
VMUL
(
D
,
wc_data
+
D2
,
cell_out_data
,
xx_data
+
D
);
blas
.
VADD
(
D
,
xx_data
+
D3
,
checked_cell_data
+
D2
,
xx_data
+
D3
);
blas
.
VADD
(
D
,
xx_data
+
D
,
xx_data
+
D3
,
xx_data
+
D3
);
// O_t
act_gate
(
D
,
xx_data
+
D3
,
xx_data
+
D3
);
act_gate
(
D
,
xx_data
+
D3
,
xx_data
+
D3
);
}
}
// hidden out= act_state(cellout) * outgate
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
act_cell
(
D
,
cell_out_data
,
xx_data
+
D2
);
// H_t = O_t * act_state(C_t)
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
blas
.
VMUL
(
D
,
xx_data
+
D2
,
xx_data
+
D3
,
hidden_out_data
);
// prev
// prev
...
@@ -393,8 +371,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
...
@@ -393,8 +371,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_c_data
=
cell_out_data
;
prev_c_data
=
cell_out_data
;
move_step
();
move_step
();
}
// for
each step in batch
}
// for
seqlen
}
// for
each
batch
}
// for batch
}
}
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
void
BatchCompute
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
...
...
python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py
浏览文件 @
2f3b4989
...
@@ -53,12 +53,11 @@ class TestFusionLSTMOp(OpTest):
...
@@ -53,12 +53,11 @@ class TestFusionLSTMOp(OpTest):
self
.
M
=
8
self
.
M
=
8
self
.
D
=
16
self
.
D
=
16
self
.
has_initial_state
=
False
self
.
has_initial_state
=
False
self
.
use_peepholes
=
False
self
.
is_reverse
=
False
self
.
is_reverse
=
False
self
.
act_gate
=
'sigmoid'
self
.
act_gate
=
'sigmoid'
self
.
act_cell
=
'tanh'
self
.
act_cell
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
act_cand
=
'tanh'
self
.
use_peepholes
=
False
self
.
use_seq
=
False
self
.
set_conf
()
self
.
set_conf
()
T
=
sum
(
self
.
lod
[
0
])
T
=
sum
(
self
.
lod
[
0
])
...
@@ -108,7 +107,6 @@ class TestFusionLSTMOp(OpTest):
...
@@ -108,7 +107,6 @@ class TestFusionLSTMOp(OpTest):
}
}
self
.
attrs
=
{
self
.
attrs
=
{
'use_peepholes'
:
self
.
use_peepholes
,
'use_peepholes'
:
self
.
use_peepholes
,
'use_seq'
:
self
.
use_seq
,
'is_reverse'
:
self
.
is_reverse
,
'is_reverse'
:
self
.
is_reverse
,
'gate_activation'
:
self
.
act_gate
,
'gate_activation'
:
self
.
act_gate
,
'cell_activation'
:
self
.
act_cell
,
'cell_activation'
:
self
.
act_cell
,
...
@@ -178,50 +176,18 @@ class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp):
...
@@ -178,50 +176,18 @@ class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp):
self
.
is_reverse
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOpP
oopholesBS1
(
TestFusionLSTMOp
):
class
TestFusionLSTMOpP
eepholesInitReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
def
set_conf
(
self
):
self
.
use_peepholes
=
True
self
.
use_peepholes
=
True
self
.
lod
=
[[
3
]]
self
.
D
=
16
class
TestFusionLSTMOpSeqInit
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
has_initial_state
=
True
class
TestFusionLSTMOpSeqReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOpSeqInitReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
has_initial_state
=
True
self
.
has_initial_state
=
True
self
.
is_reverse
=
True
self
.
is_reverse
=
True
class
TestFusionLSTMOpSeqPeepholes
(
TestFusionLSTMOp
):
class
TestFusionLSTMOpPoopholesBS1
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
class
TestFusionLSTMOpSeqPeepholesInit
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
self
.
has_initial_state
=
True
class
TestFusionLSTMOpSeqPeepholesReverse
(
TestFusionLSTMOp
):
def
set_conf
(
self
):
def
set_conf
(
self
):
self
.
use_seq
=
True
self
.
use_peepholes
=
True
self
.
use_peepholes
=
True
self
.
is_reverse
=
True
self
.
lod
=
[[
2
]]
self
.
D
=
8
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录