diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h index 3c4148ccc0a7d5e2a2ef095cf8f639db09be0fc2..9cc34bdded780e61e8700eb4fa4a295c84fb48bc 100644 --- a/paddle/operators/adam_op.h +++ b/paddle/operators/adam_op.h @@ -124,19 +124,20 @@ struct SparseAdamFunctor { row_numel_(row_numel) {} inline HOSTDEVICE void operator()(size_t i) const { + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; for (int64_t j = 0; j < row_numel_; ++j) { T g = grad_[i * row_numel_ + j]; T mom1 = moment1_[rows_[i] * row_numel_ + j]; T mom2 = moment2_[rows_[i] * row_numel_ + j]; T lr = *lr_; - T beta1_pow = *beta1_pow_; - T beta2_pow = *beta2_pow_; T p = param_[rows_[i] * row_numel_ + j]; lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + moment1_out_[rows_[i] * row_numel_ + j] = mom1; moment2_out_[rows_[i] * row_numel_ + j] = mom2; param_out_[rows_[i] * row_numel_ + j] = p;