提交 3dc29b39 编写于 作者: Q Qiao Longfei

change sparse_update to adam_update

上级 fc6ec6bd
...@@ -202,7 +202,7 @@ struct SparseAdamFunctor { ...@@ -202,7 +202,7 @@ struct SparseAdamFunctor {
row_count_(row_count), row_count_(row_count),
sparse_mode_(sparse_mode) {} sparse_mode_(sparse_mode) {}
inline HOSTDEVICE void sparse_update(size_t i, T g) const { inline HOSTDEVICE void adam_update(size_t i, T g) const {
// The following code is the same as dense // The following code is the same as dense
T mom1 = moment1_[i]; T mom1 = moment1_[i];
T mom2 = moment2_[i]; T mom2 = moment2_[i];
...@@ -228,7 +228,7 @@ struct SparseAdamFunctor { ...@@ -228,7 +228,7 @@ struct SparseAdamFunctor {
auto row_idx = auto row_idx =
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_); math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
sparse_update(i, g); adam_update(i, g);
} }
}; };
...@@ -364,7 +364,7 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -364,7 +364,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
for (size_t offset = 0; offset < row_numel; ++offset) { for (size_t offset = 0; offset < row_numel; ++offset) {
size_t i = rows[row_index] * row_numel + offset; size_t i = rows[row_index] * row_numel + offset;
T g = grad_data[row_index * row_numel + offset]; T g = grad_data[row_index * row_numel + offset];
functor.sparse_update(i, g); functor.adam_update(i, g);
} }
} }
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册