提交 c15270c5 编写于 作者: Q Qiao Longfei

optimize multi thread adam

上级 0e747e8d
...@@ -305,13 +305,6 @@ struct SparseAdamFunctor<T, CPUAdam> { ...@@ -305,13 +305,6 @@ struct SparseAdamFunctor<T, CPUAdam> {
param_out_[i] = p; param_out_[i] = p;
} }
inline void update_row(size_t row_id, int grad_row_offset) const {
for (size_t i = 0U; i < row_numel_; ++i) {
T g = grad_row_offset >= 0 ? grad_[grad_row_offset * row_numel_ + i] : 0;
adam_update(row_id * row_numel_ + i, g);
}
}
inline void operator()(size_t numel) const { inline void operator()(size_t numel) const {
// lr could be reuse // lr could be reuse
T lr = *lr_; T lr = *lr_;
...@@ -502,9 +495,6 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -502,9 +495,6 @@ class AdamOpKernel : public framework::OpKernel<T> {
"multi thread, currently " "multi thread, currently "
<< param_row_count; << param_row_count;
} }
for (size_t i = 0; i < param_row_count; ++i) {
row_id_to_grad_row_offset[i] = -1;
}
for (size_t i = 0; i < grad_rows.size(); ++i) { for (size_t i = 0; i < grad_rows.size(); ++i) {
row_id_to_grad_row_offset[grad_rows[i]] = i; row_id_to_grad_row_offset[grad_rows[i]] = i;
} }
...@@ -520,10 +510,24 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -520,10 +510,24 @@ class AdamOpKernel : public framework::OpKernel<T> {
if (end > param_row_count) { if (end > param_row_count) {
end = param_row_count; end = param_row_count;
} }
fs.push_back(framework::Async( fs.push_back(
[&functor, &row_id_to_grad_row_offset, start, end]() { framework::Async([&functor, &row_id_to_grad_row_offset,
for (int64_t i = start; i < end; ++i) { &grad_data, row_numel, start, end]() {
functor.update_row(i, row_id_to_grad_row_offset[i]); for (int64_t row_id = start; row_id < end; ++row_id) {
auto iter = row_id_to_grad_row_offset.find(row_id);
if (iter != row_id_to_grad_row_offset.end()) {
for (size_t row_offset = 0U; row_offset < row_numel;
++row_offset) {
functor.adam_update(
row_id * row_numel + row_offset,
grad_data[iter->second * row_numel + row_offset]);
}
} else {
for (size_t row_offset = 0U; row_offset < row_numel;
++row_offset) {
functor.adam_update(row_id * row_numel + row_offset, 0);
}
}
} }
})); }));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册