diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cc b/paddle/fluid/operators/optimizers/lars_momentum_op.cc index e6b04a8a3caaa5b6ce46c56ab696276b745be4e7..57abbab884e30235984771acaea49228d2aecd3c 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cc +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/multiary.h" namespace paddle { namespace operators { @@ -22,117 +24,6 @@ class LarsMomentumOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInputs("Param"), "Input", "Param", "LarsMomentum"); - OP_INOUT_CHECK(ctx->HasInputs("Grad"), "Input", "Grad", "LarsMomentum"); - OP_INOUT_CHECK( - ctx->HasInputs("Velocity"), "Input", "Velocity", "LarsMomentum"); - OP_INOUT_CHECK(ctx->HasInputs("LearningRate"), - "Input", - "LearningRate", - "LarsMomentum"); - OP_INOUT_CHECK( - ctx->HasOutputs("ParamOut"), "Output", "ParamOut", "LarsMomentum"); - OP_INOUT_CHECK(ctx->HasOutputs("VelocityOut"), - "Output", - "VelocityOut", - "LarsMomentum"); - PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Param").front(), - framework::proto::VarType::LOD_TENSOR, - platform::errors::InvalidArgument( - "The input var's type should be phi::DenseTensor, " - "but the received is %s", - ctx->GetInputsVarType("Param").front())); - - auto lr_dims = ctx->GetInputsDim("LearningRate"); - auto grad_dim = ctx->GetInputsDim("Grad"); - auto param_dim = ctx->GetInputsDim("Param"); - auto velocity_dim = ctx->GetInputsDim("Velocity"); - auto lars_weight_decays = - ctx->Attrs().Get>("lars_weight_decay"); - auto multi_precision = ctx->Attrs().Get("multi_precision"); - - PADDLE_ENFORCE_EQ( - param_dim.size(), - grad_dim.size(), - platform::errors::InvalidArgument( - "Input(Param) and Input(Grad) of LarsMomentumOp should have " - "same quantity. But number of Param is [%d] and Grad is [%d].", - param_dim.size(), - grad_dim.size())); - PADDLE_ENFORCE_EQ( - param_dim.size(), - velocity_dim.size(), - platform::errors::InvalidArgument( - "Input(Param) and Input(Velocity) of LarsMomentumOp should " - "have same quantity. But number of Param is [%d] and Velocity " - "is [%d].", - param_dim.size(), - velocity_dim.size())); - PADDLE_ENFORCE_EQ( - lars_weight_decays.size(), - grad_dim.size(), - platform::errors::InvalidArgument( - "Attr(Lars_weight_decay) and " - "Input(Grad) of LarsMomentumOp should have same quantity. " - "But number of Lars_weight_decay is [%d] and Grad is [%d].", - lars_weight_decays.size(), - grad_dim.size())); - - if (multi_precision) { - OP_INOUT_CHECK(ctx->HasInputs("MasterParam"), - "Input", - "MasterParam", - "LarsMomentumMultiPrecision"); - OP_INOUT_CHECK(ctx->HasOutputs("MasterParamOut"), - "Output", - "MasterParamOut", - "LarsMomentumMultiPrecision"); - } - for (auto& lr_dim : lr_dims) { - PADDLE_ENFORCE_EQ(phi::product(lr_dim), - 1, - platform::errors::InvalidArgument( - "Learning_rate should be a scalar. But Received " - "LearningRate's dim [%s]", - phi::product(lr_dim))); - } - - for (size_t i = 0; i < param_dim.size(); ++i) { - PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Grad")[i], - framework::proto::VarType::LOD_TENSOR, - platform::errors::InvalidArgument( - "The Var(%s)'s type should be phi::DenseTensor, " - "but the received is %s", - ctx->Inputs("Grad")[i].front(), - ctx->GetInputsVarType("Grad")[i])); - PADDLE_ENFORCE_EQ( - param_dim[i], - grad_dim[i], - platform::errors::InvalidArgument( - "Input(Param) and Input(Grad) input of LarsMomentumOp shall " - "have same dimension. But Param`s dim is [%s] and Grad's dim " - "is [%s].", - param_dim[i], - grad_dim[i])); - PADDLE_ENFORCE_EQ( - param_dim[i], - velocity_dim[i], - platform::errors::InvalidArgument( - "Input(Param) and Input(Velocity) of LarsMomentumOp shall have " - "same dimension. But Param dim [%s] differs with Velocity dim " - "[%s].", - param_dim[i], - velocity_dim[i])); - } - ctx->SetOutputsDim("ParamOut", param_dim); - ctx->SetOutputsDim("VelocityOut", param_dim); - if (ctx->HasOutputs("MasterParamOut")) { - ctx->SetOutputsDim("MasterParamOut", param_dim); - } - } - protected: phi::KernelKey GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { @@ -226,6 +117,10 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference { } // namespace operators } // namespace paddle +DECLARE_INFER_SHAPE_FUNCTOR(lars_momentum, + LarsMomentumInferShapeFunctor, + PD_INFER_META(phi::LarsMomentumInferMeta)); + namespace ops = paddle::operators; REGISTER_OPERATOR( lars_momentum, @@ -233,4 +128,5 @@ REGISTER_OPERATOR( ops::LarsMomentumOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, - ops::LarsMomentumOpVarTypeInference); + ops::LarsMomentumOpVarTypeInference, + LarsMomentumInferShapeFunctor); diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 958bfa76fe9c5c5729b113154f61367b7055fc20..0701afa51af797dc927c2fa97fd333b54b0dcc8c 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2577,6 +2577,92 @@ void LambInferMeta(const MetaTensor& param, } } +void LarsMomentumInferMeta( + const std::vector& param, + const std::vector& velocity, + const std::vector& learning_rate, + const std::vector& grad, + const paddle::optional>& master_param, + const std::vector& lars_weight_decay, + float mu, + float lars_coeff, + float epsilon, + bool multi_precision, + float rescale_grad, + std::vector param_out, + std::vector velocity_out, + std::vector master_param_out) { + std::vector lr_dims = GetMetaTensorsDim(learning_rate); + std::vector grad_dim = GetMetaTensorsDim(grad); + std::vector param_dim = GetMetaTensorsDim(param); + std::vector velocity_dim = GetMetaTensorsDim(velocity); + + PADDLE_ENFORCE_EQ( + param_dim.size(), + grad_dim.size(), + phi::errors::InvalidArgument( + "Input(Param) and Input(Grad) of LarsMomentumOp should have " + "same quantity. But number of Param is [%d] and Grad is [%d].", + param_dim.size(), + grad_dim.size())); + PADDLE_ENFORCE_EQ( + param_dim.size(), + velocity_dim.size(), + phi::errors::InvalidArgument( + "Input(Param) and Input(Velocity) of LarsMomentumOp should " + "have same quantity. But number of Param is [%d] and Velocity " + "is [%d].", + param_dim.size(), + velocity_dim.size())); + PADDLE_ENFORCE_EQ( + lars_weight_decay.size(), + grad_dim.size(), + phi::errors::InvalidArgument( + "Attr(Lars_weight_decay) and " + "Input(Grad) of LarsMomentumOp should have same quantity. " + "But number of Lars_weight_decay is [%d] and Grad is [%d].", + lars_weight_decay.size(), + grad_dim.size())); + + for (auto& lr_dim : lr_dims) { + PADDLE_ENFORCE_EQ(phi::product(lr_dim), + 1, + phi::errors::InvalidArgument( + "Learning_rate should be a scalar. But Received " + "LearningRate's dim [%s]", + phi::product(lr_dim))); + } + + for (size_t i = 0; i < param_dim.size(); ++i) { + PADDLE_ENFORCE_EQ( + param_dim[i], + grad_dim[i], + phi::errors::InvalidArgument( + "Input(Param) and Input(Grad) input of LarsMomentumOp shall " + "have same dimension. But Param`s dim is [%s] and Grad's dim " + "is [%s].", + param_dim[i], + grad_dim[i])); + PADDLE_ENFORCE_EQ( + param_dim[i], + velocity_dim[i], + phi::errors::InvalidArgument( + "Input(Param) and Input(Velocity) of LarsMomentumOp shall have " + "same dimension. But Param dim [%s] differs with Velocity dim " + "[%s].", + param_dim[i], + velocity_dim[i])); + } + + for (size_t i = 0; i < param_out.size(); i++) { + param_out[i]->set_dims(param_dim[i]); + velocity_out[i]->set_dims(param_dim[i]); + if (master_param != nullptr) { + master_param_out[i]->set_dims(param_dim[i]); + } + } +} + void LLMInt8LinearInferMeta(const MetaTensor& x, const MetaTensor& weight, const MetaTensor& bias, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 9beb9d213899d954597e2def53e98245e322772e..ee62d6d51d65502933f8b2de6a59b5b4ebd57646 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -459,6 +459,22 @@ void LambInferMeta(const MetaTensor& param, MetaTensor* beta2_pow_out, MetaTensor* master_param_outs); +void LarsMomentumInferMeta( + const std::vector& param, + const std::vector& velocity, + const std::vector& learning_rate, + const std::vector& grad, + const paddle::optional>& master_param, + const std::vector& lars_weight_decay, + float mu, + float lars_coeff, + float epsilon, + bool multi_precision, + float rescale_grad, + std::vector param_out, + std::vector velocity_out, + std::vector master_param_out); + void LLMInt8LinearInferMeta(const MetaTensor& x, const MetaTensor& weight, const MetaTensor& bias,