Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
417b576c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
417b576c
编写于
5月 13, 2020
作者:
L
liu zhengxi
提交者:
GitHub
5月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
API(dynamic_lstm, dynamic_lstmp) error message enhancement (#24450)
* update err msg for dynamic_lstm and dynamic_lstmp, test=develop
上级
53bdee64
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
269 addition
and
123 deletion
+269
-123
paddle/fluid/operators/lstm_op.cc
paddle/fluid/operators/lstm_op.cc
+67
-56
paddle/fluid/operators/lstmp_op.cc
paddle/fluid/operators/lstmp_op.cc
+81
-66
python/paddle/fluid/layers/rnn.py
python/paddle/fluid/layers/rnn.py
+29
-1
python/paddle/fluid/tests/unittests/test_lstm_op.py
python/paddle/fluid/tests/unittests/test_lstm_op.py
+36
-0
python/paddle/fluid/tests/unittests/test_lstmp_op.py
python/paddle/fluid/tests/unittests/test_lstmp_op.py
+56
-0
未找到文件。
paddle/fluid/operators/lstm_op.cc
浏览文件 @
417b576c
...
...
@@ -24,64 +24,80 @@ class LSTMOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(Weight) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Input(Bias) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
"Output(Hidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Output(Cell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchGate"
),
"Output(BatchGate) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
"Output(BatchGate) of LSTM should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
"LSTM"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Weight"
),
"Input"
,
"Weight"
,
"LSTM"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias"
),
"Input"
,
"Bias"
,
"LSTM"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Hidden"
),
"Output"
,
"Hidden"
,
"LSTM"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Cell"
),
"Output"
,
"Cell"
,
"LSTM"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchGate"
),
"Output"
,
"BatchGate"
,
"LSTM"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
"Output"
,
"BatchCellPreAct"
,
"LSTM"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"Input(X)'s rank must be 2, but received %d."
,
in_dims
.
size
()));
if
(
ctx
->
HasInput
(
"H0"
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
"Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time."
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"C0"
),
true
,
platform
::
errors
::
NotFound
(
"Input(Cell) and Input(Hidden) of LSTM "
"should not be null at the same time."
));
auto
h_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
PADDLE_ENFORCE
(
h_dims
==
c_dims
,
"The dimension of Input(H0) and Input(C0) "
"should be the same."
);
PADDLE_ENFORCE_EQ
(
h_dims
,
c_dims
,
platform
::
errors
::
InvalidArgument
(
"The dimension of Input(H0) and Input(C0) should "
"be the same, but received [%s] (H0) vs [%s] (C0)."
,
h_dims
,
c_dims
));
}
int
frame_size
=
in_dims
[
1
]
/
4
;
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
"The rank of Input(Weight) should be 2."
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The rank of Input(Weight) should be 2, but received %d."
,
w_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
w_dims
[
0
],
frame_size
,
"The first dimension of Input(Weight) "
"should be %d."
,
frame_size
);
platform
::
errors
::
InvalidArgument
(
"The first dimension of Input(Weight) should be %d, "
"but received %d."
,
frame_size
,
w_dims
[
0
]));
PADDLE_ENFORCE_EQ
(
w_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(Weight) "
"should be 4 * %d."
,
frame_size
);
platform
::
errors
::
InvalidArgument
(
"The second dimension of Input(Weight) should be 4 * "
"%d, but received %d."
,
frame_size
,
w_dims
[
1
]));
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The rank of Input(Bias) should be 2, but received %d."
,
b_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
platform
::
errors
::
InvalidArgument
(
"The first dimension of Input(Bias) should be 1, but received %d."
,
b_dims
[
0
]));
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
))
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
7
*
frame_size
,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection"
,
frame_size
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
7
*
frame_size
,
platform
::
errors
::
InvalidArgument
(
"The second dimension of Input(Bias) should be 7 * %d if enable "
"peepholes connection, but received %d."
,
frame_size
,
b_dims
[
1
]));
}
else
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection"
,
frame_size
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
platform
::
errors
::
InvalidArgument
(
"The second dimension of Input(Bias) should be 4 * %d if disable "
"peepholes connection, but received %d."
,
frame_size
,
b_dims
[
1
]));
}
framework
::
DDim
out_dims
({
in_dims
[
0
],
frame_size
});
...
...
@@ -229,21 +245,16 @@ class LSTMGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Hidden"
),
"Input(Hidden) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Cell"
),
"Input(Cell) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(Weight) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Input(Bias) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchGate"
),
"Input(BatchGate) of LSTM should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchCellPreAct"
),
"Input(BatchGate) of LSTM should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
"LSTM@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Hidden"
),
"Input"
,
"Hidden"
,
"LSTM@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Cell"
),
"Input"
,
"Cell"
,
"LSTM@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Weight"
),
"Input"
,
"Weight"
,
"LSTM@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias"
),
"Input"
,
"Bias"
,
"LSTM@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BatchGate"
),
"Input"
,
"BatchGate"
,
"LSTM@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BatchCellPreAct"
),
"Input"
,
"BatchCellPreAct"
,
"LSTM@Grad"
);
auto
SetOutGradDim
=
[
&
ctx
](
const
std
::
string
&
name
)
{
auto
g_name
=
framework
::
GradVarName
(
name
);
...
...
paddle/fluid/operators/lstmp_op.cc
浏览文件 @
417b576c
...
...
@@ -24,74 +24,92 @@ class LSTMPOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Input"
),
"Input(Input) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(Weight) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ProjWeight"
),
"Input(ProjWeight) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Input(Bias) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Projection"
),
"Output(Projection) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Output(Cell) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchGate"
),
"Output(BatchGate) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
"Output(BatchCellPreAct) of LSTMP operator should not be "
"null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchHidden"
),
"Output(BatchHidden) of LSTMP operator should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input"
),
"Input"
,
"Input"
,
"LSTMP"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Weight"
),
"Input"
,
"Weight"
,
"LSTMP"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ProjWeight"
),
"Input"
,
"ProjWeight"
,
"LSTMP"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias"
),
"Input"
,
"Bias"
,
"LSTMP"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Projection"
),
"Output"
,
"Projection"
,
"LSTMP"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Cell"
),
"Output"
,
"Cell"
,
"LSTMP"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchGate"
),
"Output"
,
"BatchGate"
,
"LSTMP"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchCellPreAct"
),
"Output"
,
"BatchCellPreAct"
,
"LSTMP"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchHidden"
),
"Output"
,
"BatchHidden"
,
"LSTMP"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
2
,
"Input(X)'s rank of LSTMP operator must be 2."
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"Input(X)'s rank of LSTMP operator must be 2, but received %d."
,
in_dims
.
size
()));
int
frame_size
=
in_dims
[
1
]
/
4
;
auto
w_dims
=
ctx
->
GetInputDim
(
"Weight"
);
auto
proj_dims
=
ctx
->
GetInputDim
(
"ProjWeight"
);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
"The rank of Input(Weight) should be 2."
);
PADDLE_ENFORCE_EQ
(
w_dims
[
0
],
proj_dims
[
1
],
"The first dimension of Input(Weight) "
"should be %d."
,
proj_dims
[
1
]);
PADDLE_ENFORCE_EQ
(
w_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The rank of Input(Weight) should be 2, but received %d."
,
w_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
w_dims
[
0
],
proj_dims
[
1
],
platform
::
errors
::
InvalidArgument
(
"The first dimension of Input(Weight) and the second dimension of "
"Input(ProjWeight) should be the same, but received %d vs %d."
,
w_dims
[
0
],
proj_dims
[
1
]));
PADDLE_ENFORCE_EQ
(
w_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(Weight) "
"should be 4 * %d."
,
frame_size
);
PADDLE_ENFORCE_EQ
(
proj_dims
.
size
(),
2
,
"The rank of Input(ProjWeight) should be 2."
);
platform
::
errors
::
InvalidArgument
(
"The second dimension of Input(Weight) should be 4 * "
"%d, but received %d."
,
frame_size
,
w_dims
[
1
]));
PADDLE_ENFORCE_EQ
(
proj_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The rank of Input(ProjWeight) should be 2, but received %d."
,
proj_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
proj_dims
[
0
],
frame_size
,
"The first dimension of Input(ProjWeight) "
"should be %d."
,
frame_size
);
platform
::
errors
::
InvalidArgument
(
"The first dimension of Input(ProjWeight) should be "
"%d, but received %d."
,
frame_size
,
proj_dims
[
0
]));
if
(
ctx
->
HasInput
(
"H0"
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
"Input(C0) of LSTMP operator should not be null after "
"Input(H0) provided."
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"C0"
),
true
,
platform
::
errors
::
NotFound
(
"Input(C0) of LSTMP operator should not "
"be null after Input(H0) provided."
));
}
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The rank of Input(Bias) should be 2, but received %d."
,
b_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
platform
::
errors
::
InvalidArgument
(
"The first dimension of Input(Bias) should be 1, but received %d."
,
b_dims
[
0
]));
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
))
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
7
*
frame_size
,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection"
,
frame_size
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
7
*
frame_size
,
platform
::
errors
::
InvalidArgument
(
"The second dimension of Input(Bias) should be 7 * %d if enable "
"peepholes connection, but received %d."
,
frame_size
,
b_dims
[
1
]));
}
else
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection"
,
frame_size
);
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
platform
::
errors
::
InvalidArgument
(
"The second dimension of Input(Bias) should be 4 * %d if disable "
"peepholes connection, but received %d."
,
frame_size
,
b_dims
[
1
]));
}
framework
::
DDim
out_dims
({
in_dims
[
0
],
frame_size
});
...
...
@@ -314,21 +332,18 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Projection"
),
"Input(Projection) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Cell"
),
"Input(Cell) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Weight"
),
"Input(Weight) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"ProjWeight"
),
"Input(ProjWeight) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Input(Bias) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchGate"
),
"Input(BatchGate) of LSTMP operator should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"BatchCellPreAct"
),
"Input(BatchGate) of LSTMP operator should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Projection"
),
"Input"
,
"Projection"
,
"LSTMP@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Cell"
),
"Input"
,
"Cell"
,
"LSTMP@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Weight"
),
"Input"
,
"Weight"
,
"LSTMP@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ProjWeight"
),
"Input"
,
"ProjWeight"
,
"LSTMP@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias"
),
"Input"
,
"Bias"
,
"LSTMP@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BatchGate"
),
"Input"
,
"BatchGate"
,
"LSTMP@Grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"BatchCellPreAct"
),
"Input"
,
"BatchCellPreAct"
,
"LSTMP@Grad"
);
auto
SetOutGradDim
=
[
&
ctx
](
const
std
::
string
&
name
)
{
auto
g_name
=
framework
::
GradVarName
(
name
);
...
...
python/paddle/fluid/layers/rnn.py
浏览文件 @
417b576c
...
...
@@ -2073,7 +2073,21 @@ def dynamic_lstm(input,
"""
assert
in_dygraph_mode
(
)
is
not
True
,
"please use lstm instead of dynamic_lstm in dygraph mode!"
assert
bias_attr
is
not
False
,
"bias_attr should not be False in dynamic_lstmp."
assert
bias_attr
is
not
False
,
"bias_attr should not be False in dynamic_lstm."
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
,
'float64'
],
'dynamic_lstm'
)
check_type
(
h_0
,
'h_0'
,
(
Variable
,
type
(
None
)),
'dynamic_lstm'
)
if
isinstance
(
h_0
,
Variable
):
check_variable_and_dtype
(
h_0
,
'h_0'
,
[
'float32'
,
'float64'
],
'dynamic_lstm'
)
check_type
(
c_0
,
'c_0'
,
(
Variable
,
type
(
None
)),
'dynamic_lstm'
)
if
isinstance
(
c_0
,
Variable
):
check_variable_and_dtype
(
c_0
,
'c_0'
,
[
'float32'
,
'float64'
],
'dynamic_lstm'
)
helper
=
LayerHelper
(
'lstm'
,
**
locals
())
size
=
size
//
4
weight
=
helper
.
create_parameter
(
...
...
@@ -2439,6 +2453,20 @@ def dynamic_lstmp(input,
)
is
not
True
,
"please use lstm instead of dynamic_lstmp in dygraph mode!"
assert
bias_attr
is
not
False
,
"bias_attr should not be False in dynamic_lstmp."
check_variable_and_dtype
(
input
,
'input'
,
[
'float32'
,
'float64'
],
'dynamic_lstmp'
)
check_type
(
h_0
,
'h_0'
,
(
Variable
,
type
(
None
)),
'dynamic_lstmp'
)
if
isinstance
(
h_0
,
Variable
):
check_variable_and_dtype
(
h_0
,
'h_0'
,
[
'float32'
,
'float64'
],
'dynamic_lstmp'
)
check_type
(
c_0
,
'c_0'
,
(
Variable
,
type
(
None
)),
'dynamic_lstmp'
)
if
isinstance
(
c_0
,
Variable
):
check_variable_and_dtype
(
c_0
,
'c_0'
,
[
'float32'
,
'float64'
],
'dynamic_lstmp'
)
helper
=
LayerHelper
(
'lstmp'
,
**
locals
())
size
=
size
//
4
weight
=
helper
.
create_parameter
(
...
...
python/paddle/fluid/tests/unittests/test_lstm_op.py
浏览文件 @
417b576c
...
...
@@ -301,6 +301,42 @@ class TestLstmOpCase3(TestLstmOp):
self
.
lod
=
[[
2
,
0
,
4
]]
class
TestLstmOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
def
test_Variable
():
input_data
=
np
.
random
.
random
((
1
,
2048
)).
astype
(
"float32"
)
fluid
.
layers
.
dynamic_lstm
(
input
=
input_data
,
size
=
2048
,
use_peepholes
=
False
)
self
.
assertRaises
(
TypeError
,
test_Variable
)
def
test_h_0
():
in_data
=
fluid
.
data
(
name
=
"input"
,
shape
=
[
None
,
2048
],
dtype
=
"float32"
)
h
=
fluid
.
data
(
name
=
"h"
,
shape
=
[
None
,
512
],
dtype
=
"int32"
)
c
=
fluid
.
data
(
name
=
"c"
,
shape
=
[
None
,
512
],
dtype
=
"float32"
)
fluid
.
layers
.
dynamic_lstm
(
input
=
in_data
,
size
=
2048
,
use_peepholes
=
False
,
h_0
=
h
,
c_0
=
c
)
self
.
assertRaises
(
TypeError
,
test_h_0
)
def
test_c_0
():
in_data_
=
fluid
.
data
(
name
=
"input_"
,
shape
=
[
None
,
2048
],
dtype
=
"float32"
)
h_
=
fluid
.
data
(
name
=
"h_"
,
shape
=
[
None
,
512
],
dtype
=
"float32"
)
c_
=
fluid
.
data
(
name
=
"c_"
,
shape
=
[
None
,
512
],
dtype
=
"int32"
)
fluid
.
layers
.
dynamic_lstm
(
input
=
in_data_
,
size
=
2048
,
use_peepholes
=
False
,
h_0
=
h_
,
c_0
=
c_
)
self
.
assertRaises
(
TypeError
,
test_c_0
)
# class TestLstmOpHasInitial(TestLstmOp):
# def set_argument(self):
# self.lod = [[2, 3, 2]]
...
...
python/paddle/fluid/tests/unittests/test_lstmp_op.py
浏览文件 @
417b576c
...
...
@@ -16,6 +16,8 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
import
test_lstm_op
as
LstmTest
from
paddle
import
fluid
from
paddle.fluid
import
Program
,
program_guard
ACTIVATION
=
{
'identity'
:
LstmTest
.
identity
,
...
...
@@ -315,5 +317,59 @@ class TestLstmpOpLen0Case2(TestLstmpOp):
self
.
lod
=
[[
2
,
0
,
3
]]
class
TestLstmpOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
def
test_Variable
():
input_data
=
np
.
random
.
random
((
1
,
2048
)).
astype
(
"float32"
)
fluid
.
layers
.
dynamic_lstmp
(
input
=
input_data
,
size
=
2048
,
proj_size
=
256
,
use_peepholes
=
False
,
is_reverse
=
True
,
cell_activation
=
"tanh"
,
proj_activation
=
"tanh"
)
self
.
assertRaises
(
TypeError
,
test_Variable
)
def
test_h_0
():
in_data
=
fluid
.
data
(
name
=
"input"
,
shape
=
[
None
,
2048
],
dtype
=
"float32"
)
h
=
fluid
.
data
(
name
=
"h"
,
shape
=
[
None
,
512
],
dtype
=
"int32"
)
c
=
fluid
.
data
(
name
=
"c"
,
shape
=
[
None
,
512
],
dtype
=
"float32"
)
fluid
.
layers
.
dynamic_lstmp
(
input
=
in_data
,
size
=
2048
,
proj_size
=
256
,
use_peepholes
=
False
,
is_reverse
=
True
,
cell_activation
=
"tanh"
,
proj_activation
=
"tanh"
,
h_0
=
h
,
c_0
=
c
)
self
.
assertRaises
(
TypeError
,
test_h_0
)
def
test_c_0
():
in_data_
=
fluid
.
data
(
name
=
"input_"
,
shape
=
[
None
,
2048
],
dtype
=
"float32"
)
h_
=
fluid
.
data
(
name
=
"h_"
,
shape
=
[
None
,
512
],
dtype
=
"float32"
)
c_
=
fluid
.
data
(
name
=
"c_"
,
shape
=
[
None
,
512
],
dtype
=
"int32"
)
fluid
.
layers
.
dynamic_lstmp
(
input
=
in_data_
,
size
=
2048
,
proj_size
=
256
,
use_peepholes
=
False
,
is_reverse
=
True
,
cell_activation
=
"tanh"
,
proj_activation
=
"tanh"
,
h_0
=
h_
,
c_0
=
c_
)
self
.
assertRaises
(
TypeError
,
test_c_0
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录