diff --git a/paddle/fluid/operators/optimizers/adamw_op.cu b/paddle/fluid/operators/optimizers/adamw_op.cu index 49b7fe771be1331772f6b5918fd81b7164f5790f..8b152bc67a30bd4eaa3d2cc99b846943e39bac06 100644 --- a/paddle/fluid/operators/optimizers/adamw_op.cu +++ b/paddle/fluid/operators/optimizers/adamw_op.cu @@ -27,25 +27,25 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff, T* param_out, const MT* master_param, MT* master_param_out, int ndim) { MT lr = *lr_ * lr_ratio; - MT lr_orig = lr; MT beta1_pow = beta1_pow_; MT beta2_pow = beta2_pow_; - lr *= sqrt(static_cast(1.0) - beta2_pow) / - (static_cast(1.0) - beta1_pow); - int id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { MT p = master_param ? master_param[id] : static_cast(param[id]); MT g = static_cast(grad[id]); - MT mom1 = moment1[id]; - MT mom2 = moment2[id]; + MT mom1 = static_cast(moment1[id]); + MT mom2 = static_cast(moment2[id]); + + p *= (static_cast(1.0) - lr * coeff); + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - p -= lr_orig * coeff * p; - p -= lr * (mom1 / - (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); + + MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + + p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); moment1_out[id] = mom1; moment2_out[id] = mom2; @@ -63,13 +63,9 @@ __global__ void AdamWKernelMEM( MT* moment2_out, const MT* lr_, const T* grad, const T* param, T* param_out, const MT* master_param, MT* master_param_out, int ndim) { MT lr = *lr_ * lr_ratio; - MT lr_orig = lr; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; - lr *= sqrt(static_cast(1.0) - beta2_pow) / - (static_cast(1.0) - beta1_pow); - int id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { @@ -77,11 +73,15 @@ __global__ void AdamWKernelMEM( MT g = static_cast(grad[id]); MT mom1 = static_cast(moment1[id]); MT mom2 = static_cast(moment2[id]); + + p *= (static_cast(1.0) - lr * coeff); + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - p -= lr_orig * coeff * p; - p -= lr * (mom1 / - (sqrt(mom2) + epsilon * sqrt(static_cast(1.0) - beta2_pow))); + + MT denom = (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + + p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); moment1_out[id] = mom1; moment2_out[id] = mom2; @@ -109,10 +109,6 @@ __global__ void SparseAdamWCUDAKernelREG( int ndim) { int id = blockIdx.x * blockDim.x + threadIdx.x; MT lr = *lr_ * lr_ratio; - MT lr_orig = lr; - - lr *= sqrt(static_cast(1.0) - beta2_pow) / - (static_cast(1.0) - beta1_pow); for (; id < ndim; id += blockDim.x * gridDim.x) { auto row_idx = @@ -120,17 +116,23 @@ __global__ void SparseAdamWCUDAKernelREG( if (lazy_mode && row_idx < 0) { return; } else { - MT mom1 = mom1_[id]; - MT mom2 = mom2_[id]; + MT mom1 = static_cast(mom1_[id]); + MT mom2 = static_cast(mom2_[id]); + MT p = master_param ? master_param[id] : static_cast(param_[id]); MT g = row_idx >= 0 ? static_cast(grad_[row_idx * row_numel + id % row_numel]) : static_cast(0); + + p *= (static_cast(1.0) - lr * coeff); + mom1 = beta1 * mom1 + (static_cast(1.0) - beta1) * g; mom2 = beta2 * mom2 + (static_cast(1.0) - beta2) * g * g; - p -= lr_orig * coeff * p; - p -= lr * (mom1 / (sqrt(mom2) + - epsilon * sqrt(static_cast(1.0) - beta2_pow))); + + MT denom = + (sqrt(mom2) / sqrt(static_cast(1.0) - beta2_pow)) + epsilon; + + p += (mom1 / denom) * (-(lr / (static_cast(1.0) - beta1_pow))); // Write back to global memory mom1_out_[id] = mom1;