Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7b5e23c0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7b5e23c0
编写于
4月 10, 2020
作者:
Z
zhaoyuchen2018
提交者:
GitHub
4月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
OP(fusion_gru) error message enhancement. test=develop (#23599)
C++ OP enhancement.
上级
8c0bdde9
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
76 addition
and
48 deletion
+76
-48
paddle/fluid/operators/fused/fusion_lstm_op.cc
paddle/fluid/operators/fused/fusion_lstm_op.cc
+76
-48
未找到文件。
paddle/fluid/operators/fused/fusion_lstm_op.cc
浏览文件 @
7b5e23c0
...
@@ -23,68 +23,94 @@ namespace paddle {
...
@@ -23,68 +23,94 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
void
FusionLSTMOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
void
FusionLSTMOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Assert only one Input(X) of LSTM."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"fusion_lstm"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightX"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"WeightX"
),
"Input"
,
"WeightX"
,
"fusion_lstm"
);
"Assert only one Input(WeightX) of LSTM."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"WeightH"
),
"Input"
,
"WeightH"
,
"fusion_lstm"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"WeightH"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias"
),
"Input"
,
"Bias"
,
"fusion_lstm"
);
"Assert only one Input(WeightH) of LSTM."
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"XX"
),
"Output"
,
"XX"
,
"fusion_lstm"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
"Assert only one Input(Bias) of LSTM."
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Hidden"
),
"Output"
,
"Hidden"
,
"fusion_lstm"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"XX"
),
"Assert only one Output(XX) of LSTM."
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Cell"
),
"Output"
,
"Cell"
,
"fusion_lstm"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Hidden"
),
"Assert only one Output(Hidden) of LSTM."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Cell"
),
"Assert only one Output(Cell) of LSTM."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X)'s rank must be 2."
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"Input(X)'s rank must be 2, but received x's rank "
"is:%d, x dim is:[%s]"
,
x_dims
.
size
(),
x_dims
));
if
(
ctx
->
HasInput
(
"H0"
))
{
if
(
ctx
->
HasInput
(
"H0"
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"C0"
),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"C0"
),
"Input"
,
"C0"
,
"fusion_lstm"
);
"Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time."
);
auto
h_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
h_dims
=
ctx
->
GetInputDim
(
"H0"
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
auto
c_dims
=
ctx
->
GetInputDim
(
"C0"
);
PADDLE_ENFORCE
(
h_dims
==
c_dims
,
PADDLE_ENFORCE_EQ
(
h_dims
,
c_dims
,
"The dimension of Input(H0) and Input(C0) "
platform
::
errors
::
InvalidArgument
(
"should be the same."
);
"The dimension of Input(H0) and Input(C0) should be "
"same, but received h0 dims is:[%s], c0 dims is:[%s]"
,
h_dims
,
c_dims
));
}
}
auto
wx_dims
=
ctx
->
GetInputDim
(
"WeightX"
);
auto
wx_dims
=
ctx
->
GetInputDim
(
"WeightX"
);
PADDLE_ENFORCE_EQ
(
wx_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
wx_dims
.
size
(),
2
,
"The rank of Input(WeightX) should be 2."
);
platform
::
errors
::
InvalidArgument
(
"The rank of Input(WeightX) should be 2, but received "
"WeightX's rank is:%d, WeightX dim is:[%s]"
,
wx_dims
.
size
(),
wx_dims
));
PADDLE_ENFORCE_EQ
(
wx_dims
[
0
],
x_dims
[
1
],
PADDLE_ENFORCE_EQ
(
wx_dims
[
0
],
x_dims
[
1
],
platform
::
errors
::
InvalidArgument
(
"The first dimension of Input(WeightX) "
"The first dimension of Input(WeightX) "
"should be %d."
,
"should equal to second dimension of Input(X), but "
x_dims
[
1
]);
"received WeightX first dim is:%d, X second dim is:%d"
,
wx_dims
[
0
],
x_dims
[
1
]));
int
frame_size
=
wx_dims
[
1
]
/
4
;
int
frame_size
=
wx_dims
[
1
]
/
4
;
auto
wh_dims
=
ctx
->
GetInputDim
(
"WeightH"
);
auto
wh_dims
=
ctx
->
GetInputDim
(
"WeightH"
);
PADDLE_ENFORCE_EQ
(
wh_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
wh_dims
.
size
(),
2
,
"The rank of Input(WeightH) should be 2."
);
platform
::
errors
::
InvalidArgument
(
"The rank of Input(WeightH) should be 2, but received "
"WeightH rank is:%d, WeightH dim is:[%s]"
,
wh_dims
.
size
(),
wh_dims
));
PADDLE_ENFORCE_EQ
(
wh_dims
[
0
],
frame_size
,
PADDLE_ENFORCE_EQ
(
wh_dims
[
0
],
frame_size
,
platform
::
errors
::
InvalidArgument
(
"The first dimension of Input(WeightH) "
"The first dimension of Input(WeightH) "
"should be %d."
,
"should equal to frame size, but received WeightH "
frame_size
);
"first dim is:%d, frame size is:%d."
,
wh_dims
[
0
],
frame_size
));
PADDLE_ENFORCE_EQ
(
wh_dims
[
1
],
4
*
frame_size
,
PADDLE_ENFORCE_EQ
(
wh_dims
[
1
],
4
*
frame_size
,
platform
::
errors
::
InvalidArgument
(
"The second dimension of Input(WeightH) "
"The second dimension of Input(WeightH) "
"should be 4 * %d."
,
"should equal to 4 * frame_size, but received WeightH "
frame_size
);
"second dimension is:%d, frame size is:%d."
,
wh_dims
[
1
],
frame_size
));
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
auto
b_dims
=
ctx
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
"The rank of Input(Bias) should be 2."
);
PADDLE_ENFORCE_EQ
(
b_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The rank of Input(Bias) should be 2, but received "
"Bias rank is:%d, Bias dim is:[%s]"
,
b_dims
.
size
(),
b_dims
));
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
PADDLE_ENFORCE_EQ
(
b_dims
[
0
],
1
,
"The first dimension of Input(Bias) should be 1."
);
platform
::
errors
::
InvalidArgument
(
"The first dimension of Input(Bias) should be 1, but "
"received Bias's dimension is:[%s]"
,
b_dims
));
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
))
{
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"use_peepholes"
))
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
7
*
frame_size
,
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
7
*
frame_size
,
platform
::
errors
::
InvalidArgument
(
"The second dimension of Input(Bias) should be "
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection"
,
"7 * %d if enable peepholes connection, but received "
frame_size
);
"Bias dim is:[%s]"
,
frame_size
,
b_dims
));
ctx
->
SetOutputDim
(
"CheckedCell"
,
{
2
,
frame_size
});
ctx
->
SetOutputDim
(
"CheckedCell"
,
{
2
,
frame_size
});
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
PADDLE_ENFORCE_EQ
(
b_dims
[
1
],
4
*
frame_size
,
platform
::
errors
::
InvalidArgument
(
"The second dimension of Input(Bias) should be "
"The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes
"
,
"4 * %d if disable peepholes, but received Bias dim is:[%s]
"
,
frame_size
);
frame_size
,
b_dims
)
);
}
}
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
framework
::
DDim
out_dims
({
x_dims
[
0
],
frame_size
});
...
@@ -97,16 +123,18 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
...
@@ -97,16 +123,18 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
xx_width
=
wx_dims
[
1
];
xx_width
=
wx_dims
[
1
];
}
else
{
}
else
{
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
xx_width
=
x_dims
[
1
]
>
wx_dims
[
1
]
?
wx_dims
[
1
]
:
x_dims
[
1
];
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Assert only one Output(BatchedInput) of LSTM."
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchedInput"
),
"Output"
,
"BatchedInput"
,
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedHidden"
),
"fusion_lstm"
);
"Assert only one Output(BatchedHidden) of LSTM."
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchedHidden"
),
"Output"
,
"BatchedHidden"
,
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"BatchedCell"
),
"fusion_lstm"
);
"Assert only one Output(BatchedCell) of LSTM."
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"BatchedCell"
),
"Output"
,
"BatchedCell"
,
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"fusion_lstm"
);
"Assert only one Output(ReorderedH0) of LSTM"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"ReorderedH0"
),
"Output"
,
"ReorderedH0"
,
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"ReorderedC0"
),
"fusion_lstm"
);
"Assert only one Output(ReorderedC0) of LSTM."
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"ReorderedC0"
),
"Output"
,
"ReorderedC0"
,
"fusion_lstm"
);
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedInput"
,
{
x_dims
[
0
],
wx_dims
[
1
]});
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedHidden"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedCell"
,
out_dims
);
ctx
->
SetOutputDim
(
"BatchedCell"
,
out_dims
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录