From 230325906483b3e3b473f5177ede1a0de2132415 Mon Sep 17 00:00:00 2001 From: wangzhen38 <41941775+wangzhen38@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:20:35 +0800 Subject: [PATCH] [BUG Fixs] adadelta lr support (#49732) --- .../fluid/operators/optimizers/adadelta_op.cc | 1 + paddle/fluid/pybind/eager_generator.h | 7 ++- paddle/phi/api/yaml/legacy_ops.yaml | 2 +- paddle/phi/infermeta/multiary.cc | 6 +++ paddle/phi/infermeta/multiary.h | 1 + paddle/phi/kernels/adadelta_kernel.h | 1 + .../phi/kernels/impl/adadelta_kernel_impl.h | 47 ++++++++++++------- paddle/phi/kernels/xpu/adadelta_kernel.cc | 1 + paddle/phi/ops/compat/adadelta_sig.cc | 20 ++++---- python/paddle/fluid/optimizer.py | 2 + .../fluid/tests/unittests/test_adadelta_op.py | 9 +++- python/paddle/optimizer/adadelta.py | 2 + 12 files changed, 69 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adadelta_op.cc b/paddle/fluid/operators/optimizers/adadelta_op.cc index 2df8ff971ce..cb2c374d017 100644 --- a/paddle/fluid/operators/optimizers/adadelta_op.cc +++ b/paddle/fluid/operators/optimizers/adadelta_op.cc @@ -39,6 +39,7 @@ class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("AvgSquaredGrad", "(Tensor) Input average of squared gradient"); AddInput("AvgSquaredUpdate", "(Tensor) Input average of squared parameter updates"); + AddInput("LearningRate", "(Tensor) Learning rate"); AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable(); AddOutput("ParamOut", "(Tensor) Output parameter"); diff --git a/paddle/fluid/pybind/eager_generator.h b/paddle/fluid/pybind/eager_generator.h index 2eb7934c911..03b8690569c 100644 --- a/paddle/fluid/pybind/eager_generator.h +++ b/paddle/fluid/pybind/eager_generator.h @@ -220,7 +220,12 @@ std::map> op_ins_map = { {"sgd", {"Param", "LearningRate", "Grad", "MasterParam"}}, {"adagrad", {"Param", "Grad", "Moment", "LearningRate", "MasterParam"}}, {"adadelta", - {"Param", "Grad", "AvgSquaredGrad", "AvgSquaredUpdate", "MasterParam"}}, + {"Param", + "Grad", + "AvgSquaredGrad", + "AvgSquaredUpdate", + "LearningRate", + "MasterParam"}}, {"graph_khop_sampler", {"Row", "Eids", "Col_Ptr", "X"}}, {"nce", {"Input", diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index e44bbe7e6dd..2d0aadcf536 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -11,7 +11,7 @@ backward : abs_grad - op : adadelta_ - args : (Tensor param, Tensor grad, Tensor avg_squared_grad, Tensor avg_squared_update, Tensor master_param, float rho, float epsilon, bool multi_precision) + args : (Tensor param, Tensor grad, Tensor avg_squared_grad, Tensor avg_squared_update, Tensor learning_rate, Tensor master_param, float rho, float epsilon, bool multi_precision) output : Tensor(param_out), Tensor(moment_out), Tensor(inf_norm_out), Tensor(master_param_out) infer_meta : func : AdadeltaInferMeta diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index ea93a587493..7364f85e751 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -40,6 +40,7 @@ void AdadeltaInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& avg_squared_grad, const MetaTensor& avg_squared_update, + const MetaTensor& learning_rate, const MetaTensor& master_param, float rho, float epsilon, @@ -48,6 +49,11 @@ void AdadeltaInferMeta(const MetaTensor& param, MetaTensor* avg_squared_grad_out, MetaTensor* avg_squared_update_out, MetaTensor* master_param_out) { + auto lr_dims = learning_rate.dims(); + PADDLE_ENFORCE_EQ( + phi::product(lr_dims), + 1, + phi::errors::InvalidArgument("LearningRate should have one element")); auto param_dims = param.dims(); PADDLE_ENFORCE_EQ( param_dims, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index cf6ca3c2a9f..178910e3620 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -43,6 +43,7 @@ void AdadeltaInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& avg_squared_grad, const MetaTensor& avg_squared_update, + const MetaTensor& learning_rate, const MetaTensor& master_param, float rho, float epsilon, diff --git a/paddle/phi/kernels/adadelta_kernel.h b/paddle/phi/kernels/adadelta_kernel.h index 15c07b3e6f9..16f4e6ca269 100644 --- a/paddle/phi/kernels/adadelta_kernel.h +++ b/paddle/phi/kernels/adadelta_kernel.h @@ -24,6 +24,7 @@ void AdadeltaKernel(const Context& dev_ctx, const DenseTensor& grad, const DenseTensor& avg_squared_grad, const DenseTensor& avg_squared_update, + const DenseTensor& learning_rate, const paddle::optional& master_param, float rho, float epsilon, diff --git a/paddle/phi/kernels/impl/adadelta_kernel_impl.h b/paddle/phi/kernels/impl/adadelta_kernel_impl.h index b0c0a072acd..c432c72d832 100644 --- a/paddle/phi/kernels/impl/adadelta_kernel_impl.h +++ b/paddle/phi/kernels/impl/adadelta_kernel_impl.h @@ -13,11 +13,14 @@ // limitations under the License. #pragma once +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/adadelta_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" -#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -27,6 +30,7 @@ void AdadeltaKernel(const Context& dev_ctx, const DenseTensor& grad, const DenseTensor& avg_squared_grad, const DenseTensor& avg_squared_update, + const DenseTensor& learning_rate, const paddle::optional& master_param, float rho, float epsilon, @@ -56,29 +60,36 @@ void AdadeltaKernel(const Context& dev_ctx, auto eigen_avg_squared_update_out = EigenVector::Flatten(*avg_squared_update_out); auto& place = *dev_ctx.eigen_device(); - auto eigen_grad_cast = eigen_grad.template cast(); - eigen_avg_squared_grad_out.device(place) = rho_ * eigen_avg_squared_grad + (1 - rho_) * eigen_grad_cast.square(); - auto update = -((eigen_avg_squared_update + epsilon_) / - (eigen_avg_squared_grad_out + epsilon_)) - .sqrt() * - eigen_grad_cast; - eigen_avg_squared_update_out.device(place) = - rho_ * eigen_avg_squared_update + (1 - rho_) * update.square(); - - if (multi_precision) { - auto eigen_master_param_out = - EigenVector::Flatten(*master_param_outs); - auto eigen_master_param = EigenVector::Flatten(*master_param); - - eigen_master_param_out.device(place) = eigen_master_param + update; + auto update = + -(((eigen_avg_squared_update + epsilon_).sqrt()) / + ((eigen_avg_squared_grad_out + epsilon_).sqrt()) * eigen_grad_cast); + Eigen::DSizes m_dsize(avg_squared_update_out->numel()); + if (paddle::platform::is_cpu_place(dev_ctx.GetPlace())) { + auto* lr = learning_rate.data(); eigen_param_out.device(place) = - (eigen_param.template cast() + update).template cast(); + eigen_param + lr[0] * update.template cast(); } else { - eigen_param_out.device(place) = eigen_param + update.template cast(); + auto lr = EigenVector::Flatten(learning_rate); + if (multi_precision) { + auto eigen_master_param_out = + EigenVector::Flatten(*master_param_outs); + auto eigen_master_param = EigenVector::Flatten(*master_param); + + eigen_master_param_out.device(place) = + eigen_master_param + lr.broadcast(m_dsize) * update; + eigen_param_out.device(place) = (eigen_param.template cast() + + lr.broadcast(m_dsize) * update) + .template cast(); + } else { + eigen_param_out.device(place) = + eigen_param + (lr.broadcast(m_dsize) * update).template cast(); + } } + eigen_avg_squared_update_out.device(place) = + rho_ * eigen_avg_squared_update + (1 - rho_) * update.square(); } } // namespace phi diff --git a/paddle/phi/kernels/xpu/adadelta_kernel.cc b/paddle/phi/kernels/xpu/adadelta_kernel.cc index e02a5aeabad..b87ec1afbdc 100644 --- a/paddle/phi/kernels/xpu/adadelta_kernel.cc +++ b/paddle/phi/kernels/xpu/adadelta_kernel.cc @@ -25,6 +25,7 @@ void AdadeltaKernel(const Context& dev_ctx, const DenseTensor& grad, const DenseTensor& avg_squared_grad, const DenseTensor& avg_squared_update, + const DenseTensor& learning_rate, const paddle::optional& master_param, float rho, float epsilon, diff --git a/paddle/phi/ops/compat/adadelta_sig.cc b/paddle/phi/ops/compat/adadelta_sig.cc index fd285e7e5d0..da7e4229a0d 100644 --- a/paddle/phi/ops/compat/adadelta_sig.cc +++ b/paddle/phi/ops/compat/adadelta_sig.cc @@ -18,14 +18,18 @@ namespace phi { KernelSignature AdadeltaOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.IsDenseTensorInput("Grad")) { - return KernelSignature( - "adadelta", - {"Param", "Grad", "AvgSquaredGrad", "AvgSquaredUpdate", "MasterParam"}, - {"rho", "epsilon", "multi_precision"}, - {"ParamOut", - "AvgSquaredGradOut", - "AvgSquaredUpdateOut", - "MasterParamOut"}); + return KernelSignature("adadelta", + {"Param", + "Grad", + "AvgSquaredGrad", + "AvgSquaredUpdate", + "LearningRate", + "MasterParam"}, + {"rho", "epsilon", "multi_precision"}, + {"ParamOut", + "AvgSquaredGradOut", + "AvgSquaredUpdateOut", + "MasterParamOut"}); } return KernelSignature("unregistered", {}, {}, {}); diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 6ed9e674689..db483b151e4 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -3215,6 +3215,7 @@ class AdadeltaOptimizer(Optimizer): param_and_grad[1], avg_squared_grad_acc, avg_squared_update_acc, + self._create_param_lr(param_and_grad), master_weight, self._rho, self._epsilon, @@ -3227,6 +3228,7 @@ class AdadeltaOptimizer(Optimizer): "Grad": param_and_grad[1], "AvgSquaredGrad": avg_squared_grad_acc, "AvgSquaredUpdate": avg_squared_update_acc, + "LearningRate": self._create_param_lr(param_and_grad), } outputs = { "ParamOut": param_and_grad[0], diff --git a/python/paddle/fluid/tests/unittests/test_adadelta_op.py b/python/paddle/fluid/tests/unittests/test_adadelta_op.py index 11db47b2475..f3eca8fec9c 100644 --- a/python/paddle/fluid/tests/unittests/test_adadelta_op.py +++ b/python/paddle/fluid/tests/unittests/test_adadelta_op.py @@ -26,6 +26,7 @@ def adadelta_wrapper( Grad, AvgSquaredGrad, AvgSquaredUpdate, + LearningRate, master_weight=None, rho=0.95, epsilon=1e-6, @@ -35,12 +36,13 @@ def adadelta_wrapper( Grad, AvgSquaredGrad, AvgSquaredUpdate, + LearningRate, None, rho, epsilon, False, ) - return Param, AvgSquaredGrad, AvgSquaredUpdate + return Param, AvgSquaredGrad, AvgSquaredUpdate, LearningRate class TestAdadeltaOp1(OpTest): @@ -58,11 +60,13 @@ class TestAdadeltaOp1(OpTest): rho = 0.95 epsilon = 1e-6 + learning_rate = 1.0 self.inputs = { 'Param': param, 'Grad': grad, 'AvgSquaredGrad': avg_squared_grad, 'AvgSquaredUpdate': avg_squared_update, + 'LearningRate': np.array([learning_rate]).astype("float32"), } self.attrs = {'rho': rho, 'epsilon': epsilon} @@ -113,12 +117,13 @@ class TestAdadeltaOp2(OpTest): epsilon = 1e-6 self.attrs = {'rho': rho, 'epsilon': epsilon} - + learning_rate = 1.0 self.inputs = { 'Param': param, 'Grad': grad, 'AvgSquaredGrad': avg_squared_grad, 'AvgSquaredUpdate': avg_squared_update, + 'LearningRate': np.array([learning_rate]).astype("float32"), } avg_squared_grad_out = rho * avg_squared_grad + (1 - rho) * np.square( diff --git a/python/paddle/optimizer/adadelta.py b/python/paddle/optimizer/adadelta.py index 1cdb61f698e..c760c535da0 100644 --- a/python/paddle/optimizer/adadelta.py +++ b/python/paddle/optimizer/adadelta.py @@ -197,6 +197,7 @@ class Adadelta(Optimizer): param_and_grad[1], avg_squared_grad_acc, avg_squared_update_acc, + self._create_param_lr(param_and_grad), master_weight, self._rho, self._epsilon, @@ -213,6 +214,7 @@ class Adadelta(Optimizer): "Grad": param_and_grad[1], "AvgSquaredGrad": avg_squared_grad_acc, "AvgSquaredUpdate": avg_squared_update_acc, + "LearningRate": self._create_param_lr(param_and_grad), } outputs = { "ParamOut": param_and_grad[0], -- GitLab