diff --git a/paddle/fluid/operators/optimizers/dgc_momentum_op.cc b/paddle/fluid/operators/optimizers/dgc_momentum_op.cc index 92ce600f22b64f82a053233dbd130adefca964fa..7f0b2b7d064ed12875577fee2265ab17c1fce08f 100644 --- a/paddle/fluid/operators/optimizers/dgc_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/dgc_momentum_op.cc @@ -25,15 +25,11 @@ class DGCMomentumOp : public MomentumOp { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("current_step"), true, - "current_step should be set."); - PADDLE_ENFORCE_EQ(ctx->HasInput("nranks"), true, - platform::errors::NotFound( - "Input(nranks) of DGCMomentumOp is not found.")); - - PADDLE_ENFORCE_EQ(ctx->HasOutput("Grad_out"), true, - platform::errors::NotFound( - "Output(Grad_out) of DGCMomentumOp is not found.")); + OP_INOUT_CHECK(ctx->HasInput("current_step"), "Input", "current_step", + "DGCMomentumOp"); + OP_INOUT_CHECK(ctx->HasInput("nranks"), "Input", "nranks", "DGCMomentumOp"); + OP_INOUT_CHECK(ctx->HasOutput("Grad_out"), "Output", "Grad_out", + "DGCMomentumOp"); return MomentumOp::InferShape(ctx); }