提交 c15270c5 编写于 作者: Q Qiao Longfei

optimize multi thread adam

上级 0e747e8d
......@@ -305,13 +305,6 @@ struct SparseAdamFunctor<T, CPUAdam> {
param_out_[i] = p;
}
inline void update_row(size_t row_id, int grad_row_offset) const {
for (size_t i = 0U; i < row_numel_; ++i) {
T g = grad_row_offset >= 0 ? grad_[grad_row_offset * row_numel_ + i] : 0;
adam_update(row_id * row_numel_ + i, g);
}
}
inline void operator()(size_t numel) const {
// lr could be reuse
T lr = *lr_;
......@@ -502,9 +495,6 @@ class AdamOpKernel : public framework::OpKernel<T> {
"multi thread, currently "
<< param_row_count;
}
for (size_t i = 0; i < param_row_count; ++i) {
row_id_to_grad_row_offset[i] = -1;
}
for (size_t i = 0; i < grad_rows.size(); ++i) {
row_id_to_grad_row_offset[grad_rows[i]] = i;
}
......@@ -520,10 +510,24 @@ class AdamOpKernel : public framework::OpKernel<T> {
if (end > param_row_count) {
end = param_row_count;
}
fs.push_back(framework::Async(
[&functor, &row_id_to_grad_row_offset, start, end]() {
for (int64_t i = start; i < end; ++i) {
functor.update_row(i, row_id_to_grad_row_offset[i]);
fs.push_back(
framework::Async([&functor, &row_id_to_grad_row_offset,
&grad_data, row_numel, start, end]() {
for (int64_t row_id = start; row_id < end; ++row_id) {
auto iter = row_id_to_grad_row_offset.find(row_id);
if (iter != row_id_to_grad_row_offset.end()) {
for (size_t row_offset = 0U; row_offset < row_numel;
++row_offset) {
functor.adam_update(
row_id * row_numel + row_offset,
grad_data[iter->second * row_numel + row_offset]);
}
} else {
for (size_t row_offset = 0U; row_offset < row_numel;
++row_offset) {
functor.adam_update(row_id * row_numel + row_offset, 0);
}
}
}
}));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册