diff --git a/paddle/fluid/operators/optimizers/adamw_op.h b/paddle/fluid/operators/optimizers/adamw_op.h index d87f827bbf22bb103f390878c8531eba15aeb74f..b6dce0a68c694693bbdf4f879e6d20b59f3e6316 100644 --- a/paddle/fluid/operators/optimizers/adamw_op.h +++ b/paddle/fluid/operators/optimizers/adamw_op.h @@ -31,19 +31,22 @@ class AdamWFunctor; template class AdamWFunctor { private: - const float coeff_; - const float learning_rate_; + const T coeff_; + const T* lr_; T* param_; public: - AdamWFunctor(const float& coeff, const float& learning_rate, T* param) - : coeff_(coeff), learning_rate_(learning_rate), param_(param) {} + AdamWFunctor(const T coeff, const T* lr, T* param) + : coeff_(coeff), lr_(lr), param_(param) {} inline HOSTDEVICE void operator()(size_t numel) const { Eigen::Map> param{ param_, static_cast(numel)}; + + T lr = *lr_; + // Calculation - param = param * (1.0f - learning_rate_ * coeff_); + param = param * (1 - lr * coeff_); } }; @@ -183,7 +186,7 @@ class AdamWOpKernel : public AdamOpKernel { return; } - float coeff = ctx.Attr("coeff"); + T coeff = static_cast(ctx.Attr("coeff")); auto* lr = ctx.Input("LearningRate"); LoDTensor* param; @@ -195,9 +198,7 @@ class AdamWOpKernel : public AdamOpKernel { param = const_cast(ctx.Input("Param")); } - // AdamWFunctor(float coeff, const float* learning_rate, T* parma) - AdamWFunctor functor(coeff, *lr->data(), - param->data()); + AdamWFunctor functor(coeff, lr->data(), param->data()); functor(param->numel()); AdamOpKernel::Compute(ctx);