From 5fea8cd47809f56ef232dd187f480c251898c762 Mon Sep 17 00:00:00 2001 From: minqiyang Date: Fri, 14 Dec 2018 18:26:33 +0800 Subject: [PATCH] Add sorted_result parameter to SelectedRows Functor test=develop --- .../operators/math/selected_rows_functor.cc | 17 ++++++++++------- .../operators/math/selected_rows_functor.cu | 3 ++- .../operators/math/selected_rows_functor.h | 16 ++++++---------- paddle/fluid/operators/optimizers/adam_op.h | 8 ++++++-- 4 files changed, 24 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 0c2e6d402..1a11b584e 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -253,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT* blas, template struct MergeAdd { framework::SelectedRows operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input) { + const framework::SelectedRows& input, + const bool sorted_result = false) { framework::SelectedRows out; - (*this)(context, input, &out); + (*this)(context, input, &out, sorted_result); return out; } void operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input, - framework::SelectedRows* output) { + framework::SelectedRows* output, + const bool sorted_result = false) { std::vector inputs; inputs.push_back(&input); - (*this)(context, inputs, output); + (*this)(context, inputs, output, sorted_result); } void operator()(const platform::CPUDeviceContext& context, const std::vector& inputs, - framework::SelectedRows* output) { + framework::SelectedRows* output, + const bool sorted_result = false) { if (inputs.size() == 0) { VLOG(3) << "no input! return"; return; @@ -302,8 +305,8 @@ struct MergeAdd { } std::vector merge_rows(merged_row_set.begin(), merged_row_set.end()); - if (sorted_result_) { - std::sort(merge_rows); + if (sorted_result) { + std::sort(merge_rows.begin(), merge_rows.end()); } std::unordered_map rows_to_id; for (size_t i = 0; i < merge_rows.size(); ++i) { diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index c4fccdbf8..b87c9461e 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -266,7 +266,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, template struct MergeAdd { framework::SelectedRows operator()(const platform::CUDADeviceContext& context, - const framework::SelectedRows& input) { + const framework::SelectedRows& input, + const bool sorted_result = false) { framework::SelectedRows out; (*this)(context, input, &out); return out; diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index b7b19f130..222d761ef 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -78,23 +78,19 @@ namespace scatter { // functors for manuplating SelectedRows data template struct MergeAdd { - MergeAdd() : sorted_result_(false) {} - - explicit MergeAdd(bool sorted_result) : sorted_result_(sorted_result) {} - // unary functor, merge by adding duplicated rows in // the input SelectedRows object. framework::SelectedRows operator()(const DeviceContext& context, - const framework::SelectedRows& input); + const framework::SelectedRows& input, + const bool sorted_result = false); void operator()(const DeviceContext& context, const framework::SelectedRows& input, - framework::SelectedRows* output); + framework::SelectedRows* output, + const bool sorted_result = false); void operator()(const DeviceContext& context, const std::vector& inputs, - framework::SelectedRows* output); - - private: - bool sorted_result_; + framework::SelectedRows* output, + const bool sorted_result = false); }; enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index c2bf7040d..c9e27b754 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -157,6 +157,9 @@ struct AdamFunctor { } }; +template +struct SparseAdamFunctor; + template struct SparseAdamFunctor { T beta1_; @@ -283,6 +286,7 @@ struct SparseAdamFunctor { // Calculation if (i == *(rows_ + j)) { + T g = grad_[j * row_numel_]; mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; ++j; @@ -388,12 +392,12 @@ class AdamOpKernel : public framework::OpKernel { } else { // merge duplicated rows if any. // The rows of grad_merge have been sorted inside MergeAdd functor - scatter::MergeAdd merge_func(true); + scatter::MergeAdd merge_func; auto* grad_merge_var = const_cast(ctx.scope()) .Var() ->GetMutable(); merge_func(ctx.template device_context(), grad, - grad_merge_var); + grad_merge_var, true); grad_merge_ptr = grad_merge_var; } -- GitLab