From c15270c5b20d31bff04bd66bbc8f37f188213d72 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sun, 6 Jan 2019 15:50:26 +0800 Subject: [PATCH] optimize multi thread adam --- paddle/fluid/operators/optimizers/adam_op.h | 32 ++++++++++++--------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index 1f0dbedcf..b84d63f51 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -305,13 +305,6 @@ struct SparseAdamFunctor { param_out_[i] = p; } - inline void update_row(size_t row_id, int grad_row_offset) const { - for (size_t i = 0U; i < row_numel_; ++i) { - T g = grad_row_offset >= 0 ? grad_[grad_row_offset * row_numel_ + i] : 0; - adam_update(row_id * row_numel_ + i, g); - } - } - inline void operator()(size_t numel) const { // lr could be reuse T lr = *lr_; @@ -502,9 +495,6 @@ class AdamOpKernel : public framework::OpKernel { "multi thread, currently " << param_row_count; } - for (size_t i = 0; i < param_row_count; ++i) { - row_id_to_grad_row_offset[i] = -1; - } for (size_t i = 0; i < grad_rows.size(); ++i) { row_id_to_grad_row_offset[grad_rows[i]] = i; } @@ -520,10 +510,24 @@ class AdamOpKernel : public framework::OpKernel { if (end > param_row_count) { end = param_row_count; } - fs.push_back(framework::Async( - [&functor, &row_id_to_grad_row_offset, start, end]() { - for (int64_t i = start; i < end; ++i) { - functor.update_row(i, row_id_to_grad_row_offset[i]); + fs.push_back( + framework::Async([&functor, &row_id_to_grad_row_offset, + &grad_data, row_numel, start, end]() { + for (int64_t row_id = start; row_id < end; ++row_id) { + auto iter = row_id_to_grad_row_offset.find(row_id); + if (iter != row_id_to_grad_row_offset.end()) { + for (size_t row_offset = 0U; row_offset < row_numel; + ++row_offset) { + functor.adam_update( + row_id * row_numel + row_offset, + grad_data[iter->second * row_numel + row_offset]); + } + } else { + for (size_t row_offset = 0U; row_offset < row_numel; + ++row_offset) { + functor.adam_update(row_id * row_numel + row_offset, 0); + } + } } })); } -- GitLab