提交 903d5609 编写于 作者: T typhoonzero

follow comment1

上级 1039c1e3
...@@ -124,19 +124,20 @@ struct SparseAdamFunctor { ...@@ -124,19 +124,20 @@ struct SparseAdamFunctor {
row_numel_(row_numel) {} row_numel_(row_numel) {}
inline HOSTDEVICE void operator()(size_t i) const { inline HOSTDEVICE void operator()(size_t i) const {
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
for (int64_t j = 0; j < row_numel_; ++j) { for (int64_t j = 0; j < row_numel_; ++j) {
T g = grad_[i * row_numel_ + j]; T g = grad_[i * row_numel_ + j];
T mom1 = moment1_[rows_[i] * row_numel_ + j]; T mom1 = moment1_[rows_[i] * row_numel_ + j];
T mom2 = moment2_[rows_[i] * row_numel_ + j]; T mom2 = moment2_[rows_[i] * row_numel_ + j];
T lr = *lr_; T lr = *lr_;
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
T p = param_[rows_[i] * row_numel_ + j]; T p = param_[rows_[i] * row_numel_ + j];
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
moment1_out_[rows_[i] * row_numel_ + j] = mom1; moment1_out_[rows_[i] * row_numel_ + j] = mom1;
moment2_out_[rows_[i] * row_numel_ + j] = mom2; moment2_out_[rows_[i] * row_numel_ + j] = mom2;
param_out_[rows_[i] * row_numel_ + j] = p; param_out_[rows_[i] * row_numel_ + j] = p;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册