diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 3eba268cfa9712e4bc5475dd44076bc768552bce..1a11b584e2bab7eeb395bf391da080ec0ba62ae4 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include @@ -252,23 +253,26 @@ elementwise_add_to(const DeviceContext& ctx, BlasT* blas, template struct MergeAdd { framework::SelectedRows operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input) { + const framework::SelectedRows& input, + const bool sorted_result = false) { framework::SelectedRows out; - (*this)(context, input, &out); + (*this)(context, input, &out, sorted_result); return out; } void operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input, - framework::SelectedRows* output) { + framework::SelectedRows* output, + const bool sorted_result = false) { std::vector inputs; inputs.push_back(&input); - (*this)(context, inputs, output); + (*this)(context, inputs, output, sorted_result); } void operator()(const platform::CPUDeviceContext& context, const std::vector& inputs, - framework::SelectedRows* output) { + framework::SelectedRows* output, + const bool sorted_result = false) { if (inputs.size() == 0) { VLOG(3) << "no input! return"; return; @@ -301,6 +305,9 @@ struct MergeAdd { } std::vector merge_rows(merged_row_set.begin(), merged_row_set.end()); + if (sorted_result) { + std::sort(merge_rows.begin(), merge_rows.end()); + } std::unordered_map rows_to_id; for (size_t i = 0; i < merge_rows.size(); ++i) { rows_to_id[merge_rows[i]] = i; diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index c4fccdbf862fda8a599869c30ae598573ca367aa..0d63f641c8670f8629c52b9e5fc380a250d80dd7 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -266,7 +266,8 @@ __global__ void MergeAddKernel(const T* input, const int64_t* input_rows, template struct MergeAdd { framework::SelectedRows operator()(const platform::CUDADeviceContext& context, - const framework::SelectedRows& input) { + const framework::SelectedRows& input, + const bool sorted_result = false) { framework::SelectedRows out; (*this)(context, input, &out); return out; @@ -274,7 +275,8 @@ struct MergeAdd { void operator()(const platform::CUDADeviceContext& context, const framework::SelectedRows& input, - framework::SelectedRows* output) { + framework::SelectedRows* output, + const bool sorted_result = false) { framework::Vector input_rows(input.rows()); if (input_rows.size() == 0) { return; @@ -312,7 +314,8 @@ struct MergeAdd { void operator()(const platform::CUDADeviceContext& context, const std::vector& inputs, - framework::SelectedRows* output) { + framework::SelectedRows* output, + const bool sorted_result = false) { if (inputs.size() == 0) { VLOG(3) << "no input! return"; return; diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index 6d146d39d6d07678e859b82b25ba60ed7661546d..222d761ef91d8aee4843d717dabba7edf131f8dc 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -81,13 +81,16 @@ struct MergeAdd { // unary functor, merge by adding duplicated rows in // the input SelectedRows object. framework::SelectedRows operator()(const DeviceContext& context, - const framework::SelectedRows& input); + const framework::SelectedRows& input, + const bool sorted_result = false); void operator()(const DeviceContext& context, const framework::SelectedRows& input, - framework::SelectedRows* output); + framework::SelectedRows* output, + const bool sorted_result = false); void operator()(const DeviceContext& context, const std::vector& inputs, - framework::SelectedRows* output); + framework::SelectedRows* output, + const bool sorted_result = false); }; enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY }; diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index f214d8272f5cc5f1cb2e32c9bb59ca60a1066500..1138bb7400e0e7a00983e7bfaad2b2d9704b77ab 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -157,8 +157,11 @@ struct AdamFunctor { } }; +template +struct SparseAdamFunctor; + template -struct SparseAdamFunctor { +struct SparseAdamFunctor { T beta1_; T beta2_; T epsilon_; @@ -236,6 +239,106 @@ struct SparseAdamFunctor { } }; +template +struct SparseAdamFunctor { + T beta1_; + T beta2_; + T epsilon_; + + const T* beta1_pow_; + const T* beta2_pow_; + const T* moment1_; + T* moment1_out_; + const T* moment2_; + T* moment2_out_; + const T* lr_; + const T* grad_; + const T* param_; + T* param_out_; + + const int64_t* rows_; + int64_t row_numel_; + int64_t row_count_; + + SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow, + const T* beta2_pow, const T* mom1, T* mom1_out, + const T* mom2, T* mom2_out, const T* lr, const T* grad, + const T* param, T* param_out, const int64_t* rows, + int64_t row_numel, int64_t row_count, bool lazy_mode) + : beta1_(beta1), + beta2_(beta2), + epsilon_(epsilon), + beta1_pow_(beta1_pow), + beta2_pow_(beta2_pow), + moment1_(mom1), + moment1_out_(mom1_out), + moment2_(mom2), + moment2_out_(mom2_out), + lr_(lr), + grad_(grad), + param_(param), + param_out_(param_out), + rows_(rows), + row_numel_(row_numel), + row_count_(row_count) {} + + inline HOSTDEVICE void adam_update(size_t i, T g) const { + // The following code is the same as dense + T mom1 = moment1_[i]; + T mom2 = moment2_[i]; + T lr = *lr_; + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; + T p = param_[i]; + + // Calculation + lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + + mom1 = beta1_ * mom1 + (1 - beta1_) * g; + mom2 = beta2_ * mom2 + (1 - beta2_) * g * g; + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + + // Write back to global memory + moment1_out_[i] = mom1; + moment2_out_[i] = mom2; + param_out_[i] = p; + } + + inline void operator()(size_t numel) const { + // lr could be reuse + T lr = *lr_; + T beta1_pow = *beta1_pow_; + T beta2_pow = *beta2_pow_; + lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow); + size_t row_count = numel / row_numel_; + + for (size_t i = 0U, j = 0U; i != row_count; ++i) { + if (i == *(rows_ + j)) { + for (size_t k = 0U; k != row_numel_; ++k) { + T g = grad_[j * row_numel_ + k]; + adam_update(i * row_numel_ + k, g); + } + ++j; + } else { + for (size_t k = 0U; k != row_numel_; ++k) { + T mom1 = moment1_[i * row_numel_ + k]; + T mom2 = moment2_[i * row_numel_ + k]; + T p = param_[i * row_numel_ + k]; + + mom1 = beta1_ * mom1; + mom2 = beta2_ * mom2; + + p -= lr * (mom1 / (sqrt(mom2) + epsilon_)); + // Write back to global memory + moment1_out_[i * row_numel_ + k] = mom1; + moment2_out_[i * row_numel_ + k] = mom2; + param_out_[i * row_numel_ + k] = p; + } + } + } + } +}; + template class AdamOpKernel : public framework::OpKernel { public: @@ -331,7 +434,7 @@ class AdamOpKernel : public framework::OpKernel { .Var() ->GetMutable(); merge_func(ctx.template device_context(), grad, - grad_merge_var); + grad_merge_var, true); grad_merge_ptr = grad_merge_var; } @@ -347,32 +450,46 @@ class AdamOpKernel : public framework::OpKernel { } else { #endif rows = grad_merge.rows().data(); - #if defined(PADDLE_WITH_CUDA) } #endif auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); - SparseAdamFunctor functor( - beta1, beta2, epsilon, beta1_pow.template data(), - beta2_pow.template data(), mom1.template data(), - mom1_out.template mutable_data(ctx.GetPlace()), - mom2.template data(), - mom2_out.template mutable_data(ctx.GetPlace()), - lr.template data(), grad_data, param.template data(), - param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, - grad_merge.rows().size(), lazy_mode); - VLOG(3) << "lazy_mode :" << lazy_mode; - if (lazy_mode && platform::is_cpu_place(ctx.GetPlace())) { - size_t row_count = grad_merge.rows().size(); - std::vector cpu_rows(grad_merge.rows()); - for (size_t row_index = 0; row_index < row_count; ++row_index) { - for (size_t offset = 0; offset < row_numel; ++offset) { - size_t i = cpu_rows[row_index] * row_numel + offset; - functor.adam_update(i, grad_data[row_index * row_numel + offset]); + if (platform::is_cpu_place(ctx.GetPlace())) { + SparseAdamFunctor functor( + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad_data, param.template data(), + param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, + grad_merge.rows().size(), lazy_mode); + + if (lazy_mode) { + size_t row_count = grad_merge.rows().size(); + std::vector cpu_rows(grad_merge.rows()); + for (size_t row_index = 0; row_index < row_count; ++row_index) { + for (size_t offset = 0; offset < row_numel; ++offset) { + size_t i = cpu_rows[row_index] * row_numel + offset; + functor.adam_update(i, grad_data[row_index * row_numel + offset]); + } } + } else { + functor(param.numel()); } - } else { + } else if (platform::is_gpu_place(ctx.GetPlace())) { + SparseAdamFunctor functor( + beta1, beta2, epsilon, beta1_pow.template data(), + beta2_pow.template data(), mom1.template data(), + mom1_out.template mutable_data(ctx.GetPlace()), + mom2.template data(), + mom2_out.template mutable_data(ctx.GetPlace()), + lr.template data(), grad_data, param.template data(), + param_out.template mutable_data(ctx.GetPlace()), rows, row_numel, + grad_merge.rows().size(), lazy_mode); + + // FIXME(minqiyang): remove BinarySearch in GPU later platform::ForRange for_range( static_cast(ctx.device_context()), param.numel());