From 3672480b22a4090ebc580bd65a2909fe0f71b0f6 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 29 Dec 2021 11:01:38 +0800 Subject: [PATCH] fix lamb beta1pow beta2pow update (#38518) --- paddle/fluid/operators/optimizers/lamb_op.h | 180 ++++++++++++-------- 1 file changed, 108 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/operators/optimizers/lamb_op.h b/paddle/fluid/operators/optimizers/lamb_op.h index 9eba8df9992..df17b5e5f40 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.h +++ b/paddle/fluid/operators/optimizers/lamb_op.h @@ -52,19 +52,16 @@ struct LambMomentREGUpdateFunctor { const bool* skip_update_; LambMomentREGUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon, - MT beta1_pow, MT* beta1_pow_out, MT beta2_pow, - MT* beta2_pow_out, const MT* mom1, MT* mom1_out, - const MT* mom2, MT* mom2_out, const T* grad, - const MT* param, MT* trust_ratio_div, - const bool* skip_update) + MT beta1_pow, MT beta2_pow, const MT* mom1, + MT* mom1_out, const MT* mom2, MT* mom2_out, + const T* grad, const MT* param, + MT* trust_ratio_div, const bool* skip_update) : weight_decay_(weight_decay), beta1_(beta1), beta2_(beta2), epsilon_(epsilon), beta1_pow_(beta1_pow), - beta1_pow_out_(beta1_pow_out), beta2_pow_(beta2_pow), - beta2_pow_out_(beta2_pow_out), moment1_(mom1), moment1_out_(mom1_out), moment2_(mom2), @@ -95,10 +92,6 @@ struct LambMomentREGUpdateFunctor { trust_ratio_div_[i] = mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p; - if (beta1_pow_out_ && beta2_pow_out_) { - beta1_pow_out_[0] = beta1_pow * beta1_; - beta2_pow_out_[0] = beta2_pow * beta2_; - } } }; @@ -113,9 +106,7 @@ struct LambMomentMENUpdateFunctor { MT epsilon_; const MT* beta1_pow_; - MT* beta1_pow_out_; const MT* beta2_pow_; - MT* beta2_pow_out_; const MT* moment1_; MT* moment1_out_; const MT* moment2_; @@ -126,8 +117,7 @@ struct LambMomentMENUpdateFunctor { const bool* skip_update_; LambMomentMENUpdateFunctor(MT weight_decay, MT beta1, MT beta2, MT epsilon, - const MT* beta1_pow, MT* beta1_pow_out, - const MT* beta2_pow, MT* beta2_pow_out, + const MT* beta1_pow, const MT* beta2_pow, const MT* mom1, MT* mom1_out, const MT* mom2, MT* mom2_out, const T* grad, const MT* param, MT* trust_ratio_div, const bool* skip_update) @@ -136,9 +126,7 @@ struct LambMomentMENUpdateFunctor { beta2_(beta2), epsilon_(epsilon), beta1_pow_(beta1_pow), - beta1_pow_out_(beta1_pow_out), beta2_pow_(beta2_pow), - beta2_pow_out_(beta2_pow_out), moment1_(mom1), moment1_out_(mom1_out), moment2_(mom2), @@ -168,10 +156,6 @@ struct LambMomentMENUpdateFunctor { trust_ratio_div_[i] = mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p; - if (beta1_pow_out_ && beta2_pow_out_) { - beta1_pow_out_[0] = beta1_pow * beta1_; - beta2_pow_out_[0] = beta2_pow * beta2_; - } } }; @@ -183,9 +167,7 @@ struct SparseLambMomentREGUpdateFunctor { T epsilon_; T beta1_pow_; - T* beta1_pow_out_; T beta2_pow_; - T* beta2_pow_out_; const T* moment1_; T* moment1_out_; const T* moment2_; @@ -201,20 +183,18 @@ struct SparseLambMomentREGUpdateFunctor { const bool* skip_update_; SparseLambMomentREGUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon, - T beta1_pow, T* beta1_pow_out, T beta2_pow, - T* beta2_pow_out, const T* mom1, T* mom1_out, - const T* mom2, T* mom2_out, const T* grad, - const T* param, T* trust_ratio_div, - const int64_t* rows, int64_t row_numel, - int64_t row_count, const bool* skip_update) + T beta1_pow, T beta2_pow, const T* mom1, + T* mom1_out, const T* mom2, T* mom2_out, + const T* grad, const T* param, + T* trust_ratio_div, const int64_t* rows, + int64_t row_numel, int64_t row_count, + const bool* skip_update) : weight_decay_(weight_decay), beta1_(beta1), beta2_(beta2), epsilon_(epsilon), beta1_pow_(beta1_pow), - beta1_pow_out_(beta1_pow_out), beta2_pow_(beta2_pow), - beta2_pow_out_(beta2_pow_out), moment1_(mom1), moment1_out_(mom1_out), moment2_(mom2), @@ -246,10 +226,6 @@ struct SparseLambMomentREGUpdateFunctor { trust_ratio_div_[i] = mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p; - if (beta1_pow_out_ && beta1_pow_out_) { - beta1_pow_out_[0] = beta1_pow * beta1_; - beta2_pow_out_[0] = beta2_pow * beta2_; - } } inline HOSTDEVICE void operator()(size_t i) const { @@ -270,9 +246,7 @@ struct SparseLambMomentMENUpdateFunctor { T epsilon_; const T* beta1_pow_; - T* beta1_pow_out_; const T* beta2_pow_; - T* beta2_pow_out_; const T* moment1_; T* moment1_out_; const T* moment2_; @@ -288,8 +262,7 @@ struct SparseLambMomentMENUpdateFunctor { const bool* skip_update_; SparseLambMomentMENUpdateFunctor(T weight_decay, T beta1, T beta2, T epsilon, - const T* beta1_pow, T* beta1_pow_out, - const T* beta2_pow, T* beta2_pow_out, + const T* beta1_pow, const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2, T* mom2_out, const T* grad, const T* param, T* trust_ratio_div, const int64_t* rows, @@ -300,9 +273,7 @@ struct SparseLambMomentMENUpdateFunctor { beta2_(beta2), epsilon_(epsilon), beta1_pow_(beta1_pow), - beta1_pow_out_(beta1_pow_out), beta2_pow_(beta2_pow), - beta2_pow_out_(beta2_pow_out), moment1_(mom1), moment1_out_(mom1_out), moment2_(mom2), @@ -334,10 +305,6 @@ struct SparseLambMomentMENUpdateFunctor { trust_ratio_div_[i] = mom1_unbiased / (Eigen::numext::sqrt(mom2_unbiased) + epsilon_) + weight_decay_ * p; - if (beta1_pow_out_ && beta1_pow_out_) { - beta1_pow_out_[0] = beta1_pow * beta1_; - beta2_pow_out_[0] = beta2_pow * beta2_; - } } inline HOSTDEVICE void operator()(size_t i) const { @@ -350,11 +317,44 @@ struct SparseLambMomentMENUpdateFunctor { } }; -template -struct LambParamUpateFunctor { - using MT = typename std::conditional< - IsMultiPrecision, typename details::MPTypeTrait::Type, T>::type; +template +struct LambBetaPowUpdateFunctor { + void SetBetaPows(const MT* beta1pow, const MT* beta2pow, MT* beta1pow_out, + MT* beta2pow_out, MT beta1, MT beta2) { + beta1pow_ = beta1pow; + beta2pow_ = beta2pow; + beta1pow_out_ = beta1pow_out; + beta2pow_out_ = beta2pow_out; + beta1_ = beta1; + beta2_ = beta2; + } + HOSTDEVICE void UpdateBetaPow(size_t i) const { + if (i == 0) { + beta1pow_out_[0] = beta1pow_[0] * beta1_; + beta2pow_out_[0] = beta2pow_[0] * beta2_; + } + } + + private: + const MT* beta1pow_; + const MT* beta2pow_; + MT* beta1pow_out_; + MT* beta2pow_out_; + MT beta1_; + MT beta2_; +}; + +template +struct LambBetaPowUpdateFunctor { + void SetBetaPows(const MT* beta1pow, const MT* beta2pow, MT* beta1pow_out, + MT* beta2pow_out, MT beta1, MT beta2) {} + HOSTDEVICE void UpdateBetaPow(size_t) const {} +}; + +template +struct LambParamUpateFunctor + : public LambBetaPowUpdateFunctor { const MT* lr_; const T* param_; const MT* master_param_; @@ -396,6 +396,7 @@ struct LambParamUpateFunctor { if (IsMultiPrecision) { master_param_out_[i] = param_out; } + this->UpdateBetaPow(i); } }; @@ -501,6 +502,11 @@ class LambOpKernel : public framework::OpKernel { : nullptr; // Update moments + bool should_update_beta_pow_later = false; + const MT *beta1_pow_ptr = nullptr, *beta2_pow_ptr = nullptr; + MT *beta1_pow_out_ptr = nullptr, *beta2_pow_out_ptr = nullptr; + VLOG(10) << "Beta1Pow place: " << beta1_pow.place() + << " , Beta2Pow place: " << beta2_pow.place(); if (grad_var->IsType()) { auto& grad = grad_var->Get(); if (platform::is_gpu_place(ctx.GetPlace()) && @@ -508,8 +514,7 @@ class LambOpKernel : public framework::OpKernel { beta2_pow.place() == platform::CPUPlace()) { LambMomentREGUpdateFunctor moment_update_functor( weight_decay, beta1, beta2, epsilon, *beta1_pow.template data(), - nullptr, *beta2_pow.template data(), nullptr, - mom1.template data(), + *beta2_pow.template data(), mom1.template data(), mom1_out.template mutable_data(ctx.GetPlace()), mom2.template data(), mom2_out.template mutable_data(ctx.GetPlace()), @@ -523,12 +528,17 @@ class LambOpKernel : public framework::OpKernel { beta2_pow_out.template mutable_data(platform::CPUPlace())[0] = beta2 * beta2_pow.template data()[0]; } else { + beta1_pow_ptr = beta1_pow.template data(); + beta2_pow_ptr = beta2_pow.template data(); + beta1_pow_out_ptr = + beta1_pow_out.template mutable_data(ctx.GetPlace()); + beta2_pow_out_ptr = + beta2_pow_out.template mutable_data(ctx.GetPlace()); + should_update_beta_pow_later = true; LambMomentMENUpdateFunctor moment_update_functor( - weight_decay, beta1, beta2, epsilon, beta1_pow.template data(), - beta1_pow_out.template mutable_data(ctx.GetPlace()), - beta2_pow.template data(), - beta2_pow_out.template mutable_data(ctx.GetPlace()), - mom1.template data(), + weight_decay, beta1, beta2, epsilon, + static_cast(beta1_pow_ptr), + static_cast(beta2_pow_ptr), mom1.template data(), mom1_out.template mutable_data(ctx.GetPlace()), mom2.template data(), mom2_out.template mutable_data(ctx.GetPlace()), @@ -542,7 +552,12 @@ class LambOpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(IsMultiPrecision, false, platform::errors::Unimplemented( "SelectedRows gradient is not supported when " - "multi_precision=True")); + "multi_precision=True.")); + constexpr bool kIsSameType = std::is_same::value; + PADDLE_ENFORCE_EQ(kIsSameType, true, + platform::errors::Unimplemented( + "SelectedRows gradient is not supported when " + "multi_precision=True.")); auto& grad = GET_DATA_SAFELY(ctx.Input("Grad"), "Input", "Grad", "Lamb"); if (grad.rows().size() == 0) { @@ -582,8 +597,8 @@ class LambOpKernel : public framework::OpKernel { SparseLambMomentREGUpdateFunctor moment_update_functor( static_cast(weight_decay), static_cast(beta1), static_cast(beta2), static_cast(epsilon), - *beta1_pow.template data(), nullptr, - *beta2_pow.template data(), nullptr, mom1.template data(), + *beta1_pow.template data(), *beta2_pow.template data(), + mom1.template data(), mom1_out.template mutable_data(ctx.GetPlace()), mom2.template data(), mom2_out.template mutable_data(ctx.GetPlace()), grad_data, @@ -595,14 +610,18 @@ class LambOpKernel : public framework::OpKernel { beta2_pow_out.template mutable_data(platform::CPUPlace())[0] = static_cast(beta2) * beta2_pow.template data()[0]; } else { + beta1_pow_ptr = beta1_pow.template data(); + beta2_pow_ptr = beta2_pow.template data(); + beta1_pow_out_ptr = + beta1_pow_out.template mutable_data(ctx.GetPlace()); + beta2_pow_out_ptr = + beta2_pow_out.template mutable_data(ctx.GetPlace()); + should_update_beta_pow_later = true; SparseLambMomentMENUpdateFunctor moment_update_functor( static_cast(weight_decay), static_cast(beta1), static_cast(beta2), static_cast(epsilon), - beta1_pow.template data(), - beta1_pow_out.template mutable_data(ctx.GetPlace()), - beta2_pow.template data(), - beta2_pow_out.template mutable_data(ctx.GetPlace()), - mom1.template data(), + reinterpret_cast(beta1_pow_ptr), + reinterpret_cast(beta2_pow_ptr), mom1.template data(), mom1_out.template mutable_data(ctx.GetPlace()), mom2.template data(), mom2_out.template mutable_data(ctx.GetPlace()), grad_data, @@ -639,14 +658,31 @@ class LambOpKernel : public framework::OpKernel { } trust_ratio_div_norm.device(*place) = t.square().sum().sqrt(); - LambParamUpateFunctor param_update_functor( - lr.template data(), static_cast(param_ptr), - static_cast(master_param_ptr), p_norm_t.template data(), - trust_ratio_div.template data(), - trust_ratio_div_norm_t.template data(), - static_cast(param_out_ptr), static_cast(master_param_out_ptr), - skip_update_flag); - for_range(param_update_functor); +#define CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(__should_update_beta_pow) \ + do { \ + LambParamUpateFunctor \ + param_update_functor( \ + lr.template data(), static_cast(param_ptr), \ + static_cast(master_param_ptr), \ + p_norm_t.template data(), trust_ratio_div.template data(), \ + trust_ratio_div_norm_t.template data(), \ + static_cast(param_out_ptr), \ + static_cast(master_param_out_ptr), skip_update_flag); \ + if (__should_update_beta_pow) { \ + param_update_functor.SetBetaPows(beta1_pow_ptr, beta2_pow_ptr, \ + beta1_pow_out_ptr, beta2_pow_out_ptr, \ + beta1, beta2); \ + } \ + for_range(param_update_functor); \ + } while (0) + + if (should_update_beta_pow_later) { + CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(true); + } else { + CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC(false); + } + +#undef CALL_PADDLE_UPDATE_LAMB_PARAM_FUNC } }; -- GitLab