Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
dd938d0b
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dd938d0b
编写于
8月 22, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs and pass op test
上级
522b3e41
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
23 deletion
+22
-23
paddle/fluid/operators/attention_lstm_op.cc
paddle/fluid/operators/attention_lstm_op.cc
+17
-19
python/paddle/fluid/tests/unittests/test_attention_lstm_op.py
...on/paddle/fluid/tests/unittests/test_attention_lstm_op.py
+5
-4
未找到文件。
paddle/fluid/operators/attention_lstm_op.cc
浏览文件 @
dd938d0b
...
@@ -59,10 +59,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -59,10 +59,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
auto
b_dims
=
ctx
->
GetInputDim
(
"LSTMBias"
);
auto
b_dims
=
ctx
->
GetInputDim
(
"LSTMBias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"Input(LSTMBias)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"Input(LSTMBias)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"LSTMBias dims should be 1 x (%d + %d)."
,
M
,
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"LSTMBias dims should be 1 x %d."
,
4
*
D
);
D
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
D
,
"LSTMBias dims should be 1 x %d."
,
4
*
D
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
M
+
D
,
"LSTMBias dims should be 1 x (%d + %d)."
,
M
,
D
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
PADDLE_ENFORCE_EQ
(
c_dims
.
size
(),
2
,
"Input(C0)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
c_dims
.
size
(),
2
,
"Input(C0)'s rank must be 2."
);
...
@@ -148,8 +146,8 @@ void AttentionLSTMOpMaker::Make() {
...
@@ -148,8 +146,8 @@ void AttentionLSTMOpMaker::Make() {
"(Tensor) the weights of attention fc. Always relu the fc result."
"(Tensor) the weights of attention fc. Always relu the fc result."
"The shape is ((M+D) x 1), where M is the dim size of x, D is the "
"The shape is ((M+D) x 1), where M is the dim size of x, D is the "
"gate size of LSTM."
);
"gate size of LSTM."
);
AddInput
(
"AttentionBias
, optional
"
,
AddInput
(
"AttentionBias"
,
"(Tensor) the bias of attention fc."
"(Tensor
, optional
) the bias of attention fc."
"The shape is (1 x 1)"
)
"The shape is (1 x 1)"
)
.
AsDispensable
();
.
AsDispensable
();
AddInput
(
"AttentionScalar"
,
AddInput
(
"AttentionScalar"
,
...
@@ -281,7 +279,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
...
@@ -281,7 +279,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
auto
*
atten_w
=
ctx
.
Input
<
Tensor
>
(
"AttentionWeight"
);
// (M+D) x 1
auto
*
atten_w
=
ctx
.
Input
<
Tensor
>
(
"AttentionWeight"
);
// (M+D) x 1
auto
*
atten_b
=
ctx
.
Input
<
Tensor
>
(
"AttentionBias"
);
// 1x1
auto
*
atten_b
=
ctx
.
Input
<
Tensor
>
(
"AttentionBias"
);
// 1x1
auto
*
atten_scalar
=
ctx
.
Input
<
Tensor
>
(
"AttentionScalar"
);
// 1x1
auto
*
atten_scalar
=
ctx
.
Input
<
Tensor
>
(
"AttentionScalar"
);
// 1x1
auto
*
atten_scalar_bias
=
ctx
.
Input
<
Tensor
>
(
"AttentionScalar"
);
// 1x1
auto
*
atten_scalar_bias
=
ctx
.
Input
<
Tensor
>
(
"AttentionScalar
Bias
"
);
// 1x1
auto
*
lstm_w
=
ctx
.
Input
<
Tensor
>
(
"LSTMWeight"
);
// (D+M) x D*4
auto
*
lstm_w
=
ctx
.
Input
<
Tensor
>
(
"LSTMWeight"
);
// (D+M) x D*4
auto
*
lstm_b
=
ctx
.
Input
<
Tensor
>
(
"LSTMBias"
);
// 1 x D*4
auto
*
lstm_b
=
ctx
.
Input
<
Tensor
>
(
"LSTMBias"
);
// 1 x D*4
...
@@ -319,7 +317,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
...
@@ -319,7 +317,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
// }
// }
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
h0_data
=
h0
->
data
<
T
>
()
;
const
T
*
h0_data
=
h0
?
h0
->
data
<
T
>
()
:
NULL
;
const
T
*
c0_data
=
c0
->
data
<
T
>
();
const
T
*
c0_data
=
c0
->
data
<
T
>
();
const
T
*
lstm_w_data
=
lstm_w
->
data
<
T
>
();
const
T
*
lstm_w_data
=
lstm_w
->
data
<
T
>
();
const
T
*
lstm_b_data
=
lstm_b
->
data
<
T
>
();
const
T
*
lstm_b_data
=
lstm_b
->
data
<
T
>
();
...
@@ -341,36 +339,35 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
...
@@ -341,36 +339,35 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
total_T
,
1
,
M
,
x_data
,
atten_w_data
,
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
total_T
,
1
,
M
,
x_data
,
atten_w_data
,
atted_x_data
,
atten_b_data
);
atted_x_data
,
atten_b_data
);
const
T
*
cur_atten_x_data
=
atted_x_data
;
const
T
*
cur_x_data
=
x_data
;
const
T
*
cur_x_data
=
x_data
;
const
T
*
prev_cell_data
=
NULL
;
const
T
*
prev_cell_data
=
NULL
;
const
T
*
prev_hidden_data
=
NULL
;
const
T
*
prev_hidden_data
=
NULL
;
T
*
cur_cell_out_data
=
cell_out_data
;
T
*
cur_cell_out_data
=
cell_out_data
;
T
*
cur_hidden_out_data
=
hidden_out_data
;
T
*
cur_hidden_out_data
=
hidden_out_data
;
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
int
seq_len
=
x_lod
[
0
][
i
+
1
];
int
seq_len
=
x_lod
[
0
][
i
+
1
]
-
x_lod
[
0
][
i
]
;
prev_cell_data
=
c0_data
+
i
*
D
;
prev_cell_data
=
c0_data
+
i
*
D
;
prev_hidden_data
=
h0
?
h0_data
+
i
*
D
:
NULL
;
prev_hidden_data
=
h0_data
?
h0_data
+
i
*
D
:
NULL
;
for
(
int
step
=
0
;
step
<
seq_len
;
++
step
)
{
for
(
int
step
=
0
;
step
<
seq_len
;
++
step
)
{
/// compute attention vector
/// 1. compute attention vector
// prev_cell(1xD) * fc(D) rest part of atten_wgt
// 1a. prev_cell(1xD) * fc(D) rest part of atten_wgt
// T = cblas_dot();
T
prev_cell_bias
=
blas
.
DOT
(
D
,
prev_cell_data
,
atten_w_data
+
M
);
T
prev_cell_bias
=
blas
.
DOT
(
D
,
prev_cell_data
,
atten_w_data
+
M
);
// add cell bias and relu
//
1b.
add cell bias and relu
bias_relu
<
T
>
(
seq_len
,
atted
_x_data
,
&
prev_cell_bias
,
fc_out_data
);
bias_relu
<
T
>
(
seq_len
,
cur_atten
_x_data
,
&
prev_cell_bias
,
fc_out_data
);
//
fc2:
scalar
//
1c. fc
scalar
if
(
atten_scalar_data
)
{
if
(
atten_scalar_data
)
{
// x = a*x
blas
.
SCAL
(
seq_len
,
*
atten_scalar_data
,
fc_out_data
);
blas
.
SCAL
(
seq_len
,
*
atten_scalar_data
,
fc_out_data
);
bias_relu
<
T
>
(
seq_len
,
fc_out_data
,
atten_scalar_bias_data
,
bias_relu
<
T
>
(
seq_len
,
fc_out_data
,
atten_scalar_bias_data
,
fc_out_data
);
fc_out_data
);
}
}
// 1d. softmax
vec_softmax
<
DeviceContext
,
T
>
(
blas
,
seq_len
,
fc_out_data
,
fc_out_data
);
vec_softmax
<
DeviceContext
,
T
>
(
blas
,
seq_len
,
fc_out_data
,
fc_out_data
);
// mul x(seq_len*M) and sum pool
// mul x(seq_len*M) and sum pool
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
1
,
M
,
seq_len
,
fc_out_data
,
math
::
FCCompute
<
DeviceContext
,
T
>
(
blas
,
1
,
M
,
seq_len
,
fc_out_data
,
cur_x_data
,
lstm_x_data
);
cur_x_data
,
lstm_x_data
);
/// compute LSTM step
///
2.
compute LSTM step
// lstm weight : concat[forget , input , output , tilde]
// lstm weight : concat[forget , input , output , tilde]
// shape : (D + M) x (4 * D)
// shape : (D + M) x (4 * D)
// fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D
// fc inputX(1xM) * weightX(M*(4D)) => 1 x 4D
...
@@ -407,6 +404,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
...
@@ -407,6 +404,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
cur_hidden_out_data
=
cur_hidden_out_data
+
D
;
cur_hidden_out_data
=
cur_hidden_out_data
+
D
;
}
}
cur_x_data
=
cur_x_data
+
seq_len
*
M
;
cur_x_data
=
cur_x_data
+
seq_len
*
M
;
cur_atten_x_data
=
cur_atten_x_data
+
seq_len
;
}
}
}
}
};
};
...
...
python/paddle/fluid/tests/unittests/test_attention_lstm_op.py
浏览文件 @
dd938d0b
...
@@ -40,19 +40,20 @@ def attention_lstm(
...
@@ -40,19 +40,20 @@ def attention_lstm(
D
=
b
.
shape
[
1
]
/
4
D
=
b
.
shape
[
1
]
/
4
assert
T
==
x
.
shape
[
0
]
assert
T
==
x
.
shape
[
0
]
assert
len
(
fcws
)
==
len
(
fcbs
)
assert
len
(
fcws
)
==
len
(
fcbs
)
hidden
=
[]
hidden
=
[]
cell
=
[]
cell
=
[]
start_offset
=
0
start_offset
=
0
for
bid
in
range
(
N
):
for
bid
in
range
(
N
):
seq_len
=
lod
[
0
][
bid
]
seq_len
=
lod
[
0
][
bid
]
xi
=
np
.
copy
(
x
[
start_offset
:
seq_len
,
:]).
reshape
(
seq_len
,
M
)
xi
=
np
.
copy
(
x
[
start_offset
:
start_offset
+
seq_len
,
:]).
reshape
(
seq_len
,
M
)
prev_cell
=
np
.
copy
(
c0
[
bid
]).
reshape
([
1
,
D
])
prev_cell
=
np
.
copy
(
c0
[
bid
]).
reshape
([
1
,
D
])
prev_hidden
=
np
.
copy
(
h0
[
bid
]).
reshape
([
1
,
D
])
prev_hidden
=
np
.
copy
(
h0
[
bid
]).
reshape
([
1
,
D
])
for
step
in
range
(
seq_len
):
for
step
in
range
(
seq_len
):
expanded_cell
=
np
.
repeat
(
prev_cell
,
seq_len
,
axis
=
0
)
expanded_cell
=
np
.
repeat
(
prev_cell
,
seq_len
,
axis
=
0
)
tmp
=
np
.
concatenate
((
xi
,
expanded_cell
),
axis
=
1
)
tmp
=
np
.
concatenate
((
xi
,
expanded_cell
),
axis
=
1
)
assert
tmp
.
shape
[
0
]
==
seq_len
assert
tmp
.
shape
[
1
]
==
M
+
D
assert
tmp
.
shape
[
1
]
==
M
+
D
for
fcid
in
range
(
len
(
fcbs
)):
for
fcid
in
range
(
len
(
fcbs
)):
tmp
=
fc
(
tmp
,
fcws
[
fcid
],
fcbs
[
fcid
])
tmp
=
fc
(
tmp
,
fcws
[
fcid
],
fcbs
[
fcid
])
...
@@ -62,7 +63,7 @@ def attention_lstm(
...
@@ -62,7 +63,7 @@ def attention_lstm(
lstmx
=
xi
*
tmp
# seq * M
lstmx
=
xi
*
tmp
# seq * M
lstmx
=
np
.
sum
(
lstmx
.
reshape
(
seq_len
,
M
),
axis
=
0
).
reshape
([
1
,
M
])
lstmx
=
np
.
sum
(
lstmx
.
reshape
(
seq_len
,
M
),
axis
=
0
).
reshape
([
1
,
M
])
lstmin
=
np
.
concatenate
((
prev_hidden
,
lstmx
),
axis
=
1
)
lstmin
=
np
.
concatenate
((
prev_hidden
,
lstmx
),
axis
=
1
)
lstmout
=
np
.
dot
(
lstmin
,
w
).
reshape
([
1
,
4
*
D
])
lstmout
=
fc
(
lstmin
,
w
,
b
).
reshape
([
1
,
4
*
D
])
g_f
,
g_i
,
g_o
,
cand
=
np
.
split
(
lstmout
,
4
,
axis
=
1
)
g_f
,
g_i
,
g_o
,
cand
=
np
.
split
(
lstmout
,
4
,
axis
=
1
)
g_f
=
act_gate
(
g_f
).
reshape
([
1
,
D
])
g_f
=
act_gate
(
g_f
).
reshape
([
1
,
D
])
...
@@ -88,7 +89,7 @@ def attention_lstm(
...
@@ -88,7 +89,7 @@ def attention_lstm(
class
TestAttentionLSTMOp
(
OpTest
):
class
TestAttentionLSTMOp
(
OpTest
):
def
set_conf
(
self
):
def
set_conf
(
self
):
self
.
lod
=
[[
3
]]
pass
def
setUp
(
self
):
def
setUp
(
self
):
self
.
op_type
=
'attention_lstm'
self
.
op_type
=
'attention_lstm'
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录