diff --git a/paddle/fluid/operators/optimizers/dgc_momentum_op.cc b/paddle/fluid/operators/optimizers/dgc_momentum_op.cc index d491593d0af11a1127477a99a74dec59d102941d..6241395d3b70f16e300590496716a3913747de9a 100644 --- a/paddle/fluid/operators/optimizers/dgc_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/dgc_momentum_op.cc @@ -14,7 +14,9 @@ #include +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -24,92 +26,6 @@ class DGCMomentumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("current_step"), - "Input", - "current_step", - "DGCMomentumOp"); - OP_INOUT_CHECK(ctx->HasInput("nranks"), "Input", "nranks", "DGCMomentumOp"); - OP_INOUT_CHECK( - ctx->HasOutput("Grad_out"), "Output", "Grad_out", "DGCMomentumOp"); - - PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), - true, - platform::errors::NotFound( - "Input(param) of Momentum should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), - true, - platform::errors::NotFound( - "Input(grad) of Momentum should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Velocity"), - true, - platform::errors::NotFound( - "Input(velocity) of Momentum should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("LearningRate"), - true, - platform::errors::NotFound( - "Input(LearningRate) of Momentum should not be null.")); - PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(), - framework::proto::VarType::LOD_TENSOR, - platform::errors::InvalidArgument( - "The input var's type should be phi::DenseTensor, " - "but the received is %s", - ctx->GetInputsVarType("Param").front())); - - PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), - true, - platform::errors::NotFound( - "Output(ParamOut) of Momentum should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("VelocityOut"), - true, - platform::errors::NotFound( - "Output(VelocityOut) of Momentum should not be null.")); - - auto lr_dims = ctx->GetInputDim("LearningRate"); - PADDLE_ENFORCE_NE(phi::product(lr_dims), - 0, - platform::errors::InvalidArgument( - "Maybe the Input variable LearningRate has not " - "been initialized. You may need to confirm " - "if you put exe.run(startup_program) " - "after optimizer.minimize function.")); - PADDLE_ENFORCE_EQ(phi::product(lr_dims), - 1, - platform::errors::InvalidArgument( - "Learning_rate should be a scalar. But Received " - "LearningRate's dim [%s]", - phi::product(lr_dims))); - - auto param_dim = ctx->GetInputDim("Param"); - if (ctx->GetInputsVarType("Grad")[0] == - framework::proto::VarType::LOD_TENSOR) { - PADDLE_ENFORCE_EQ( - param_dim, - ctx->GetInputDim("Grad"), - platform::errors::InvalidArgument( - "Param and Grad input of MomentumOp should have the same " - "dimension. But received Param's dim [%s] and Grad's dim [%s].", - param_dim, - ctx->GetInputDim("Grad"))); - PADDLE_ENFORCE_EQ( - param_dim, - ctx->GetInputDim("Velocity"), - platform::errors::InvalidArgument( - "Param and Velocity of MomentumOp should have the same " - "dimension. But received Param's dim [%s] and Velocity [%s].", - param_dim, - ctx->GetInputDim("Velocity"))); - } - - ctx->SetOutputDim("ParamOut", param_dim); - ctx->SetOutputDim("VelocityOut", param_dim); - if (ctx->HasOutput("MasterParamOut")) { - ctx->SetOutputDim("MasterParamOut", param_dim); - } - } - phi::KernelKey GetKernelTypeForVar( const std::string& var_name, const phi::DenseTensor& tensor, @@ -199,7 +115,12 @@ DGC Momentum Operator. } // namespace operators } // namespace paddle +DECLARE_INFER_SHAPE_FUNCTOR(dgc_momentum, + DGCMomentumInferShapeFunctor, + PD_INFER_META(phi::DGCMomentumInferMeta)); + namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(dgc_momentum, ops::DGCMomentumOp, - ops::DGCMomentumOpMaker); + ops::DGCMomentumOpMaker, + DGCMomentumInferShapeFunctor); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 06232c06907169f2fb3163decf0a79689c6c5c0a..677c568df56c7d0ed41e64bdecb31c0a8266549b 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -1305,6 +1305,67 @@ void DeformableConvInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void DGCMomentumInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& velocity, + const MetaTensor& learning_rate, + const MetaTensor& master_param, + const MetaTensor& current_step_tensor, + const MetaTensor& nranks_tensor, + float mu, + bool use_nesterov, + const std::string& regularization_method, + float regularization_coeff, + bool multi_precision, + float rescale_grad, + float rampup_begin_step, + MetaTensor* param_out, + MetaTensor* velocity_out, + MetaTensor* master_param_out, + MetaTensor* grad_out) { + auto lr_dims = learning_rate.dims(); + + PADDLE_ENFORCE_NE(phi::product(lr_dims), + 0, + phi::errors::InvalidArgument( + "Maybe the Input variable LearningRate has not " + "been initialized. You may need to confirm " + "if you put exe.run(startup_program) " + "after optimizer.minimize function.")); + PADDLE_ENFORCE_EQ(phi::product(lr_dims), + 1, + phi::errors::InvalidArgument( + "Learning_rate should be a scalar. But Received " + "LearningRate's dim [%s]", + phi::product(lr_dims))); + + auto param_dims = param.dims(); + auto grad_dims = grad.dims(); + auto velocity_dims = velocity.dims(); + PADDLE_ENFORCE_EQ( + param_dims, + grad_dims, + phi::errors::InvalidArgument( + "Param and Grad input of MomentumOp should have the same " + "dimension. But received Param's dim [%s] and Grad's dim [%s].", + param_dims, + grad_dims)); + PADDLE_ENFORCE_EQ( + param_dims, + velocity_dims, + phi::errors::InvalidArgument( + "Param and Velocity of MomentumOp should have the same " + "dimension. But received Param's dim [%s] and Velocity [%s].", + param_dims, + velocity_dims)); + + param_out->set_dims(param_dims); + velocity_out->set_dims(param_dims); + if (master_param != nullptr) { + master_param_out->set_dims(param_dims); + } +} + void EditDistanceInferMeta(const MetaTensor& hyps, const MetaTensor& refs, const MetaTensor& hypslength, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index e1bd5bd1fe1e850ba807b5adc08ebf2985c4052c..630342b3b6ffe46cb4201d7dc8af90071ef46311 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -280,6 +280,25 @@ void DeformableConvInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void DGCMomentumInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& velocity, + const MetaTensor& learning_rate, + const MetaTensor& master_param, + const MetaTensor& current_step_tensor, + const MetaTensor& nranks_tensor, + float mu, + bool use_nesterov, + const std::string& regularization_method, + float regularization_coeff, + bool multi_precision, + float rescale_grad, + float rampup_begin_step, + MetaTensor* param_out, + MetaTensor* velocity_out, + MetaTensor* master_param_out, + MetaTensor* grad_out); + void EditDistanceInferMeta(const MetaTensor& hyps, const MetaTensor& refs, const MetaTensor& hypslength,