diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 0c2e6d40241218a6697b896262608b3a002e0571..1a11b584e2bab7eeb395bf391da080ec0ba62ae4 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -253,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT* blas, template struct MergeAdd { framework::SelectedRows operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input) { + const framework::SelectedRows& input, + const bool sorted_result = false) { framework::SelectedRows out; - (*this)(context, input, &out); + (*this)(context, input, &out, sorted_result); return out; } void operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input, - framework::SelectedRows* output) { + framework::SelectedRows* output, + const bool sorted_result = false) { std::vector inputs; inputs.push_back(&input); - (*this)(context, inputs, output); + (*this)(context, inputs, output, sorted_result); } void operator()(const platform::CPUDeviceContext& context, const std::vector& inputs, - framework::SelectedRows* output) { + framework::SelectedRows* output, + const bool sorted_result = false) { if (inputs.size() == 0) { VLOG(3) << "no input! return"; return; @@ -302,8 +305,8 @@ struct MergeAdd { } std::vector merge_rows(merged_row_set.begin(), merged_row_set.end()); - if (sorted_result_) { - std::sort(merge_rows); + if (sorted_result) { + std::sort(merge_rows.begin(), merge_rows.end()); } std::unordered_map rows_to_id; for (size_t i = 0; i < merge_rows.size(); ++i) { diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index c4fccdbf862fda8a599869c30ae598573ca367aa..b87c9461e883492d3bf658c69307ad2fe4ef0a0f 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -266,7 +266,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, template struct MergeAdd { framework::SelectedRows operator()(const platform::CUDADeviceContext& context, - const framework::SelectedRows& input) { + const framework::SelectedRows& input, + const bool sorted_result = false) { framework::SelectedRows out; (*this)(context, input, &out); return out; diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index b7b19f130e5269629c27ae51b5b346ef1a789e3e..222d761ef91d8aee4843d717dabba7edf131f8dc 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -78,23 +78,19 @@ namespace scatter { // functors for manuplating SelectedRows data template struct MergeAdd { - MergeAdd() : sorted_result_(false) {} - - explicit MergeAdd(bool sorted_result) : sorted_result_(sorted_result) {} - // unary functor, merge by adding duplicated rows in // the input SelectedRows object. framework::SelectedRows operator()(const DeviceContext& context, - const framework::SelectedRows& input); + const framework::SelectedRows& input, + const bool sorted_result = false); void operator()(const DeviceContext& context, const framework::SelectedRows& input, - framework::SelectedRows* output); + framework::SelectedRows* output, + const bool sorted_result = false); void operator()(const DeviceContext& context, const std::vector& inputs, - framework::SelectedRows* output); - - private: - bool sorted_result_; + framework::SelectedRows* output, + const bool sorted_result = false); }; enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index c2bf7040d77315030b5108edb64dd1bb914d81d8..c9e27b754728a2a10d9c5c60ff123c347401d2f2 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -157,6 +157,9 @@ struct AdamFunctor { } }; +template +struct SparseAdamFunctor; + template struct SparseAdamFunctor { T beta1_; @@ -283,6 +286,7 @@ struct SparseAdamFunctor { // Calculation if (i == *(rows_ + j)) { + T g = grad_[j * row_numel_]; mom1 = beta1_ * mom1 + (1 - beta1_) * g; mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; ++j; @@ -388,12 +392,12 @@ class AdamOpKernel : public framework::OpKernel { } else { // merge duplicated rows if any. // The rows of grad_merge have been sorted inside MergeAdd functor - scatter::MergeAdd merge_func(true); + scatter::MergeAdd merge_func; auto* grad_merge_var = const_cast(ctx.scope()) .Var() ->GetMutable(); merge_func(ctx.template device_context(), grad, - grad_merge_var); + grad_merge_var, true); grad_merge_ptr = grad_merge_var; }