提交 24eb8f03 编写于 作者: M minqiyang

Fix bug

test=develop
上级 bd0067b2
...@@ -279,26 +279,42 @@ struct SparseAdamFunctor<T, CPUAdam> { ...@@ -279,26 +279,42 @@ struct SparseAdamFunctor<T, CPUAdam> {
T beta1_pow = *beta1_pow_; T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_; T beta2_pow = *beta2_pow_;
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
for (size_t i = 0U, j = 0U; i != numel; ++i) { size_t row_count = numel / row_numel_;
T mom1 = moment1_[i];
T mom2 = moment2_[i];
T p = param_[i];
// Calculation for (size_t i = 0U, j = 0U; i != row_count; ++i) {
if (i == *(rows_ + j)) { if (i == *(rows_ + j)) {
T g = grad_[j * row_numel_]; for (size_t k = 0U; k != row_numel_; ++k) {
mom1 = beta1_ * mom1 + (1 - beta1_) * g; T mom1 = moment1_[i * row_numel_ + k];
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; T mom2 = moment2_[i * row_numel_ + k];
T p = param_[i * row_numel_ + k];
T g = grad_[j * row_numel_ + k];
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// Write back to global memory
moment1_out_[i * row_numel_ + k] = mom1;
moment2_out_[i * row_numel_ + k] = mom2;
param_out_[i * row_numel_ + k] = p;
}
++j; ++j;
} else { } else {
mom1 = beta1_ * mom1; for (size_t k = 0U; k != row_numel_; ++k) {
mom2 = beta2_ * mom2; T mom1 = moment1_[i * row_numel_ + k];
T mom2 = moment2_[i * row_numel_ + k];
T p = param_[i * row_numel_ + k];
mom1 = beta1_ * mom1;
mom2 = beta2_ * mom2;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// Write back to global memory
moment1_out_[i * row_numel_ + k] = mom1;
moment2_out_[i * row_numel_ + k] = mom2;
param_out_[i * row_numel_ + k] = p;
}
} }
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
// Write back to global memory
moment1_out_[i] = mom1;
moment2_out_[i] = mom2;
param_out_[i] = p;
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册