Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
417b576c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
417b576c
编写于
5月 13, 2020
作者:
L
liu zhengxi
提交者:
GitHub
5月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
API(dynamic_lstm, dynamic_lstmp) error message enhancement (#24450)
* update err msg for dynamic_lstm and dynamic_lstmp, test=develop
上级
53bdee64
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
269 addition
and
123 deletion
+269
-123
paddle/fluid/operators/lstm_op.cc
paddle/fluid/operators/lstm_op.cc
+67
-56
paddle/fluid/operators/lstmp_op.cc
paddle/fluid/operators/lstmp_op.cc
+81
-66
python/paddle/fluid/layers/rnn.py
python/paddle/fluid/layers/rnn.py
+29
-1
python/paddle/fluid/tests/unittests/test_lstm_op.py
python/paddle/fluid/tests/unittests/test_lstm_op.py
+36
-0
python/paddle/fluid/tests/unittests/test_lstmp_op.py
python/paddle/fluid/tests/unittests/test_lstmp_op.py
+56
-0
未找到文件。
paddle/fluid/operators/lstm_op.cc
浏览文件 @
417b576c
...
@@ -24,64 +24,80 @@ class LSTMOp : public framework::OperatorWithKernel {
...
@@ -24,64 +24,80 @@ class LSTMOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
"LSTM"
);
"Input(Input) of LSTM should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Weight"
),
"Input"
,
"Weight"
,
"LSTM"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias"
),
"Input"
,
"Bias"
,
"LSTM"
);
"Input(Weight) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Hidden"
),
"Output"
,
"Hidden"
,
"LSTM"
);
"Input(Bias) of LSTM should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Cell"
),
"Output"
,
"Cell"
,
"LSTM"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchGate"
),
"Output"
,
"BatchGate"
,
"LSTM"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
"Output"
,
"Output(Hidden) of LSTM should not be null."
);
"BatchCellPreAct"
,
"LSTM"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Output(Cell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchGate"
),
"Output(BatchGate) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
"Output(BatchGate) of LSTM should not be null."
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"Input(X)'s rank must be 2, but received %d."
,
in_dims
.
size
()));
if
(
ctx
->
HasInput
(
"H0"
))
{
if
(
ctx
->
HasInput
(
"H0"
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
PADDLE_ENFORCE_EQ
(
"Input(Cell) and Input(Hidden) of LSTM should not "
ctx
->
HasInput
(
"C0"
),
true
,
"be null at the same time."
);
platform
::
errors
::
NotFound
(
"Input(Cell) and Input(Hidden) of LSTM "
"should not be null at the same time."
));
auto
h_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
h_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
PADDLE_ENFORCE
(
h_dims
==
c_dims
,
PADDLE_ENFORCE_EQ
(
h_dims
,
c_dims
,
"The dimension of Input(H0) and Input(C0) "
platform
::
errors
::
InvalidArgument
(
"should be the same."
);
"The dimension of Input(H0) and Input(C0) should "
"be the same, but received [%s] (H0) vs [%s] (C0)."
,
h_dims
,
c_dims
));
}
}
int
frame_size
=
in_dims
[
1
]
/
4
;
int
frame_size
=
in_dims
[
1
]
/
4
;
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
"The rank of Input(Weight) should be 2."
);
w_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The rank of Input(Weight) should be 2, but received %d."
,
w_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
w_dims
[
0
],
frame_size
,
PADDLE_ENFORCE_EQ
(
w_dims
[
0
],
frame_size
,
"The first dimension of Input(Weight) "
platform
::
errors
::
InvalidArgument
(
"should be %d."
,
"The first dimension of Input(Weight) should be %d, "
frame_size
);
"but received %d."
,
frame_size
,
w_dims
[
0
]));
PADDLE_ENFORCE_EQ
(
w_dims
[
1
],
4
*
frame_size
,
PADDLE_ENFORCE_EQ
(
w_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(Weight) "
platform
::
errors
::
InvalidArgument
(
"should be 4 * %d."
,
"The second dimension of Input(Weight) should be 4 * "
frame_size
);
"%d, but received %d."
,
frame_size
,
w_dims
[
1
]));
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
b_dims
.
size
(),
2
,
"The first dimension of Input(Bias) should be 1."
);
platform
::
errors
::
InvalidArgument
(
"The rank of Input(Bias) should be 2, but received %d."
,
b_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
platform
::
errors
::
InvalidArgument
(
"The first dimension of Input(Bias) should be 1, but received %d."
,
b_dims
[
0
]));
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
))
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
))
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
7
*
frame_size
,
PADDLE_ENFORCE_EQ
(
"The second dimension of Input(Bias) should be "
b_dims
[
1
],
7
*
frame_size
,
"7 * %d if enable peepholes connection"
,
platform
::
errors
::
InvalidArgument
(
frame_size
);
"The second dimension of Input(Bias) should be 7 * %d if enable "
"peepholes connection, but received %d."
,
frame_size
,
b_dims
[
1
]));
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
PADDLE_ENFORCE_EQ
(
"The second dimension of Input(Bias) should be "
b_dims
[
1
],
4
*
frame_size
,
"4 * %d if disable peepholes connection"
,
platform
::
errors
::
InvalidArgument
(
frame_size
);
"The second dimension of Input(Bias) should be 4 * %d if disable "
"peepholes connection, but received %d."
,
frame_size
,
b_dims
[
1
]));
}
}
framework
::
DDim
out_dims
({
in_dims
[
0
],
frame_size
});
framework
::
DDim
out_dims
({
in_dims
[
0
],
frame_size
});
...
@@ -229,21 +245,16 @@ class LSTMGradOp : public framework::OperatorWithKernel {
...
@@ -229,21 +245,16 @@ class LSTMGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
"LSTM@Grad"
);
"Input(Input) of LSTM should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Hidden"
),
"Input"
,
"Hidden"
,
"LSTM@Grad"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Hidden"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Cell"
),
"Input"
,
"Cell"
,
"LSTM@Grad"
);
"Input(Hidden) of LSTM should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Weight"
),
"Input"
,
"Weight"
,
"LSTM@Grad"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Cell"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias"
),
"Input"
,
"Bias"
,
"LSTM@Grad"
);
"Input(Cell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BatchGate"
),
"Input"
,
"BatchGate"
,
"Input(Weight) of LSTM should not be null."
);
"LSTM@Grad"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BatchCellPreAct"
),
"Input"
,
"BatchCellPreAct"
,
"Input(Bias) of LSTM should not be null."
);
"LSTM@Grad"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchGate"
),
"Input(BatchGate) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchCellPreAct"
),
"Input(BatchGate) of LSTM should not be null."
);
auto
SetOutGradDim
=
[
&
ctx
](
const
std
::
string
&
name
)
{
auto
SetOutGradDim
=
[
&
ctx
](
const
std
::
string
&
name
)
{
auto
g_name
=
framework
::
GradVarName
(
name
);
auto
g_name
=
framework
::
GradVarName
(
name
);
...
...
paddle/fluid/operators/lstmp_op.cc
浏览文件 @
417b576c
...
@@ -24,74 +24,92 @@ class LSTMPOp : public framework::OperatorWithKernel {
...
@@ -24,74 +24,92 @@ class LSTMPOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
"LSTMP"
);
"Input(Input) of LSTMP operator should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Weight"
),
"Input"
,
"Weight"
,
"LSTMP"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ProjWeight"
),
"Input"
,
"ProjWeight"
,
"LSTMP"
);
"Input(Weight) of LSTMP operator should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias"
),
"Input"
,
"Bias"
,
"LSTMP"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ProjWeight"
),
"Input(ProjWeight) of LSTMP operator should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Projection"
),
"Output"
,
"Projection"
,
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"LSTMP"
);
"Input(Bias) of LSTMP operator should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Cell"
),
"Output"
,
"Cell"
,
"LSTMP"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchGate"
),
"Output"
,
"BatchGate"
,
"LSTMP"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Projection"
),
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
"Output"
,
"Output(Projection) of LSTMP operator should not be null."
);
"BatchCellPreAct"
,
"LSTMP"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchHidden"
),
"Output"
,
"BatchHidden"
,
"Output(Cell) of LSTMP operator should not be null."
);
"LSTMP"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchGate"
),
"Output(BatchGate) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
"Output(BatchCellPreAct) of LSTMP operator should not be "
"null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchHidden"
),
"Output(BatchHidden) of LSTMP operator should not be null."
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
"Input(X)'s rank of LSTMP operator must be 2."
);
in_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"Input(X)'s rank of LSTMP operator must be 2, but received %d."
,
in_dims
.
size
()));
int
frame_size
=
in_dims
[
1
]
/
4
;
int
frame_size
=
in_dims
[
1
]
/
4
;
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
auto
proj_dims
=
ctx
->
GetInputDim
(
"ProjWeight"
);
auto
proj_dims
=
ctx
->
GetInputDim
(
"ProjWeight"
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
"The rank of Input(Weight) should be 2."
);
w_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
w_dims
[
0
],
proj_dims
[
1
],
platform
::
errors
::
InvalidArgument
(
"The first dimension of Input(Weight) "
"The rank of Input(Weight) should be 2, but received %d."
,
"should be %d."
,
w_dims
.
size
()));
proj_dims
[
1
]);
PADDLE_ENFORCE_EQ
(
w_dims
[
0
],
proj_dims
[
1
],
platform
::
errors
::
InvalidArgument
(
"The first dimension of Input(Weight) and the second dimension of "
"Input(ProjWeight) should be the same, but received %d vs %d."
,
w_dims
[
0
],
proj_dims
[
1
]));
PADDLE_ENFORCE_EQ
(
w_dims
[
1
],
4
*
frame_size
,
PADDLE_ENFORCE_EQ
(
w_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(Weight) "
platform
::
errors
::
InvalidArgument
(
"should be 4 * %d."
,
"The second dimension of Input(Weight) should be 4 * "
frame_size
);
"%d, but received %d."
,
frame_size
,
w_dims
[
1
]));
PADDLE_ENFORCE_EQ
(
proj_dims
.
size
(),
2
,
"The rank of Input(ProjWeight) should be 2."
);
PADDLE_ENFORCE_EQ
(
proj_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The rank of Input(ProjWeight) should be 2, but received %d."
,
proj_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
proj_dims
[
0
],
frame_size
,
PADDLE_ENFORCE_EQ
(
proj_dims
[
0
],
frame_size
,
"The first dimension of Input(ProjWeight) "
platform
::
errors
::
InvalidArgument
(
"should be %d."
,
"The first dimension of Input(ProjWeight) should be "
frame_size
);
"%d, but received %d."
,
frame_size
,
proj_dims
[
0
]));
if
(
ctx
->
HasInput
(
"H0"
))
{
if
(
ctx
->
HasInput
(
"H0"
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
PADDLE_ENFORCE_EQ
(
"Input(C0) of LSTMP operator should not be null after "
ctx
->
HasInput
(
"C0"
),
true
,
"Input(H0) provided."
);
platform
::
errors
::
NotFound
(
"Input(C0) of LSTMP operator should not "
"be null after Input(H0) provided."
));
}
}
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
b_dims
.
size
(),
2
,
"The first dimension of Input(Bias) should be 1."
);
platform
::
errors
::
InvalidArgument
(
"The rank of Input(Bias) should be 2, but received %d."
,
b_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
platform
::
errors
::
InvalidArgument
(
"The first dimension of Input(Bias) should be 1, but received %d."
,
b_dims
[
0
]));
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
))
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
))
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
7
*
frame_size
,
PADDLE_ENFORCE_EQ
(
"The second dimension of Input(Bias) should be "
b_dims
[
1
],
7
*
frame_size
,
"7 * %d if enable peepholes connection"
,
platform
::
errors
::
InvalidArgument
(
frame_size
);
"The second dimension of Input(Bias) should be 7 * %d if enable "
"peepholes connection, but received %d."
,
frame_size
,
b_dims
[
1
]));
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
PADDLE_ENFORCE_EQ
(
"The second dimension of Input(Bias) should be "
b_dims
[
1
],
4
*
frame_size
,
"4 * %d if disable peepholes connection"
,
platform
::
errors
::
InvalidArgument
(
frame_size
);
"The second dimension of Input(Bias) should be 4 * %d if disable "
"peepholes connection, but received %d."
,
frame_size
,
b_dims
[
1
]));
}
}
framework
::
DDim
out_dims
({
in_dims
[
0
],
frame_size
});
framework
::
DDim
out_dims
({
in_dims
[
0
],
frame_size
});
...
@@ -314,21 +332,18 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
...
@@ -314,21 +332,18 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Projection"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Projection"
),
"Input"
,
"Projection"
,
"Input(Projection) of LSTMP operator should not be null."
);
"LSTMP@Grad"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Cell"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Cell"
),
"Input"
,
"Cell"
,
"LSTMP@Grad"
);
"Input(Cell) of LSTMP operator should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Weight"
),
"Input"
,
"Weight"
,
"LSTMP@Grad"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ProjWeight"
),
"Input"
,
"ProjWeight"
,
"Input(Weight) of LSTMP operator should not be null."
);
"LSTMP@Grad"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ProjWeight"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias"
),
"Input"
,
"Bias"
,
"LSTMP@Grad"
);
"Input(ProjWeight) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BatchGate"
),
"Input"
,
"BatchGate"
,
"Input(Bias) of LSTMP operator should not be null."
);
"LSTMP@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BatchCellPreAct"
),
"Input"
,
"BatchCellPreAct"
,
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchGate"
),
"LSTMP@Grad"
);
"Input(BatchGate) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchCellPreAct"
),
"Input(BatchGate) of LSTMP operator should not be null."
);
auto
SetOutGradDim
=
[
&
ctx
](
const
std
::
string
&
name
)
{
auto
SetOutGradDim
=
[
&
ctx
](
const
std
::
string
&
name
)
{
auto
g_name
=
framework
::
GradVarName
(
name
);
auto
g_name
=
framework
::
GradVarName
(
name
);
...
...
python/paddle/fluid/layers/rnn.py
浏览文件 @
417b576c
...
@@ -2073,7 +2073,21 @@ def dynamic_lstm(input,
...
@@ -2073,7 +2073,21 @@ def dynamic_lstm(input,
"""
"""
assert
in_dygraph_mode
(
assert
in_dygraph_mode
(
)
is
not
True
,
"please use lstm instead of dynamic_lstm in dygraph mode!"
)
is
not
True
,
"please use lstm instead of dynamic_lstm in dygraph mode!"
assert
bias_attr
is
not
False
,
"bias_attr should not be False in dynamic_lstmp."
assert
bias_attr
is
not
False
,
"bias_attr should not be False in dynamic_lstm."
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
,
'float64'
],
'dynamic_lstm'
)
check_type
(
h_0
,
'h_0'
,
(
Variable
,
type
(
None
)),
'dynamic_lstm'
)
if
isinstance
(
h_0
,
Variable
):
check_variable_and_dtype
(
h_0
,
'h_0'
,
[
'float32'
,
'float64'
],
'dynamic_lstm'
)
check_type
(
c_0
,
'c_0'
,
(
Variable
,
type
(
None
)),
'dynamic_lstm'
)
if
isinstance
(
c_0
,
Variable
):
check_variable_and_dtype
(
c_0
,
'c_0'
,
[
'float32'
,
'float64'
],
'dynamic_lstm'
)
helper
=
LayerHelper
(
'lstm'
,
**
locals
())
helper
=
LayerHelper
(
'lstm'
,
**
locals
())
size
=
size
//
4
size
=
size
//
4
weight
=
helper
.
create_parameter
(
weight
=
helper
.
create_parameter
(
...
@@ -2439,6 +2453,20 @@ def dynamic_lstmp(input,
...
@@ -2439,6 +2453,20 @@ def dynamic_lstmp(input,
)
is
not
True
,
"please use lstm instead of dynamic_lstmp in dygraph mode!"
)
is
not
True
,
"please use lstm instead of dynamic_lstmp in dygraph mode!"
assert
bias_attr
is
not
False
,
"bias_attr should not be False in dynamic_lstmp."
assert
bias_attr
is
not
False
,
"bias_attr should not be False in dynamic_lstmp."
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
,
'float64'
],
'dynamic_lstmp'
)
check_type
(
h_0
,
'h_0'
,
(
Variable
,
type
(
None
)),
'dynamic_lstmp'
)
if
isinstance
(
h_0
,
Variable
):
check_variable_and_dtype
(
h_0
,
'h_0'
,
[
'float32'
,
'float64'
],
'dynamic_lstmp'
)
check_type
(
c_0
,
'c_0'
,
(
Variable
,
type
(
None
)),
'dynamic_lstmp'
)
if
isinstance
(
c_0
,
Variable
):
check_variable_and_dtype
(
c_0
,
'c_0'
,
[
'float32'
,
'float64'
],
'dynamic_lstmp'
)
helper
=
LayerHelper
(
'lstmp'
,
**
locals
())
helper
=
LayerHelper
(
'lstmp'
,
**
locals
())
size
=
size
//
4
size
=
size
//
4
weight
=
helper
.
create_parameter
(
weight
=
helper
.
create_parameter
(
...
...
python/paddle/fluid/tests/unittests/test_lstm_op.py
浏览文件 @
417b576c
...
@@ -301,6 +301,42 @@ class TestLstmOpCase3(TestLstmOp):
...
@@ -301,6 +301,42 @@ class TestLstmOpCase3(TestLstmOp):
self
.
lod
=
[[
2
,
0
,
4
]]
self
.
lod
=
[[
2
,
0
,
4
]]
class
TestLstmOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
def
test_Variable
():
input_data
=
np
.
random
.
random
((
1
,
2048
)).
astype
(
"float32"
)
fluid
.
layers
.
dynamic_lstm
(
input
=
input_data
,
size
=
2048
,
use_peepholes
=
False
)
self
.
assertRaises
(
TypeError
,
test_Variable
)
def
test_h_0
():
in_data
=
fluid
.
data
(
name
=
"input"
,
shape
=
[
None
,
2048
],
dtype
=
"float32"
)
h
=
fluid
.
data
(
name
=
"h"
,
shape
=
[
None
,
512
],
dtype
=
"int32"
)
c
=
fluid
.
data
(
name
=
"c"
,
shape
=
[
None
,
512
],
dtype
=
"float32"
)
fluid
.
layers
.
dynamic_lstm
(
input
=
in_data
,
size
=
2048
,
use_peepholes
=
False
,
h_0
=
h
,
c_0
=
c
)
self
.
assertRaises
(
TypeError
,
test_h_0
)
def
test_c_0
():
in_data_
=
fluid
.
data
(
name
=
"input_"
,
shape
=
[
None
,
2048
],
dtype
=
"float32"
)
h_
=
fluid
.
data
(
name
=
"h_"
,
shape
=
[
None
,
512
],
dtype
=
"float32"
)
c_
=
fluid
.
data
(
name
=
"c_"
,
shape
=
[
None
,
512
],
dtype
=
"int32"
)
fluid
.
layers
.
dynamic_lstm
(
input
=
in_data_
,
size
=
2048
,
use_peepholes
=
False
,
h_0
=
h_
,
c_0
=
c_
)
self
.
assertRaises
(
TypeError
,
test_c_0
)
# class TestLstmOpHasInitial(TestLstmOp):
# class TestLstmOpHasInitial(TestLstmOp):
# def set_argument(self):
# def set_argument(self):
# self.lod = [[2, 3, 2]]
# self.lod = [[2, 3, 2]]
...
...
python/paddle/fluid/tests/unittests/test_lstmp_op.py
浏览文件 @
417b576c
...
@@ -16,6 +16,8 @@ from __future__ import print_function
...
@@ -16,6 +16,8 @@ from __future__ import print_function
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
test_lstm_op
as
LstmTest
import
test_lstm_op
as
LstmTest
from
paddle
import
fluid
from
paddle.fluid
import
Program
,
program_guard
ACTIVATION
=
{
ACTIVATION
=
{
'identity'
:
LstmTest
.
identity
,
'identity'
:
LstmTest
.
identity
,
...
@@ -315,5 +317,59 @@ class TestLstmpOpLen0Case2(TestLstmpOp):
...
@@ -315,5 +317,59 @@ class TestLstmpOpLen0Case2(TestLstmpOp):
self
.
lod
=
[[
2
,
0
,
3
]]
self
.
lod
=
[[
2
,
0
,
3
]]
class
TestLstmpOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
def
test_Variable
():
input_data
=
np
.
random
.
random
((
1
,
2048
)).
astype
(
"float32"
)
fluid
.
layers
.
dynamic_lstmp
(
input
=
input_data
,
size
=
2048
,
proj_size
=
256
,
use_peepholes
=
False
,
is_reverse
=
True
,
cell_activation
=
"tanh"
,
proj_activation
=
"tanh"
)
self
.
assertRaises
(
TypeError
,
test_Variable
)
def
test_h_0
():
in_data
=
fluid
.
data
(
name
=
"input"
,
shape
=
[
None
,
2048
],
dtype
=
"float32"
)
h
=
fluid
.
data
(
name
=
"h"
,
shape
=
[
None
,
512
],
dtype
=
"int32"
)
c
=
fluid
.
data
(
name
=
"c"
,
shape
=
[
None
,
512
],
dtype
=
"float32"
)
fluid
.
layers
.
dynamic_lstmp
(
input
=
in_data
,
size
=
2048
,
proj_size
=
256
,
use_peepholes
=
False
,
is_reverse
=
True
,
cell_activation
=
"tanh"
,
proj_activation
=
"tanh"
,
h_0
=
h
,
c_0
=
c
)
self
.
assertRaises
(
TypeError
,
test_h_0
)
def
test_c_0
():
in_data_
=
fluid
.
data
(
name
=
"input_"
,
shape
=
[
None
,
2048
],
dtype
=
"float32"
)
h_
=
fluid
.
data
(
name
=
"h_"
,
shape
=
[
None
,
512
],
dtype
=
"float32"
)
c_
=
fluid
.
data
(
name
=
"c_"
,
shape
=
[
None
,
512
],
dtype
=
"int32"
)
fluid
.
layers
.
dynamic_lstmp
(
input
=
in_data_
,
size
=
2048
,
proj_size
=
256
,
use_peepholes
=
False
,
is_reverse
=
True
,
cell_activation
=
"tanh"
,
proj_activation
=
"tanh"
,
h_0
=
h_
,
c_0
=
c_
)
self
.
assertRaises
(
TypeError
,
test_c_0
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录