未验证 提交 7b5e23c0 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

OP(fusion_gru) error message enhancement. test=develop (#23599)

C++ OP enhancement.
上级 8c0bdde9
...@@ -23,68 +23,94 @@ namespace paddle { ...@@ -23,68 +23,94 @@ namespace paddle {
namespace operators { namespace operators {
void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of LSTM."); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_lstm");
PADDLE_ENFORCE(ctx->HasInput("WeightX"), OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_lstm");
"Assert only one Input(WeightX) of LSTM."); OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_lstm");
PADDLE_ENFORCE(ctx->HasInput("WeightH"), OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "fusion_lstm");
"Assert only one Input(WeightH) of LSTM."); OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_lstm");
PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM."); OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_lstm");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM."); OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "fusion_lstm");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Assert only one Output(Hidden) of LSTM.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Assert only one Output(Cell) of LSTM.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X)'s rank must be 2, but received x's rank "
"is:%d, x dim is:[%s]",
x_dims.size(), x_dims));
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"), OP_INOUT_CHECK(ctx->HasInput("C0"), "Input", "C0", "fusion_lstm");
"Input(Cell) and Input(Hidden) of LSTM should not "
"be null at the same time.");
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0"); auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims, PADDLE_ENFORCE_EQ(h_dims, c_dims,
"The dimension of Input(H0) and Input(C0) " platform::errors::InvalidArgument(
"should be the same."); "The dimension of Input(H0) and Input(C0) should be "
"same, but received h0 dims is:[%s], c0 dims is:[%s]",
h_dims, c_dims));
} }
auto wx_dims = ctx->GetInputDim("WeightX"); auto wx_dims = ctx->GetInputDim("WeightX");
PADDLE_ENFORCE_EQ(wx_dims.size(), 2, PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
"The rank of Input(WeightX) should be 2."); platform::errors::InvalidArgument(
"The rank of Input(WeightX) should be 2, but received "
"WeightX's rank is:%d, WeightX dim is:[%s]",
wx_dims.size(), wx_dims));
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1], PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
platform::errors::InvalidArgument(
"The first dimension of Input(WeightX) " "The first dimension of Input(WeightX) "
"should be %d.", "should equal to second dimension of Input(X), but "
x_dims[1]); "received WeightX first dim is:%d, X second dim is:%d",
wx_dims[0], x_dims[1]));
int frame_size = wx_dims[1] / 4; int frame_size = wx_dims[1] / 4;
auto wh_dims = ctx->GetInputDim("WeightH"); auto wh_dims = ctx->GetInputDim("WeightH");
PADDLE_ENFORCE_EQ(wh_dims.size(), 2, PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
"The rank of Input(WeightH) should be 2."); platform::errors::InvalidArgument(
"The rank of Input(WeightH) should be 2, but received "
"WeightH rank is:%d, WeightH dim is:[%s]",
wh_dims.size(), wh_dims));
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size, PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
platform::errors::InvalidArgument(
"The first dimension of Input(WeightH) " "The first dimension of Input(WeightH) "
"should be %d.", "should equal to frame size, but received WeightH "
frame_size); "first dim is:%d, frame size is:%d.",
wh_dims[0], frame_size));
PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size, PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size,
platform::errors::InvalidArgument(
"The second dimension of Input(WeightH) " "The second dimension of Input(WeightH) "
"should be 4 * %d.", "should equal to 4 * frame_size, but received WeightH "
frame_size); "second dimension is:%d, frame size is:%d.",
wh_dims[1], frame_size));
auto b_dims = ctx->GetInputDim("Bias"); auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(Bias) should be 2, but received "
"Bias rank is:%d, Bias dim is:[%s]",
b_dims.size(), b_dims));
PADDLE_ENFORCE_EQ(b_dims[0], 1, PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1."); platform::errors::InvalidArgument(
"The first dimension of Input(Bias) should be 1, but "
"received Bias's dimension is:[%s]",
b_dims));
if (ctx->Attrs().Get<bool>("use_peepholes")) { if (ctx->Attrs().Get<bool>("use_peepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
platform::errors::InvalidArgument(
"The second dimension of Input(Bias) should be " "The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection", "7 * %d if enable peepholes connection, but received "
frame_size); "Bias dim is:[%s]",
frame_size, b_dims));
ctx->SetOutputDim("CheckedCell", {2, frame_size}); ctx->SetOutputDim("CheckedCell", {2, frame_size});
} else { } else {
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, PADDLE_ENFORCE_EQ(
b_dims[1], 4 * frame_size,
platform::errors::InvalidArgument(
"The second dimension of Input(Bias) should be " "The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes", "4 * %d if disable peepholes, but received Bias dim is:[%s]",
frame_size); frame_size, b_dims));
} }
framework::DDim out_dims({x_dims[0], frame_size}); framework::DDim out_dims({x_dims[0], frame_size});
...@@ -97,16 +123,18 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -97,16 +123,18 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
xx_width = wx_dims[1]; xx_width = wx_dims[1];
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"),
"Assert only one Output(BatchedInput) of LSTM."); OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput",
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), "fusion_lstm");
"Assert only one Output(BatchedHidden) of LSTM."); OP_INOUT_CHECK(ctx->HasOutput("BatchedHidden"), "Output", "BatchedHidden",
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), "fusion_lstm");
"Assert only one Output(BatchedCell) of LSTM."); OP_INOUT_CHECK(ctx->HasOutput("BatchedCell"), "Output", "BatchedCell",
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), "fusion_lstm");
"Assert only one Output(ReorderedH0) of LSTM"); OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0",
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), "fusion_lstm");
"Assert only one Output(ReorderedC0) of LSTM."); OP_INOUT_CHECK(ctx->HasOutput("ReorderedC0"), "Output", "ReorderedC0",
"fusion_lstm");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims); ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchedCell", out_dims); ctx->SetOutputDim("BatchedCell", out_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册