Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6ed20474
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6ed20474
编写于
8月 22, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine attention lstm infershape
上级
508548f8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
111 addition
and
87 deletion
+111
-87
paddle/fluid/operators/attention_lstm_op.cc
paddle/fluid/operators/attention_lstm_op.cc
+111
-87
未找到文件。
paddle/fluid/operators/attention_lstm_op.cc
浏览文件 @
6ed20474
...
...
@@ -26,86 +26,102 @@ namespace paddle {
namespace
operators
{
void
AttentionLSTMOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightX"
),
"Input(WeightX) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightH"
),
"Input(WeightH) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Input(Bias) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XX"
),
"Output(XX) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of AttentionLSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
"Input(C0) of AttentionLSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"LSTMWeight"
),
"Input(LSTMWeight) of AttentionLSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"LSTMBias"
),
"Input(LSTMBias) of AttentionLSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"AttentionWeight"
),
"Input(AttentionWeight) of AttentionLSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
"Output(Hidden) of LSTM should not be null."
);
"Output(Hidden) of
Attention
LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Output(Cell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedGate"
),
"Output(BatchedGate) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
"Output(BatchedGate) of LSTM should not be null."
);
"Output(Cell) of AttentionLSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"AttentionedX"
),
"Output(AttentionedX) of AttentionLSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"AttentionFCOut"
),
"Output(AttentionFCOut) of AttentionLSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"LSTMX"
),
"Output(LSTMX) of AttentionLSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"LSTMOUT"
),
"Output(LSTMOUT) of AttentionLSTM should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
int
M
=
x_dims
[
1
];
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
auto
w_dims
=
ctx
->
GetInputDim
(
"LSTMWeight"
);
const
int
D
=
w_dims
[
1
]
/
4
;
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
"Input(LSTMWeight)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
w_dims
[
0
],
D
+
M
,
"LSTMWeight dims should be (%d + %d) * %d."
,
D
+
M
,
4
*
D
);
auto
b_dims
=
ctx
->
GetInputDim
(
"LSTMBias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"Input(LSTMBias)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"LSTMBias dims should be 1 x (%d + %d)."
,
M
,
D
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
M
+
D
,
"LSTMBias dims should be 1 x (%d + %d)."
,
M
,
D
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
PADDLE_ENFORCE_EQ
(
c_dims
.
size
(),
2
,
"Input(C0)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
c_dims
[
1
],
D
,
"C0 dims should be N x %d."
,
D
);
if
(
ctx
->
HasInput
(
"H0"
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
"Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time."
);
auto
h_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
PADDLE_ENFORCE
(
h_dims
==
c_dims
,
"The dimension of Input(H0) and Input(C0) "
"should be the same."
);
}
// fc_out , shape (maxseqlen,1)
int
max_seq_len
=
0
;
auto
wx_dims
=
ctx
->
GetInputDim
(
"WeightX"
);
PADDLE_ENFORCE_EQ
(
wx_dims
.
size
(),
2
,
"The rank of Input(WeightX) should be 2."
);
PADDLE_ENFORCE_EQ
(
wx_dims
[
0
],
x_dims
[
1
],
"The first dimension of Input(WeightX) "
"should be %d."
,
x_dims
[
1
]);
int
frame_size
=
wx_dims
[
1
]
/
4
;
auto
wh_dims
=
ctx
->
GetInputDim
(
"WeightH
"
);
PADDLE_ENFORCE_EQ
(
wh_dims
.
size
(),
2
,
"The rank of Input(WeightH) should be 2
."
);
PADDLE_ENFORCE_EQ
(
wh_dims
[
0
],
frame_size
,
"The first dimension of Input(WeightH) "
"should be %d."
,
frame_size
);
PADDLE_ENFORCE_EQ
(
wh_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(WeightH) "
"should be 4 * %d."
,
frame_size
);
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
PADDLE_ENFORCE
(
!
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
),
"Do not support peephole yet."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection"
,
frame_size
);
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
auto
atten_w_dims
=
ctx
->
GetInputDim
(
"AttentionWeight"
);
PADDLE_ENFORCE_EQ
(
atten_w_dims
.
size
(),
2
,
"Input(AttentionWeight)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
atten_w_dims
[
0
],
M
+
D
,
"AttentionWeight shapes must be (%d + %d) * 1."
,
M
,
D
);
PADDLE_ENFORCE_EQ
(
atten_w_dims
[
1
],
1
,
"AttentionWeight shapes must be (%d + %d) * 1."
,
M
,
D
);
if
(
ctx
->
HasInput
(
"AttentionBias"
))
{
auto
atten_b_dims
=
ctx
->
GetInputDim
(
"AttentionBias"
);
PADDLE_ENFORCE_EQ
(
atten_b_dims
.
size
(),
2
,
"Input(AttentionBias)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
atten_b_dims
[
0
],
1
,
"AttentionBias shapes must be 1 * 1.
"
);
PADDLE_ENFORCE_EQ
(
atten_b_dims
[
1
],
1
,
"AttentionBias shapes must be 1 * 1
."
);
}
if
(
ctx
->
HasInput
(
"AttentionScalar"
))
{
auto
dims
=
ctx
->
GetInputDim
(
"AttentionScalar"
);
PADDLE_ENFORCE_EQ
(
dims
.
size
(),
2
,
"Input(AttentionScalar)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
dims
[
0
],
1
,
"AttentionScalar shapes must be 1 * 1."
);
PADDLE_ENFORCE_EQ
(
dims
[
1
],
1
,
"AttentionScalar shapes must be 1 * 1."
);
}
if
(
ctx
->
HasInput
(
"AttentionScalarBias"
))
{
auto
dims
=
ctx
->
GetInputDim
(
"AttentionScalarBias"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"AttentionScalar"
),
"AttentionScalar should not be null when have AttentionScalarBias."
);
PADDLE_ENFORCE_EQ
(
dims
.
size
(),
2
,
"Input(AttentionScalarBias)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
dims
[
0
],
1
,
"AttentionScalarBias shapes must be 1 * 1."
);
PADDLE_ENFORCE_EQ
(
dims
[
1
],
1
,
"AttentionScalarBias shapes must be 1 * 1."
);
}
framework
::
DDim
out_dims
({
x_dims
[
0
],
D
});
ctx
->
SetOutputDim
(
"Hidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"Cell"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedGate"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchCellPreAct"
,
out_dims
);
ctx
->
SetOutputDim
(
"AttentionedX"
,
{
x_dims
[
0
],
1
});
ctx
->
SetOutputDim
(
"LSTMX"
,
{
1
,
M
});
ctx
->
SetOutputDim
(
"LSTMOUT"
,
{
1
,
4
*
D
});
// AttentionFCOut should be reshape as (maxseqlen,1) in runtime
ctx
->
ShareLoD
(
"X"
,
"Hidden"
);
ctx
->
ShareLoD
(
"X"
,
"Cell"
);
int
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
ctx
->
SetOutputDim
(
"XX"
,
{
x_dims
[
0
],
xx_width
});
ctx
->
ShareLoD
(
"X"
,
"XX"
);
}
framework
::
OpKernelType
AttentionLSTMOp
::
GetExpectedKernelType
(
...
...
@@ -164,11 +180,10 @@ void AttentionLSTMOpMaker::Make() {
AddOutput
(
"Cell"
,
"(LoDTensor) (same as LSTMOp) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`."
);
AddOutput
(
"AttentionedX"
,
"(LodTensor) shape is (T x 1), the result after X * AttentionWeight,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size."
)
AddOutput
(
"AttentionedX"
,
"(Tensor) shape is (T x 1), the result after X * AttentionWeight,"
" where T is the total time steps in this mini-batch,"
" D is the hidden size."
)
.
AsIntermediate
();
AddOutput
(
"AttentionFCOut"
,
"(Tensor) (max_seq_len, 1), compute at each step."
)
...
...
@@ -316,12 +331,31 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto
*
lstm_w
=
ctx
.
Input
<
Tensor
>
(
"LSTMWeight"
);
// (D+M) x D*4
auto
*
lstm_b
=
ctx
.
Input
<
Tensor
>
(
"LSTMBias"
);
// 1 x D*4
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
// TxD
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
// TxD
auto
*
atted_x
=
ctx
.
Output
<
LoDTensor
>
(
"AttentionedX"
);
// T x 1
auto
*
fc_out
=
ctx
.
Output
<
Tensor
>
(
'
AttentionFCOut
'
);
// max_seq_len x 1
auto
*
lstm_x
=
ctx
.
Output
<
Tensor
>
(
"LSTMX"
);
// 1 x M
auto
*
lstm_out
=
ctx
.
Output
<
Tensor
>
(
"LSTMOUT"
);
// 1 x 4D
auto
*
hidden_out
=
ctx
.
Output
<
LoDTensor
>
(
"Hidden"
);
// TxD
auto
*
cell_out
=
ctx
.
Output
<
LoDTensor
>
(
"Cell"
);
// TxD
auto
*
atted_x
=
ctx
.
Output
<
Tensor
>
(
"AttentionedX"
);
// T x 1
auto
*
fc_out
=
ctx
.
Output
<
Tensor
>
(
'
AttentionFCOut
'
);
// max_seq_len x 1
auto
*
lstm_x
=
ctx
.
Output
<
Tensor
>
(
"LSTMX"
);
// 1 x M
auto
*
lstm_out
=
ctx
.
Output
<
Tensor
>
(
"LSTMOUT"
);
// 1 x 4D
// some shape should be reshape here since infershape can not get lod info
auto
x_lod
=
x
->
lod
();
const
int
N
=
x_lod
[
0
].
size
()
-
1
;
// batch size
auto
x_dims
=
x
->
dims
();
// T x M
auto
w_dims
=
w
->
dims
();
// (D+M) x 4D
const
int
M
=
x_dims
[
1
];
// x frame size
const
int
D
=
w_dims
[
1
]
/
4
;
// gate frame size
const
int
D2
=
D
*
2
;
const
int
D3
=
D
*
3
;
const
int
D4
=
w_dims
[
1
];
int
max_seq_len
=
x_lod
[
0
][
1
];
for
(
int
i
=
1
;
i
<
N
;
++
i
)
{
int
len
=
x_lod
[
0
][
i
+
1
]
-
x_lod
[
0
][
i
];
max_seq_len
=
max_seq_len
<
len
?
len
:
max_seq_len
;
}
PADDLE_ENFORCE_EQ
(
x_lod
.
size
(),
1
,
"Input(X)'s lod size must be 1."
);
PADDLE_ENFORCE_EQ
(
c0
->
dims
()[
0
],
N
,
"C0 dims should be %d x %d."
,
N
,
D
);
fc_out
->
Resize
({
max_seq_len
,
1
});
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
h0_data
=
h0
->
data
<
T
>
();
...
...
@@ -341,16 +375,6 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
T
*
lstm_x_data
=
lstm_x
->
mutable_data
<
T
>
();
T
*
lstm_out_data
=
lstm_out
->
mutable_data
<
T
>
();
auto
x_lod
=
x
->
lod
();
auto
x_dims
=
x
->
dims
();
// T x M
auto
w_dims
=
w
->
dims
();
// (D+M) x 4D
const
int
M
=
x_dims
[
1
];
// x frame size
const
int
D
=
w_dims
[
1
]
/
4
;
// gate frame size
const
int
D2
=
D
*
2
;
const
int
D3
=
D
*
3
;
const
int
D4
=
w_dims
[
1
];
const
int
batch_size
=
x_lod
[
0
].
size
()
-
1
;
// assert lod.size() == 1
// x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
auto
blas
=
math
::
GetBlas
<
DeviceContext
,
T
>
(
ctx
);
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
T
,
1
,
M
,
x_data
,
atten_w_data
,
...
...
@@ -361,7 +385,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
const
T
*
prev_hidden_data
=
NULL
;
T
*
cur_cell_out_data
=
cell_out_data
;
T
*
cur_hidden_out_data
=
hidden_out_data
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
int
seq_len
=
x_lod
[
0
][
i
+
1
];
prev_cell_data
=
c0_data
+
i
*
D
;
prev_hidden_data
=
h0
?
h0_data
+
i
*
D
:
NULL
;
...
...
@@ -370,13 +394,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
/// compute attention vector
// prev_cell(1xD) * fc(D) rest part of atten_wgt
// T = cblas_dot();
T
prev_cell_bias
=
blas
.
V
DOT
(
D
,
prev_cell_data
,
atten_w_data
+
M
);
T
prev_cell_bias
=
blas
.
DOT
(
D
,
prev_cell_data
,
atten_w_data
+
M
);
// add cell bias and relu
bias_relu
<
T
>
(
seq_len
,
atted_x_data
,
&
prev_cell_bias
,
fc_out_data
);
// fc2: scalar
if
(
atten_scalar_data
)
{
// x = a*x
blas
.
V
SCAL
(
seq_len
,
atten_scalar_data
,
fc_out_data
);
blas
.
SCAL
(
seq_len
,
atten_scalar_data
,
fc_out_data
);
bias_relu
<
T
>
(
seq_len
,
fc_out_data
,
atten_scalar_bias_data
,
fc_out_data
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录