diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/adam_op.h index 84a584f424823a450effd4c36e9da600f5851da2..5b27068c9e805146b8bce03f4f676ef0d4d16c53 100644 --- a/paddle/fluid/operators/adam_op.h +++ b/paddle/fluid/operators/adam_op.h @@ -174,12 +174,13 @@ struct SparseAdamFunctor { const int64_t* rows_; int64_t row_numel_; + int64_t row_count_; SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2, T* mom2_out, const T* lr, const T* grad, const T* param, T* param_out, const int64_t* rows, - int64_t row_numel) + int64_t row_numel, int64_t row_count) : beta1_(beta1), beta2_(beta2), epsilon_(epsilon), @@ -194,28 +195,47 @@ struct SparseAdamFunctor { param_(param), param_out_(param_out), rows_(rows), - row_numel_(row_numel) {} + row_numel_(row_numel), + row_count_(row_count) {} + + inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const { + int64_t beg = 0, end = row_count_ - 1; + while (beg <= end) { + auto mid = ((beg + end) >> 1); + if (rows_[mid] == row) + return mid; + else if (rows_[mid] < row) + beg = mid + 1; + else + end = mid - 1; + } + return -1; + } inline HOSTDEVICE void operator()(size_t i) const { + int64_t row = i / row_numel_; + auto row_idx = BinarySearchInRows(row); + T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; + + // The following code is the same as dense + T mom1 = moment1_[i]; + T mom2 = moment2_[i]; + T lr = *lr_; T beta1_pow = *beta1_pow_; T beta2_pow = *beta2_pow_; - for (int64_t j = 0; j < row_numel_; ++j) { - T g = grad_[i * row_numel_ + j]; - T mom1 = moment1_[rows_[i] * row_numel_ + j]; - T mom2 = moment2_[rows_[i] * row_numel_ + j]; - T lr = *lr_; - T p = param_[rows_[i] * row_numel_ + j]; - - lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); - - mom1 = beta1_ * mom1 + (1 - beta1_) * g; - mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; - p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); - - moment1_out_[rows_[i] * row_numel_ + j] = mom1; - moment2_out_[rows_[i] * row_numel_ + j] = mom2; - param_out_[rows_[i] * row_numel_ + j] = p; - } // for col id + T p = param_[i]; + + // Calculation + lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + + mom1 = beta1_ * mom1 + (1 - beta1_) * g; + mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + + // Write back to global memory + moment1_out_[i] = mom1; + moment2_out_[i] = mom2; + param_out_[i] = p; } }; @@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel { return; } // merge duplicated rows if any. + // The rows of grad_merge have been sorted inside MergeAdd functor scatter::MergeAdd merge_func; - auto grad_merge = - merge_func(ctx.template device_context(), grad); + auto& grad_merge = *(ctx.scope() + .NewScope() + .Var("sparse_adam_grad_merge") + ->GetMutable()); + merge_func(ctx.template device_context(), grad, + &grad_merge); auto& grad_tensor = grad_merge.value(); const T* grad_data = grad_tensor.template data(); int64_t* rows = nullptr; @@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel { mom2.template data(), mom2_out.template mutable_data(ctx.GetPlace()), lr.template data(), grad_data, param.template data(), - param_out.template mutable_data(ctx.GetPlace()), rows, row_numel); + param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, + grad_merge.rows().size()); platform::ForRange for_range( static_cast(ctx.device_context()), - grad_merge.rows().size()); + param.numel()); for_range(functor); } else { PADDLE_THROW("Variable type not supported by adam_op"); diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index a830dc5250a6aea7e622da4046b512d0c7c5d6f9..8e8baf49b2330e95ff1a868b0b0a03bc10d84484 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -199,6 +199,14 @@ struct MergeAdd { framework::SelectedRows operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input) { framework::SelectedRows out; + (*this)(context, input, &out); + return out; + } + + void operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* output) { + framework::SelectedRows& out = *output; auto input_rows = input.rows(); std::set row_set(input_rows.begin(), input_rows.end()); std::vector merge_rows(row_set.begin(), row_set.end()); @@ -223,7 +231,6 @@ struct MergeAdd { out_data[out_i * input_width + j] += input_data[i * input_width + j]; } } - return out; } }; diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index a92762c7fea865fad2c7784736cce93a8af21892..94258f662993f0f9eee8978773d7925cdb26744c 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -262,6 +262,14 @@ struct MergeAdd { framework::SelectedRows operator()(const platform::CUDADeviceContext& context, const framework::SelectedRows& input) { framework::SelectedRows out; + (*this)(context, input, &out); + return out; + } + + void operator()(const platform::CUDADeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* output) { + framework::SelectedRows& out = *output; framework::Vector input_rows(input.rows()); std::set row_set(input_rows.begin(), input_rows.end()); std::vector merge_rows(row_set.begin(), row_set.end()); @@ -292,7 +300,6 @@ struct MergeAdd { input_data, input_rows.CUDAData(context.GetPlace()), out_data, out.mutable_rows()->CUDAMutableData(context.GetPlace()), out.rows().size(), input_width); - return out; } }; diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index 18304f83f8706f822ce628e2374b00a71f1cc171..aa419f74fcd2a53cdd734ec270bc154b78c9f2ff 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -65,6 +65,9 @@ struct MergeAdd { // the input SelectedRows object. framework::SelectedRows operator()(const DeviceContext& context, const framework::SelectedRows& input); + void operator()(const DeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* output); }; template