未验证 提交 39075b3d 编写于 作者: G GaoWei8 提交者: GitHub

[Cherry-Pick] [2.0-beta] error enhancement of Print, fused_embedding_fc_lstm...

[Cherry-Pick] [2.0-beta] error enhancement of Print, fused_embedding_fc_lstm and fusion_gru (#24097)
上级 7dd68aec
...@@ -24,68 +24,94 @@ namespace operators { ...@@ -24,68 +24,94 @@ namespace operators {
void FusedEmbeddingFCLSTMOp::InferShape( void FusedEmbeddingFCLSTMOp::InferShape(
framework::InferShapeContext* ctx) const { framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Embeddings"), OP_INOUT_CHECK(ctx->HasInput("Embeddings"), "Input", "Embeddings",
"Assert only one Input(Embeddings) of LSTM."); "fused_embedding_fc_lstm");
PADDLE_ENFORCE(ctx->HasInput("WeightH"), OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH",
"Assert only one Input(WeightH) of LSTM."); "fused_embedding_fc_lstm");
PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM."); OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias",
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM."); "fused_embedding_fc_lstm");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX",
"Assert only one Output(Hidden) of LSTM."); "fused_embedding_fc_lstm");
PADDLE_ENFORCE(ctx->HasOutput("Cell"), OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden",
"Assert only one Output(Cell) of LSTM."); "fused_embedding_fc_lstm");
PADDLE_ENFORCE(ctx->HasInput("Ids"), OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell",
"Input(Ids) of LookupTableOp should not be null."); "fused_embedding_fc_lstm");
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids",
"fused_embedding_fc_lstm");
auto table_dims = ctx->GetInputDim("Embeddings"); auto table_dims = ctx->GetInputDim("Embeddings");
auto ids_dims = ctx->GetInputDim("Ids"); auto ids_dims = ctx->GetInputDim("Ids");
int ids_rank = ids_dims.size(); int ids_rank = ids_dims.size();
PADDLE_ENFORCE_EQ(table_dims.size(), 2); PADDLE_ENFORCE_EQ(
table_dims.size(), 2,
platform::errors::InvalidArgument(
"The Embeddings's rank should be 2, but received value is:%d.",
table_dims.size()));
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1, PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
"The last dimension of the 'Ids' tensor must be 1."); platform::errors::InvalidArgument(
"The last dimension of the 'Ids' tensor must be 1, but "
"received value is:%d.",
ids_dims[ids_rank - 1]));
auto x_dims = ctx->GetInputDim("Ids"); auto x_dims = ctx->GetInputDim("Ids");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(Ids)'s rank must be 2."); PADDLE_ENFORCE_EQ(
x_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(Ids)'s rank must be 2, but received value is:%d.",
x_dims.size()));
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"), PADDLE_ENFORCE_EQ(ctx->HasInput("C0"), true,
"Input(Cell) and Input(Hidden) of LSTM should not " platform::errors::InvalidArgument(
"be null at the same time."); "Input(Cell) and Input(Hidden) of LSTM should exist "
"at the same time."));
auto h_dims = ctx->GetInputDim("H0"); auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0"); auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims, PADDLE_ENFORCE_EQ(
"The dimension of Input(H0) and Input(C0) " h_dims, c_dims,
"should be the same."); platform::errors::InvalidArgument(
"The dimension of Input(H0) and Input(C0) "
"should be the same, but received H0 dim is:[%s], C0 dim is[%s]",
h_dims, c_dims));
} }
auto embeddings_dims = ctx->GetInputDim("Embeddings");
PADDLE_ENFORCE_EQ(embeddings_dims.size(), 2,
"The rank of Input(Embeddings) should be 2.");
auto wh_dims = ctx->GetInputDim("WeightH"); auto wh_dims = ctx->GetInputDim("WeightH");
int frame_size = wh_dims[1] / 4; int frame_size = wh_dims[1] / 4;
PADDLE_ENFORCE_EQ(wh_dims.size(), 2, PADDLE_ENFORCE_EQ(
"The rank of Input(WeightH) should be 2."); wh_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(WeightH) should be 2, but received value is:%d.",
wh_dims.size()));
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size, PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
"The first dimension of Input(WeightH) " platform::errors::InvalidArgument(
"should be %d.", "The first dimension of Input(WeightH) should equal to "
frame_size); "frame size:%d, but received value is:%d.",
frame_size, wh_dims[0]));
PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size, PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size,
"The second dimension of Input(WeightH) " platform::errors::InvalidArgument(
"should be 4 * %d.", "The second dimension of Input(WeightH) should equal "
frame_size); "to 4 * %d, but received value is:%d.",
frame_size, wh_dims[1]));
auto b_dims = ctx->GetInputDim("Bias"); auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(b_dims[0], 1, b_dims.size(), 2,
"The first dimension of Input(Bias) should be 1."); platform::errors::InvalidArgument(
"The rank of Input(Bias) should be 2, but received value is:%d.",
b_dims.size()));
PADDLE_ENFORCE_EQ(b_dims[0], 1, platform::errors::InvalidArgument(
"The first dimension of Input(Bias) "
"should be 1, but received value is:%d.",
b_dims[0]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
b_dims[1], (ctx->Attrs().Get<bool>("use_peepholes") ? 7 : 4) * frame_size, b_dims[1], (ctx->Attrs().Get<bool>("use_peepholes") ? 7 : 4) * frame_size,
"The second dimension of Input(Bias) should be " platform::errors::InvalidArgument(
"7 * %d if enable peepholes connection or" "The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes", "7 * %d if enable peepholes connection or"
frame_size, frame_size); "4 * %d if disable peepholes, bias dim is:%d, use_peepholes:%d",
frame_size, frame_size, b_dims[1],
ctx->Attrs().Get<bool>("use_peepholes")));
framework::DDim out_dims({x_dims[0], frame_size}); framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Hidden", out_dims);
...@@ -93,16 +119,17 @@ void FusedEmbeddingFCLSTMOp::InferShape( ...@@ -93,16 +119,17 @@ void FusedEmbeddingFCLSTMOp::InferShape(
ctx->ShareLoD("Ids", "Hidden"); ctx->ShareLoD("Ids", "Hidden");
ctx->ShareLoD("Ids", "Cell"); ctx->ShareLoD("Ids", "Cell");
if (!ctx->Attrs().Get<bool>("use_seq")) { if (!ctx->Attrs().Get<bool>("use_seq")) {
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput",
"Assert only one Output(BatchedInput) of LSTM."); "fused_embedding_fc_lstm");
PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), OP_INOUT_CHECK(ctx->HasOutput("BatchedHidden"), "Output", "BatchedHidden",
"Assert only one Output(BatchedHidden) of LSTM."); "fused_embedding_fc_lstm");
PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), OP_INOUT_CHECK(ctx->HasOutput("BatchedCell"), "Output", "BatchedCell",
"Assert only one Output(BatchedCell) of LSTM."); "fused_embedding_fc_lstm");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0",
"Assert only one Output(ReorderedH0) of LSTM"); "fused_embedding_fc_lstm");
PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), OP_INOUT_CHECK(ctx->HasOutput("ReorderedC0"), "Output", "ReorderedC0",
"Assert only one Output(ReorderedC0) of LSTM."); "fused_embedding_fc_lstm");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wh_dims[1]}); ctx->SetOutputDim("BatchedInput", {x_dims[0], wh_dims[1]});
ctx->SetOutputDim("BatchedHidden", out_dims); ctx->SetOutputDim("BatchedHidden", out_dims);
ctx->SetOutputDim("BatchedCell", out_dims); ctx->SetOutputDim("BatchedCell", out_dims);
......
...@@ -24,51 +24,80 @@ namespace paddle { ...@@ -24,51 +24,80 @@ namespace paddle {
namespace operators { namespace operators {
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of GRU."); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_gru");
PADDLE_ENFORCE(ctx->HasInput("WeightX"), OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_gru");
"Assert only one Input(WeightX) of GRU."); OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_gru");
PADDLE_ENFORCE(ctx->HasInput("WeightH"),
"Assert only one Input(WeightH) of GRU."); OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_gru");
PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of GRU."); OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_gru");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Assert only one Output(Hidden) of GRU.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(x_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X)'s rank must be 2, but received input dim "
"size is:%d, input dim is:[%s]",
x_dims.size(), x_dims));
auto wx_dims = ctx->GetInputDim("WeightX"); auto wx_dims = ctx->GetInputDim("WeightX");
PADDLE_ENFORCE_EQ(wx_dims.size(), 2, PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
"The rank of Input(WeightX) should be 2."); platform::errors::InvalidArgument(
"The rank of Input(WeightX) should be 2, but received "
"WeightX dim size is:%d, WeightX dim is:[%s] ",
wx_dims.size(), wx_dims));
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1], PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
"The first dimension of Input(WeightX) " platform::errors::InvalidArgument(
"should be %d.", "The first dimension of Input(WeightX) "
x_dims[1]); "should equal to second dimension of input x, but "
"received WeightX dimension is:%d, x dimension is:%d",
wx_dims[0], x_dims[1]));
int frame_size = wx_dims[1] / 3; int frame_size = wx_dims[1] / 3;
auto wh_dims = ctx->GetInputDim("WeightH"); auto wh_dims = ctx->GetInputDim("WeightH");
PADDLE_ENFORCE_EQ(wh_dims.size(), 2, PADDLE_ENFORCE_EQ(wh_dims.size(), 2,
"The rank of Input(WeightH) should be 2."); platform::errors::InvalidArgument(
"The rank of Input(WeightH) should be 2, but received "
"WeightH dim size is:%d, WeightH dim is:[%s]",
wh_dims.size(), wh_dims));
PADDLE_ENFORCE_EQ(wh_dims[0], frame_size, PADDLE_ENFORCE_EQ(wh_dims[0], frame_size,
"The first dimension of Input(WeightH) " platform::errors::InvalidArgument(
"should be %d.", "The first dimension of WeightH "
frame_size); "should equal to frame_size, but received WeightH's "
"first dimension is: "
"%d, frame size is:%d",
wh_dims[0], frame_size));
PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size, PADDLE_ENFORCE_EQ(wh_dims[1], 3 * frame_size,
"The second dimension of Input(WeightH) " platform::errors::InvalidArgument(
"should be 3 * %d.", "The second dimension of Input(WeightH) "
frame_size); "should equal to 3 * frame_size, but received WeightH "
"is:%d, frame size is:%d",
wh_dims[1], frame_size));
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
auto h0_dims = ctx->GetInputDim("H0"); auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size, PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
"The width of H0 must be equal to frame_size."); platform::errors::InvalidArgument(
"The width of H0 must be equal to frame_size, but "
"receiced the width of H0 is:%d, frame size is:%d",
h0_dims[1], frame_size));
} }
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
auto b_dims = ctx->GetInputDim("Bias"); auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2,
platform::errors::InvalidArgument(
"The rank of Input(Bias) should be 2, but received "
"Bias rank is:%d, Bias dim is:[%s]",
b_dims.size(), b_dims));
PADDLE_ENFORCE_EQ(b_dims[0], 1, PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1."); platform::errors::InvalidArgument(
"The first dimension of Input(Bias) should be 1, but "
"received Bias first dim is:%d, Bias dim is:[%s]",
b_dims[0], b_dims));
PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3, PADDLE_ENFORCE_EQ(b_dims[1], frame_size * 3,
"The shape of Bias must be [1, frame_size * 3]."); platform::errors::InvalidArgument(
"The shape of Bias must be [1, frame_size * 3], but "
"received bias dim is:[%s], frame size is:%d",
b_dims, frame_size));
} }
framework::DDim out_dims({x_dims[0], frame_size}); framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Hidden", out_dims);
...@@ -78,12 +107,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -78,12 +107,12 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
xx_width = wx_dims[1]; xx_width = wx_dims[1];
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0",
"Assert only one Output(ReorderedH0) of GRU."); "fusion_gru");
PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput",
"Assert only one Output(BatchedInput) of GRU."); "fusion_gru");
PADDLE_ENFORCE(ctx->HasOutput("BatchedOut"), OP_INOUT_CHECK(ctx->HasOutput("BatchedOut"), "Output", "BatchedOut",
"Assert only one Output(BatchedOut) of GRU."); "fusion_gru");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims); ctx->SetOutputDim("BatchedOut", out_dims);
} }
......
...@@ -148,10 +148,14 @@ class PrintOp : public framework::OperatorBase { ...@@ -148,10 +148,14 @@ class PrintOp : public framework::OperatorBase {
const platform::Place &place) const override { const platform::Place &place) const override {
const auto in_var = scope.FindVar(Input("In")); const auto in_var = scope.FindVar(Input("In"));
auto out_var = scope.FindVar(Output("Out")); auto out_var = scope.FindVar(Output("Out"));
PADDLE_ENFORCE_NOT_NULL(in_var, "The input should not be found in scope",
Input("In")); PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE_NOT_NULL(out_var, "The output should not be found in scope", in_var, platform::errors::NotFound("The input:%s not found in scope",
Output("Out")); Input("In")));
PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::NotFound("The output:%s not found in scope",
Output("Out")));
auto &in_tensor = in_var->Get<framework::LoDTensor>(); auto &in_tensor = in_var->Get<framework::LoDTensor>();
framework::LoDTensor *out_tensor = framework::LoDTensor *out_tensor =
out_var->GetMutable<framework::LoDTensor>(); out_var->GetMutable<framework::LoDTensor>();
...@@ -246,8 +250,8 @@ class PrintOpInferShape : public framework::InferShapeBase { ...@@ -246,8 +250,8 @@ class PrintOpInferShape : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { void operator()(framework::InferShapeContext *ctx) const override {
VLOG(10) << "PrintOpInferShape"; VLOG(10) << "PrintOpInferShape";
PADDLE_ENFORCE(ctx->HasInput("In"), "Input(In) should not be null."); OP_INOUT_CHECK(ctx->HasInput("In"), "Input", "In", "Print");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Print");
ctx->ShareDim("In", /*->*/ "Out"); ctx->ShareDim("In", /*->*/ "Out");
ctx->ShareLoD("In", /*->*/ "Out"); ctx->ShareLoD("In", /*->*/ "Out");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册