From a9f5f822e604e4eb1811617b2fa985a4620c66f7 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Wed, 17 Oct 2018 16:34:52 +0800 Subject: [PATCH] use binary search. test=develop --- paddle/fluid/operators/momentum_op.cc | 11 +- paddle/fluid/operators/momentum_op.cu | 124 +-------- paddle/fluid/operators/momentum_op.h | 371 ++++++++++++++++++++++---- 3 files changed, 335 insertions(+), 171 deletions(-) diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/momentum_op.cc index 257aa7661..fad6f8016 100644 --- a/paddle/fluid/operators/momentum_op.cc +++ b/paddle/fluid/operators/momentum_op.cc @@ -74,9 +74,13 @@ class MomentumOpInferVarType : public framework::VarTypeInference { framework::proto::VarType::SELECTED_ROWS) { block->FindRecursiveOrCreateVar(out_var).SetType( framework::proto::VarType::SELECTED_ROWS); - } else { + } else if (block->FindRecursiveOrCreateVar(input_var).GetType() == + framework::proto::VarType::LOD_TENSOR) { block->FindRecursiveOrCreateVar(out_var).SetType( framework::proto::VarType::LOD_TENSOR); + } else { + PADDLE_THROW( + "Only support LodTensor and SelectedRows, Unexpected Input Type."); } } } @@ -135,5 +139,6 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(momentum, ops::MomentumOp, ops::MomentumOpMaker, paddle::framework::EmptyGradOpMaker, ops::MomentumOpInferVarType); -REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel, - ops::MomentumOpKernel); +REGISTER_OP_CPU_KERNEL( + momentum, ops::MomentumOpKernel, + ops::MomentumOpKernel); diff --git a/paddle/fluid/operators/momentum_op.cu b/paddle/fluid/operators/momentum_op.cu index a336f6e67..b68fec34d 100644 --- a/paddle/fluid/operators/momentum_op.cu +++ b/paddle/fluid/operators/momentum_op.cu @@ -15,125 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/momentum_op.h" -namespace paddle { -namespace operators { - -template -__global__ void MomentumKernel(const T* p, const T* g, const T* v, - const T* learning_rate, const T mu, - const int64_t num, bool use_nesterov, T* p_out, - T* v_out) { - T lr = learning_rate[0]; - if (use_nesterov) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; - i += blockDim.x * gridDim.x) { - T g_val = g[i]; - T v_new = v[i] * mu + g_val; - v_out[i] = v_new; - p_out[i] = p[i] - (g_val + v_new * mu) * lr; - } - } else { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; - i += blockDim.x * gridDim.x) { - T v_new = v[i] * mu + g[i]; - v_out[i] = v_new; - p_out[i] = p[i] - lr * v_new; - } - } -} - -template -__global__ void SparseMomentumKernel(const T* p, const T* g, const T* v, - const T* lr, const T mu, - const int64_t* grad_rows, - const size_t grad_row_numel, - const size_t grad_row_size, - const T use_nesterov, T* p_out, T* v_out) { - for (int i = blockIdx.x; i < grad_row_size; i += gridDim.x) { - for (int j = threadIdx.x; j < grad_row_numel; j += blockDim.x) { - size_t p_i = grad_rows[i] * grad_row_numel + j; - size_t g_i = i * grad_row_numel + j; - v_out[g_i] = v[g_i] * mu + g[g_i]; - if (use_nesterov) { - p_out[p_i] = p[p_i] - (g[g_i] + v_out[g_i] * mu) * lr[0]; - } else { - p_out[p_i] = p[p_i] - v_out[g_i] * lr[0]; - } - } - } -} - -template -class MomentumOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - T mu = static_cast(ctx.Attr("mu")); - bool use_nesterov = ctx.Attr("use_nesterov"); - - auto learning_rate = ctx.Input("LearningRate"); - auto param = ctx.Input("Param"); - auto param_out = ctx.Output("ParamOut"); - auto* velocity_var = ctx.InputVar("Velocity"); - auto* grad_var = ctx.InputVar("Grad"); - - if (grad_var->IsType()) { - PADDLE_ENFORCE(velocity_var->IsType(), - "Unmatched Type of Param and Grad"); - auto velocity = ctx.Input("Velocity"); - auto grad = ctx.Input("Grad"); - auto velocity_out = ctx.Output("VelocityOut"); - T* p_out = param_out->mutable_data(ctx.GetPlace()); - T* v_out = velocity_out->mutable_data(ctx.GetPlace()); - auto* p = param->data(); - auto* v = velocity->data(); - auto* g = grad->data(); - auto* lr = learning_rate->data(); - - const int kThreadPerBlock = 256; - int grid = (param->numel() + kThreadPerBlock - 1) / kThreadPerBlock; - MomentumKernel< - T><<>>( - p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out); - } else if (grad_var->IsType()) { - // sparse update embedding with selectedrows - PADDLE_ENFORCE(velocity_var->IsType(), - "Unmatched Type of Param and Grad"); - auto velocity = ctx.Input("Velocity"); - auto grad = ctx.Input("Grad"); - auto velocity_out = ctx.Output("VelocityOut"); - - // sparse update maybe empty. - if (grad->rows().size() == 0) { - return; - } - PADDLE_ENFORCE(grad->height() == velocity->height(), - "Unmatched gradient and velocity."); - auto* p_out = param_out->mutable_data(ctx.GetPlace()); - auto* v_out = - velocity_out->mutable_value()->mutable_data(ctx.GetPlace()); - auto* lr = learning_rate->data(); - auto* p = param->data(); - auto* g = grad->value().data(); - auto* v = velocity->value().data(); - size_t grad_row_numel = grad->value().numel() / grad->rows().size(); - size_t grad_row_size = grad->rows().size(); - framework::Vector rows(grad->rows()); - - const int kThreadPerBlock = 256; - int grid = (param->numel() + kThreadPerBlock - 1) / kThreadPerBlock; - SparseMomentumKernel< - T><<>>( - p, g, v, lr, mu, rows.CUDAData(ctx.GetPlace()), grad_row_numel, - grad->rows().size(), use_nesterov, p_out, v_out); - } else { - PADDLE_THROW("Unsupported Variable Type of Grad"); - } - } -}; - -} // namespace operators -} // namespace paddle - namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL(momentum, ops::MomentumOpCUDAKernel, - ops::MomentumOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + momentum, ops::MomentumOpKernel, + ops::MomentumOpKernel); diff --git a/paddle/fluid/operators/momentum_op.h b/paddle/fluid/operators/momentum_op.h index aee6d094e..dae74a5ad 100644 --- a/paddle/fluid/operators/momentum_op.h +++ b/paddle/fluid/operators/momentum_op.h @@ -15,11 +15,265 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/algorithm.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" +#include "paddle/fluid/platform/for_range.h" namespace paddle { namespace operators { +using framework::Tensor; +using framework::SelectedRows; +struct NoNesterov; +struct UseNesterov; + +template +class CPUDenseMomentumFunctor { + private: + const Tensor* param; + const Tensor* grad; + const Tensor* velocity; + const Tensor* learning_rate; + const T mu; + const T use_nesterov; + Tensor* param_out; + Tensor* velocity_out; + + public: + CPUDenseMomentumFunctor(const Tensor* param, const Tensor* grad, + const Tensor* velocity, const Tensor* learning_rate, + const T mu, const bool use_nesterov, + Tensor* param_out, Tensor* velocity_out) + : param(param), + grad(grad), + velocity(velocity), + learning_rate(learning_rate), + mu(mu), + use_nesterov(use_nesterov), + param_out(param_out), + velocity_out(velocity_out) {} + + inline void operator()() { + auto p_out = framework::EigenVector::Flatten(*param_out); + auto v_out = framework::EigenVector::Flatten(*velocity_out); + + auto p = framework::EigenVector::Flatten(*param); + auto v = framework::EigenVector::Flatten(*velocity); + auto g = framework::EigenVector::Flatten(*grad); + auto* lr = learning_rate->data(); + + v_out = v * mu + g; + if (use_nesterov) { + p_out = p - (g + v_out * mu) * lr[0]; + } else { + p_out = p - lr[0] * v_out; + } + } +}; + +template +class DenseMomentumFunctor; + +// NOTE(dzh) for performance. +// avoid if/else in inside kernel, implement GPU UseNesterov/NoNesterov as two +// functor. +template +class DenseMomentumFunctor { + private: + const T* p_; + const T* g_; + const T* v_; + const T* lr_; + const T mu_; + const int64_t num_; + T* p_out_; + T* v_out_; + + public: + DenseMomentumFunctor(const T* p, const T* g, const T* v, + const T* learning_rate, const T mu, const int64_t num, + T* p_out, T* v_out) + : p_(p), + g_(g), + v_(v), + lr_(learning_rate), + mu_(mu), + num_(num), + p_out_(p_out), + v_out_(v_out) {} + inline HOSTDEVICE void operator()(size_t i) const { + // put memory access in register + const T p = p_[i]; + const T g = g_[i]; + const T lr = lr_[0]; + const T v = v_[i]; + T v_out = v * mu_ + g; + T p_out = p - (g + v_out * mu_) * lr; + // write reigster to memory + v_out_[i] = v_out; + p_out_[i] = p_out; + } +}; + +template +class DenseMomentumFunctor { + private: + const T* p_; + const T* g_; + const T* v_; + const T* lr_; + const T mu_; + const int64_t num_; + T* p_out_; + T* v_out_; + + public: + DenseMomentumFunctor(const T* p, const T* g, const T* v, + const T* learning_rate, const T mu, const int64_t num, + T* p_out, T* v_out) + : p_(p), + g_(g), + v_(v), + lr_(learning_rate), + mu_(mu), + num_(num), + p_out_(p_out), + v_out_(v_out) {} + inline HOSTDEVICE void operator()(size_t i) const { + // put memory access in register + const T p = p_[i]; + const T g = g_[i]; + const T lr = lr_[0]; + const T v = v_[i]; + T v_out = v * mu_ + g; + T p_out = p - lr * v_out; + // write reigster to memory + v_out_[i] = v_out; + p_out_[i] = p_out; + } +}; + +// TODO(dzh): enhance speed use eigen +// template +// class CPUSparseMomentumFunctor { +// private: +// const T* p_; +// const T* g_; +// const T* v_; +// const T* lr_; +// const T mu_; +// const bool use_nesterov_; +// const int64_t* rows_; +// const int64_t row_numel_; +// const int64_t row_height_; +// T* p_out_; +// T* v_out_; + +// public: +// CPUSparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr, +// const T mu, const bool use_nesterov, const int64_t* rows, const int64_t +// row_numel, const int64_t row_height, T* p_out, T* v_out) :p_(p), g_(g), +// v_(v), lr_(lr), mu_(mu), rows_(rows), row_numel_(row_numel), +// row_height_(row_height), p_out_(p_out), v_out_(v_out) {} +// inline void operator()() { + +// } +// }; + +template +class SparseMomentumFunctor; + template +class SparseMomentumFunctor { + private: + const T* p_; + const T* g_; + const T* v_; + const T* lr_; + const T mu_; + const int64_t* rows_; + const int64_t row_numel_; + const int64_t row_height_; + T* p_out_; + T* v_out_; + + public: + SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr, + const T mu, const int64_t* rows, int64_t row_numel, + int64_t row_height, T* p_out, T* v_out) + : p_(p), + g_(g), + v_(v), + lr_(lr), + mu_(mu), + rows_(rows), + row_numel_(row_numel), + row_height_(row_height), + p_out_(p_out), + v_out_(v_out) {} + + inline HOSTDEVICE void operator()(size_t i) { + auto row_idx = + math::BinarySearch(rows_, row_height_, i / row_numel_); + T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0; + // put memory access in register + const T p = p_[i]; + const T lr = lr_[0]; + const T v = v_[i]; + T v_out = v * mu_ + g; + T p_out = p - (g + v_out * mu_) * lr; + // write reigster to memory + v_out_[i] = v_out; + p_out_[i] = p_out; + } +}; + +template +class SparseMomentumFunctor { + private: + const T* p_; + const T* g_; + const T* v_; + const T* lr_; + const T mu_; + const int64_t* rows_; + const int64_t row_numel_; + const int64_t row_height_; + T* p_out_; + T* v_out_; + + public: + SparseMomentumFunctor(const T* p, const T* g, const T* v, const T* lr, + const T mu, const int64_t* rows, int64_t row_numel, + int64_t row_height, T* p_out, T* v_out) + : p_(p), + g_(g), + v_(v), + lr_(lr), + mu_(mu), + rows_(rows), + row_numel_(row_numel), + row_height_(row_height), + p_out_(p_out), + v_out_(v_out) {} + + inline HOSTDEVICE void operator()(size_t i) { + auto row_idx = + math::BinarySearch(rows_, row_height_, i / row_numel_); + T g = row_idx >= 0 ? g_[row_idx * row_numel_ + i % row_numel_] : 0; + // put memory access in register + const T p = p_[i]; + const T lr = lr_[0]; + const T v = v_[i]; + T v_out = v * mu_ + g; + T p_out = p - v_out * lr; + // write reigster to memory + v_out_[i] = v_out; + p_out_[i] = p_out; + } +}; + +template class MomentumOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -29,65 +283,88 @@ class MomentumOpKernel : public framework::OpKernel { auto learning_rate = ctx.Input("LearningRate"); auto param = ctx.Input("Param"); auto param_out = ctx.Output("ParamOut"); - auto* velocity_var = ctx.InputVar("Velocity"); + auto* velocity = ctx.Input("Velocity"); + auto velocity_out = ctx.Output("VelocityOut"); + param_out->mutable_data(ctx.GetPlace()); + velocity_out->mutable_data(ctx.GetPlace()); + auto* grad_var = ctx.InputVar("Grad"); if (grad_var->IsType()) { - PADDLE_ENFORCE(velocity_var->IsType(), - "Unmatched Type of Param and Grad"); - auto velocity = ctx.Input("Velocity"); auto grad = ctx.Input("Grad"); - auto velocity_out = ctx.Output("VelocityOut"); - param_out->mutable_data(ctx.GetPlace()); - velocity_out->mutable_data(ctx.GetPlace()); - auto p_out = framework::EigenVector::Flatten(*param_out); - auto v_out = framework::EigenVector::Flatten(*velocity_out); - - auto p = framework::EigenVector::Flatten(*param); - auto v = framework::EigenVector::Flatten(*velocity); - auto g = framework::EigenVector::Flatten(*grad); - auto* lr = learning_rate->data(); - - v_out = v * mu + g; - if (use_nesterov) { - p_out = p - (g + v_out * mu) * lr[0]; - } else { - p_out = p - lr[0] * v_out; + if (platform::is_cpu_place(ctx.GetPlace())) { + CPUDenseMomentumFunctor functor(param, grad, velocity, learning_rate, + mu, use_nesterov, param_out, + velocity_out); + functor(); + } else if (platform::is_gpu_place(ctx.GetPlace())) { + platform::ForRange for_range( + static_cast(ctx.device_context()), + param->numel()); + if (use_nesterov) { + DenseMomentumFunctor functor( + param->data(), grad->data(), velocity->data(), + learning_rate->data(), mu, param->numel(), + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace())); + for_range(functor); + + } else { + DenseMomentumFunctor functor( + param->data(), grad->data(), velocity->data(), + learning_rate->data(), mu, param->numel(), + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace())); + for_range(functor); + } } + } else if (grad_var->IsType()) { // sparse update embedding with selectedrows - PADDLE_ENFORCE(velocity_var->IsType(), - "Unmatched Type of Param and Grad"); - auto velocity = ctx.Input("Velocity"); auto grad = ctx.Input("Grad"); - auto velocity_out = ctx.Output("VelocityOut"); // sparse update maybe empty. if (grad->rows().size() == 0) { + VLOG(3) << "Grad SelectedRows contains no data!"; return; } - PADDLE_ENFORCE(grad->height() == velocity->height(), - "Unmatched gradient and velocity."); - auto* p_out = param_out->mutable_data(ctx.GetPlace()); - auto* v_out = - velocity_out->mutable_value()->mutable_data(ctx.GetPlace()); - auto* lr = learning_rate->data(); - auto* p = param->data(); - auto* g = grad->value().data(); - auto* v = velocity->value().data(); - size_t grad_row_numel = grad->value().numel() / grad->rows().size(); - - for (size_t i = 0; i < grad->rows().size(); ++i) { - size_t grad_row_index = grad->rows()[i]; - for (size_t j = 0; j < grad_row_numel; ++j) { - size_t p_i = grad_row_index * grad_row_numel + j; - size_t g_i = i * grad_row_numel + j; - v_out[g_i] = v[g_i] * mu + g[g_i]; - if (use_nesterov) { - p_out[p_i] = p[p_i] - (g[g_i] + v_out[g_i] * mu) * lr[0]; - } else { - p_out[p_i] = p[p_i] - v_out[g_i] * lr[0]; - } - } + auto* merged_grad = const_cast(ctx.scope()) + .Var() + ->GetMutable(); + + math::scatter::MergeAdd merge_func; + merge_func(ctx.template device_context(), *grad, + merged_grad); + + platform::ForRange for_range( + static_cast(ctx.device_context()), + param->numel()); + + const int64_t* rows = nullptr; + if (platform::is_gpu_place(ctx.GetPlace())) { + rows = merged_grad->rows().CUDAData(ctx.GetPlace()); + } else { + rows = merged_grad->rows().data(); + } + + if (use_nesterov) { + SparseMomentumFunctor functor( + param->data(), merged_grad->value().data(), + velocity->data(), learning_rate->data(), mu, rows, + static_cast(merged_grad->rows().size()), + static_cast(merged_grad->height()), + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace())); + for_range(functor); + + } else { + SparseMomentumFunctor functor( + param->data(), merged_grad->value().data(), + velocity->data(), learning_rate->data(), mu, rows, + static_cast(merged_grad->rows().size()), + static_cast(merged_grad->height()), + param_out->mutable_data(ctx.GetPlace()), + velocity_out->mutable_data(ctx.GetPlace())); + for_range(functor); } } else { PADDLE_THROW("Unsupported Variable Type of Grad"); -- GitLab