diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 1cc5df8600bd487c4d505e5820f12b192fa8e173..65cf4c170ac91823bfef2d3a202f4893a46dba3c 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -23,68 +23,94 @@ namespace paddle { namespace operators { void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), "Assert only one Input(X) of LSTM."); - PADDLE_ENFORCE(ctx->HasInput("WeightX"), - "Assert only one Input(WeightX) of LSTM."); - PADDLE_ENFORCE(ctx->HasInput("WeightH"), - "Assert only one Input(WeightH) of LSTM."); - PADDLE_ENFORCE(ctx->HasInput("Bias"), "Assert only one Input(Bias) of LSTM."); - PADDLE_ENFORCE(ctx->HasOutput("XX"), "Assert only one Output(XX) of LSTM."); - PADDLE_ENFORCE(ctx->HasOutput("Hidden"), - "Assert only one Output(Hidden) of LSTM."); - PADDLE_ENFORCE(ctx->HasOutput("Cell"), - "Assert only one Output(Cell) of LSTM."); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_lstm"); + OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_lstm"); + OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_lstm"); + OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "fusion_lstm"); + OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_lstm"); + OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_lstm"); + OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "fusion_lstm"); auto x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, + platform::errors::InvalidArgument( + "Input(X)'s rank must be 2, but received x's rank " + "is:%d, x dim is:[%s]", + x_dims.size(), x_dims)); if (ctx->HasInput("H0")) { - PADDLE_ENFORCE(ctx->HasInput("C0"), - "Input(Cell) and Input(Hidden) of LSTM should not " - "be null at the same time."); + OP_INOUT_CHECK(ctx->HasInput("C0"), "Input", "C0", "fusion_lstm"); auto h_dims = ctx->GetInputDim("H0"); auto c_dims = ctx->GetInputDim("C0"); - PADDLE_ENFORCE(h_dims == c_dims, - "The dimension of Input(H0) and Input(C0) " - "should be the same."); + PADDLE_ENFORCE_EQ(h_dims, c_dims, + platform::errors::InvalidArgument( + "The dimension of Input(H0) and Input(C0) should be " + "same, but received h0 dims is:[%s], c0 dims is:[%s]", + h_dims, c_dims)); } auto wx_dims = ctx->GetInputDim("WeightX"); PADDLE_ENFORCE_EQ(wx_dims.size(), 2, - "The rank of Input(WeightX) should be 2."); + platform::errors::InvalidArgument( + "The rank of Input(WeightX) should be 2, but received " + "WeightX's rank is:%d, WeightX dim is:[%s]", + wx_dims.size(), wx_dims)); PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1], - "The first dimension of Input(WeightX) " - "should be %d.", - x_dims[1]); + platform::errors::InvalidArgument( + "The first dimension of Input(WeightX) " + "should equal to second dimension of Input(X), but " + "received WeightX first dim is:%d, X second dim is:%d", + wx_dims[0], x_dims[1])); int frame_size = wx_dims[1] / 4; auto wh_dims = ctx->GetInputDim("WeightH"); + PADDLE_ENFORCE_EQ(wh_dims.size(), 2, - "The rank of Input(WeightH) should be 2."); + platform::errors::InvalidArgument( + "The rank of Input(WeightH) should be 2, but received " + "WeightH rank is:%d, WeightH dim is:[%s]", + wh_dims.size(), wh_dims)); PADDLE_ENFORCE_EQ(wh_dims[0], frame_size, - "The first dimension of Input(WeightH) " - "should be %d.", - frame_size); + platform::errors::InvalidArgument( + "The first dimension of Input(WeightH) " + "should equal to frame size, but received WeightH " + "first dim is:%d, frame size is:%d.", + wh_dims[0], frame_size)); + PADDLE_ENFORCE_EQ(wh_dims[1], 4 * frame_size, - "The second dimension of Input(WeightH) " - "should be 4 * %d.", - frame_size); + platform::errors::InvalidArgument( + "The second dimension of Input(WeightH) " + "should equal to 4 * frame_size, but received WeightH " + "second dimension is:%d, frame size is:%d.", + wh_dims[1], frame_size)); auto b_dims = ctx->GetInputDim("Bias"); - PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); + PADDLE_ENFORCE_EQ(b_dims.size(), 2, + platform::errors::InvalidArgument( + "The rank of Input(Bias) should be 2, but received " + "Bias rank is:%d, Bias dim is:[%s]", + b_dims.size(), b_dims)); PADDLE_ENFORCE_EQ(b_dims[0], 1, - "The first dimension of Input(Bias) should be 1."); + platform::errors::InvalidArgument( + "The first dimension of Input(Bias) should be 1, but " + "received Bias's dimension is:[%s]", + b_dims)); + if (ctx->Attrs().Get("use_peepholes")) { PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, - "The second dimension of Input(Bias) should be " - "7 * %d if enable peepholes connection", - frame_size); + platform::errors::InvalidArgument( + "The second dimension of Input(Bias) should be " + "7 * %d if enable peepholes connection, but received " + "Bias dim is:[%s]", + frame_size, b_dims)); ctx->SetOutputDim("CheckedCell", {2, frame_size}); } else { - PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size, - "The second dimension of Input(Bias) should be " - "4 * %d if disable peepholes", - frame_size); + PADDLE_ENFORCE_EQ( + b_dims[1], 4 * frame_size, + platform::errors::InvalidArgument( + "The second dimension of Input(Bias) should be " + "4 * %d if disable peepholes, but received Bias dim is:[%s]", + frame_size, b_dims)); } framework::DDim out_dims({x_dims[0], frame_size}); @@ -97,16 +123,18 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { xx_width = wx_dims[1]; } else { xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; - PADDLE_ENFORCE(ctx->HasOutput("BatchedInput"), - "Assert only one Output(BatchedInput) of LSTM."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedHidden"), - "Assert only one Output(BatchedHidden) of LSTM."); - PADDLE_ENFORCE(ctx->HasOutput("BatchedCell"), - "Assert only one Output(BatchedCell) of LSTM."); - PADDLE_ENFORCE(ctx->HasOutput("ReorderedH0"), - "Assert only one Output(ReorderedH0) of LSTM"); - PADDLE_ENFORCE(ctx->HasOutput("ReorderedC0"), - "Assert only one Output(ReorderedC0) of LSTM."); + + OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput", + "fusion_lstm"); + OP_INOUT_CHECK(ctx->HasOutput("BatchedHidden"), "Output", "BatchedHidden", + "fusion_lstm"); + OP_INOUT_CHECK(ctx->HasOutput("BatchedCell"), "Output", "BatchedCell", + "fusion_lstm"); + OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0", + "fusion_lstm"); + OP_INOUT_CHECK(ctx->HasOutput("ReorderedC0"), "Output", "ReorderedC0", + "fusion_lstm"); + ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedHidden", out_dims); ctx->SetOutputDim("BatchedCell", out_dims);