From 903d5609c61046cfa37280af5506ca21e350b852 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Fri, 29 Dec 2017 14:11:37 +0800 Subject: [PATCH] follow comment1 --- paddle/operators/adam_op.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h index 3c4148ccc..9cc34bdde 100644 --- a/paddle/operators/adam_op.h +++ b/paddle/operators/adam_op.h @@ -124,19 +124,20 @@ struct SparseAdamFunctor { row_numel_(row_numel) {} inline HOSTDEVICE void operator()(size_t i) const { + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; for (int64_t j = 0; j < row_numel_; ++j) { T g = grad_[i * row_numel_ + j]; T mom1 = moment1_[rows_[i] * row_numel_ + j]; T mom2 = moment2_[rows_[i] * row_numel_ + j]; T lr = *lr_; - T beta1_pow = *beta1_pow_; - T beta2_pow = *beta2_pow_; T p = param_[rows_[i] * row_numel_ + j]; lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + moment1_out_[rows_[i] * row_numel_ + j] = mom1; moment2_out_[rows_[i] * row_numel_ + j] = mom2; param_out_[rows_[i] * row_numel_ + j] = p; -- GitLab