diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index c9e27b754728a2a10d9c5c60ff123c347401d2f2..fa51c66eeb71bdeeea592db0f3360cad9fe0ae8b 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -279,26 +279,42 @@ struct SparseAdamFunctor { T beta1_pow = *beta1_pow_; T beta2_pow = *beta2_pow_; lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); - for (size_t i = 0U, j = 0U; i != numel; ++i) { - T mom1 = moment1_[i]; - T mom2 = moment2_[i]; - T p = param_[i]; + size_t row_count = numel / row_numel_; - // Calculation + for (size_t i = 0U, j = 0U; i != row_count; ++i) { if (i == *(rows_ + j)) { - T g = grad_[j * row_numel_]; - mom1 = beta1_ * mom1 + (1 - beta1_) * g; - mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + for (size_t k = 0U; k != row_numel_; ++k) { + T mom1 = moment1_[i * row_numel_ + k]; + T mom2 = moment2_[i * row_numel_ + k]; + T p = param_[i * row_numel_ + k]; + + T g = grad_[j * row_numel_ + k]; + mom1 = beta1_ * mom1 + (1 - beta1_) * g; + mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + // Write back to global memory + moment1_out_[i * row_numel_ + k] = mom1; + moment2_out_[i * row_numel_ + k] = mom2; + param_out_[i * row_numel_ + k] = p; + } ++j; } else { - mom1 = beta1_ * mom1; - mom2 = beta2_ * mom2; + for (size_t k = 0U; k != row_numel_; ++k) { + T mom1 = moment1_[i * row_numel_ + k]; + T mom2 = moment2_[i * row_numel_ + k]; + T p = param_[i * row_numel_ + k]; + + mom1 = beta1_ * mom1; + mom2 = beta2_ * mom2; + + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + // Write back to global memory + moment1_out_[i * row_numel_ + k] = mom1; + moment2_out_[i * row_numel_ + k] = mom2; + param_out_[i * row_numel_ + k] = p; + } } - p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); - // Write back to global memory - moment1_out_[i] = mom1; - moment2_out_[i] = mom2; - param_out_[i] = p; } } };