diff --git a/paddle/fluid/operators/fsp_op.cc b/paddle/fluid/operators/fsp_op.cc index d1aedf41e9a2be3b2970f4ff2032c7e2e3ce3a42..f00ec6a1e140c1bf34a1dc8a7ac491db1c89ae9c 100644 --- a/paddle/fluid/operators/fsp_op.cc +++ b/paddle/fluid/operators/fsp_op.cc @@ -23,9 +23,9 @@ class FSPOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fsp_op"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "fsp_op"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fsp_op"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fsp"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "fsp"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "fsp"); auto x_dims = ctx->GetInputDim("X"); auto y_dims = ctx->GetInputDim("Y"); @@ -103,10 +103,11 @@ class FSPOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); - PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) should not be null"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fsp_grad"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "fsp_grad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + framework::GradVarName("Out"), "fsp_grad"); + auto x_dims = ctx->GetInputDim("X"); auto y_dims = ctx->GetInputDim("Y"); auto x_grad_name = framework::GradVarName("X");