From 584c3f048fcd221be5095575f50f837793f946c0 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Sat, 29 Sep 2018 08:01:02 +0000 Subject: [PATCH] fix sparse rmsprop --- paddle/fluid/operators/adam_op.h | 19 +- paddle/fluid/operators/math/algorithm.h | 44 ++++ paddle/fluid/operators/rmsprop_op.h | 270 ++++++++++++++++++++---- 3 files changed, 276 insertions(+), 57 deletions(-) create mode 100644 paddle/fluid/operators/math/algorithm.h diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/adam_op.h index 4cb1f3a80e9..8d664e3e9ab 100644 --- a/paddle/fluid/operators/adam_op.h +++ b/paddle/fluid/operators/adam_op.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/for_range.h" @@ -199,23 +200,9 @@ struct SparseAdamFunctor { row_numel_(row_numel), row_count_(row_count) {} - inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const { - int64_t beg = 0, end = row_count_ - 1; - while (beg <= end) { - auto mid = ((beg + end) >> 1); - if (rows_[mid] == row) - return mid; - else if (rows_[mid] < row) - beg = mid + 1; - else - end = mid - 1; - } - return -1; - } - inline HOSTDEVICE void operator()(size_t i) const { - int64_t row = i / row_numel_; - auto row_idx = BinarySearchInRows(row); + auto row_idx = + math::BinarySearch(rows_, row_count_, i / row_numel_); T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0; // The following code is the same as dense diff --git a/paddle/fluid/operators/math/algorithm.h b/paddle/fluid/operators/math/algorithm.h new file mode 100644 index 00000000000..262469beea7 --- /dev/null +++ b/paddle/fluid/operators/math/algorithm.h @@ -0,0 +1,44 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include // for int64_t +#include + +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +template +HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) { + int64_t beg = 0, end = num - 1; + while (beg <= end) { + auto mid = ((beg + end) >> 1); + if (x[mid] == val) + return mid; + else if (x[mid] < val) + beg = mid + 1; + else + end = mid - 1; + } + return -1; +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/rmsprop_op.h b/paddle/fluid/operators/rmsprop_op.h index 25ed32c5ebb..406730407d4 100644 --- a/paddle/fluid/operators/rmsprop_op.h +++ b/paddle/fluid/operators/rmsprop_op.h @@ -13,66 +13,254 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; template using EigenVector = framework::EigenVector; +template +struct DenseRmspropGradFunctor { + inline explicit DenseRmspropGradFunctor(const T *grad) : grad_(grad) {} + + HOSTDEVICE inline T operator()(int64_t idx) const { return grad_[idx]; } + + const T *grad_; +}; + +template +struct SparseRmspropGradFunctor { + inline SparseRmspropGradFunctor(const T *grad, const int64_t *rows, + int64_t row_numel, int64_t row_count) + : grad_(grad), + rows_(rows), + row_numel_(row_numel), + row_count_(row_count) {} + + HOSTDEVICE inline T operator()(int64_t idx) const { + auto row_idx = math::BinarySearch(rows_, row_count_, idx / row_numel_); + return row_idx >= 0 ? grad_[row_idx * row_numel_ + idx % row_numel_] : 0; + } + + const T *grad_; + const int64_t *rows_; + int64_t row_numel_; + int64_t row_count_; +}; + +template +struct UncenteredRmspropFunctor { + UncenteredRmspropFunctor(T *param, T *ms, T *mom, const T *lr, T rho, + T epsilon, T momentum, + const GradFunctor &grad_functor) + : param_(param), + ms_(ms), + mom_(mom), + lr_(lr), + rho_(rho), + epsilon_(epsilon), + momentum_(momentum), + grad_functor_(grad_functor) {} + + HOSTDEVICE inline void operator()(int64_t idx) const { + T g = grad_functor_(idx); + T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g; + T mom_out = momentum_ * mom_[idx] + lr_[0] * g / sqrt(ms_out + epsilon_); + param_[idx] -= mom_out; + ms_[idx] = ms_out; + mom_[idx] = mom_out; + } + + T *param_; + T *ms_; + T *mom_; + const T *lr_; + T rho_; + T epsilon_; + T momentum_; + GradFunctor grad_functor_; +}; + +template +struct CenteredRmspropFunctor { + CenteredRmspropFunctor(T *param, T *ms, T *mom, T *mean_grad, const T *lr, + T rho, T epsilon, T momentum, + const GradFunctor &grad_functor) + : param_(param), + ms_(ms), + mom_(mom), + mean_grad_(mean_grad), + lr_(lr), + rho_(rho), + epsilon_(epsilon), + momentum_(momentum), + grad_functor_(grad_functor) {} + + HOSTDEVICE inline void operator()(int64_t idx) const { + T g = grad_functor_(idx); + T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g; + T mg_out = rho_ * mean_grad_[idx] + (1 - rho_) * g; + T mom_out = momentum_ * mom_[idx] + + lr_[0] * g / sqrt(ms_out - mg_out * mg_out + epsilon_); + param_[idx] -= mom_out; + ms_[idx] = ms_out; + mom_[idx] = mom_out; + mean_grad_[idx] = mg_out; + } + + T *param_; + T *ms_; + T *mom_; + T *mean_grad_; + const T *lr_; + T rho_; + T epsilon_; + T momentum_; + GradFunctor grad_functor_; +}; + template class RmspropOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* param_out = ctx.Output("ParamOut"); - auto* moment_out = ctx.Output("MomentOut"); - auto* mean_square_out = ctx.Output("MeanSquareOut"); + void Compute(const framework::ExecutionContext &ctx) const override { + using Tensor = framework::LoDTensor; + auto *grad_var = ctx.InputVar("Grad"); + auto *param_out = ctx.Output("ParamOut"); + auto *moment_out = ctx.Output("MomentOut"); + auto *mean_square_out = ctx.Output("MeanSquareOut"); - auto grad = ctx.Input("Grad"); + auto epsilon = static_cast(ctx.Attr("epsilon")); + auto rho = static_cast(ctx.Attr("decay")); + auto momentum = static_cast(ctx.Attr("momentum")); + bool centered = ctx.Attr("centered"); - param_out->mutable_data(ctx.GetPlace()); - moment_out->mutable_data(ctx.GetPlace()); - mean_square_out->mutable_data(ctx.GetPlace()); + auto &p_tensor = *ctx.Input("Param"); + auto &ms_tensor = *ctx.Input("MeanSquare"); + auto &lr_tensor = *ctx.Input("LearningRate"); + auto &mom_tensor = *ctx.Input("Moment"); - float epsilon = ctx.Attr("epsilon"); - float rho = ctx.Attr("decay"); - float momentum = ctx.Attr("momentum"); - bool centered = ctx.Attr("centered"); + PADDLE_ENFORCE_EQ(&p_tensor, param_out, + "Param and ParamOut must be the same Tensor"); + PADDLE_ENFORCE_EQ(&mom_tensor, moment_out, + "Moment and MomentOut must be the same Tensor"); + PADDLE_ENFORCE_EQ(&ms_tensor, mean_square_out, + "MeanSquare and MeanSquareOut must be the same Tensor"); + + auto &dev_ctx = ctx.template device_context(); + size_t limit = static_cast(ms_tensor.numel()); + + if (grad_var->IsType()) { + auto &grad_tensor = grad_var->Get(); + + if (std::is_same::value) { + auto &place = + *ctx.template device_context().eigen_device(); + auto lr_value = lr_tensor.data()[0]; + + auto p = EigenVector::Flatten(p_tensor); + auto ms = EigenVector::Flatten(ms_tensor); + auto g = EigenVector::Flatten(grad_tensor); + auto mom = EigenVector::Flatten(mom_tensor); + + auto p_out = EigenVector::Flatten(*param_out); + auto mom_out = EigenVector::Flatten(*moment_out); + auto ms_out = EigenVector::Flatten(*mean_square_out); + + ms_out.device(place) = rho * ms + (1 - rho) * g * g; + if (centered) { + auto &mg_tensor = *ctx.Input("MeanGrad"); + auto mg = EigenVector::Flatten(mg_tensor); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + PADDLE_ENFORCE(&mg_tensor, mean_grad_out, + "MeanGrad and MeanGradOut must be the same Tensor"); + auto mg_out = EigenVector::Flatten(*mean_grad_out); + + mg_out.device(place) = rho * mg + (1 - rho) * g; + mom_out.device(place) = + momentum * mom + + lr_value * g / (ms_out - mg_out.square() + epsilon).sqrt(); + } else { + mom_out.device(place) = + momentum * mom + lr_value * g / (ms_out + epsilon).sqrt(); + } + p_out.device(place) = p - mom_out; + } else { + DenseRmspropGradFunctor grad_func(grad_tensor.data()); + platform::ForRange for_range(dev_ctx, limit); + if (centered) { + auto &mg_tensor = *ctx.Input("MeanGrad"); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + PADDLE_ENFORCE(&mg_tensor, mean_grad_out, + "MeanGrad and MeanGradOut must be the same Tensor"); + for_range(CenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), + mean_grad_out->mutable_data(ctx.GetPlace()), + lr_tensor.data(), rho, epsilon, momentum, grad_func)); + } else { + for_range(UncenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), + rho, epsilon, momentum, grad_func)); + } + } + } else if (grad_var->IsType()) { + auto &grad = grad_var->Get(); + auto *merged_grad = const_cast(ctx.scope()) + .Var() + ->GetMutable(); + + math::scatter::MergeAdd merge_func; + merge_func(dev_ctx, grad, merged_grad); + + platform::ForRange for_range(dev_ctx, limit); + const int64_t *rows; +#ifdef PADDLE_WITH_CUDA + if (platform::is_gpu_place(ctx.GetPlace())) { + rows = merged_grad->rows().CUDAData(ctx.GetPlace()); + } else { +#endif + rows = merged_grad->rows().data(); +#ifdef PADDLE_WITH_CUDA + } +#endif + auto &merged_tensor = merged_grad->value(); + int64_t row_count = merged_grad->rows().size(); + int64_t row_numel = merged_tensor.numel() / row_count; + SparseRmspropGradFunctor grad_func(merged_tensor.data(), rows, + row_numel, row_count); - auto p = EigenVector::Flatten(*ctx.Input("Param")); - auto ms = EigenVector::Flatten(*ctx.Input("MeanSquare")); - auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); - auto g = EigenVector::Flatten(*grad); - auto mom = EigenVector::Flatten(*ctx.Input("Moment")); - - auto p_out = EigenVector::Flatten(*param_out); - auto mom_out = EigenVector::Flatten(*moment_out); - auto ms_out = EigenVector::Flatten(*mean_square_out); - auto& place = *ctx.template device_context().eigen_device(); - - Eigen::DSizes grad_dsize(static_cast(grad->numel())); - - ms_out.device(place) = rho * ms + (1 - rho) * g * g; - if (centered) { - auto mg = EigenVector::Flatten(*ctx.Input("MeanGrad")); - auto* mean_grad_out = ctx.Output("MeanGradOut"); - mean_grad_out->mutable_data(ctx.GetPlace()); - auto mg_out = EigenVector::Flatten(*mean_grad_out); - - mg_out.device(place) = rho * mg + (1 - rho) * g; - mom_out.device(place) = momentum * mom + - lr.broadcast(grad_dsize) * g / - (ms_out - mg_out.square() + epsilon).sqrt(); + if (centered) { + auto &mg_tensor = *ctx.Input("MeanGrad"); + auto *mean_grad_out = ctx.Output("MeanGradOut"); + PADDLE_ENFORCE(&mg_tensor, mean_grad_out, + "MeanGrad and MeanGradOut must be the same Tensor"); + for_range(CenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), + mean_grad_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), + rho, epsilon, momentum, grad_func)); + } else { + for_range(UncenteredRmspropFunctor>( + param_out->mutable_data(ctx.GetPlace()), + mean_square_out->mutable_data(ctx.GetPlace()), + moment_out->mutable_data(ctx.GetPlace()), lr_tensor.data(), + rho, epsilon, momentum, grad_func)); + } } else { - mom_out.device(place) = - momentum * mom + - lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt(); + PADDLE_THROW("RMSProp only supports LoDTensor or SelectedRows gradient"); } - p_out.device(place) = p - mom_out; } }; -- GitLab