From 5bd1e73f5e6e7532bd1b13b1c0924ba70ae5cd1a Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 5 Dec 2017 00:25:39 +0800 Subject: [PATCH] Refine and speedup momentum operator. --- paddle/operators/momentum_op.cc | 4 +- paddle/operators/momentum_op.cu | 66 +++++++++++++++++++++++++++++++-- paddle/operators/momentum_op.h | 13 +++---- 3 files changed, 70 insertions(+), 13 deletions(-) diff --git a/paddle/operators/momentum_op.cc b/paddle/operators/momentum_op.cc index 19954006195..fde253b0b38 100644 --- a/paddle/operators/momentum_op.cc +++ b/paddle/operators/momentum_op.cc @@ -101,5 +101,5 @@ $$ namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(momentum, ops::MomentumOp, ops::MomentumOpMaker); -REGISTER_OP_CPU_KERNEL( - momentum, ops::MomentumOpKernel); +REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel, + ops::MomentumOpKernel); diff --git a/paddle/operators/momentum_op.cu b/paddle/operators/momentum_op.cu index efc24e795e0..d856df40027 100644 --- a/paddle/operators/momentum_op.cu +++ b/paddle/operators/momentum_op.cu @@ -12,9 +12,67 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU -#include "paddle/operators/momentum_op.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +__global__ void MomentumKernel(const T* p, const T* g, const T* v, + const T* learning_rate, const T mu, + const int64_t num, bool use_nesterov, T* p_out, + T* v_out) { + T lr = learning_rate[0]; + if (use_nesterov) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; + i += blockDim.x * gridDim.x) { + T g_val = g[i]; + T v_new = v[i] * mu + g_val; + v_out[i] = v_new; + p_out[i] = p[i] - g_val * lr + v_new * mu * lr; + } + } else { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; + i += blockDim.x * gridDim.x) { + T v_new = v[i] * mu + g[i]; + v_out[i] = v_new; + p_out[i] = p[i] - lr * v_new; + } + } +} + +template +class MomentumOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out = ctx.Output("ParamOut"); + auto velocity_out = ctx.Output("VelocityOut"); + auto param = ctx.Input("Param"); + auto velocity = ctx.Input("Velocity"); + auto grad = ctx.Input("Grad"); + auto learning_rate = ctx.Input("LearningRate"); + + T* p_out = param_out->mutable_data(ctx.GetPlace()); + T* v_out = velocity_out->mutable_data(ctx.GetPlace()); + + T mu = static_cast(ctx.Attr("mu")); + bool use_nesterov = ctx.Attr("use_nesterov"); + + auto* p = param->data(); + auto* v = velocity->data(); + auto* g = grad->data(); + auto* lr = learning_rate->data(); + + int block = 512; + int grid = (param->numel() + block - 1) / block; + MomentumKernel<<>>( + p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out); + } +}; + +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL( - momentum, ops::MomentumOpKernel); +REGISTER_OP_GPU_KERNEL(momentum, ops::MomentumOpCUDAKernel, + ops::MomentumOpCUDAKernel); diff --git a/paddle/operators/momentum_op.h b/paddle/operators/momentum_op.h index 8f7f5eb5c21..2d919573d20 100644 --- a/paddle/operators/momentum_op.h +++ b/paddle/operators/momentum_op.h @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class MomentumOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -33,7 +33,7 @@ class MomentumOpKernel : public framework::OpKernel { param_out->mutable_data(ctx.GetPlace()); velocity_out->mutable_data(ctx.GetPlace()); - float mu = ctx.Attr("mu"); + T mu = static_cast(ctx.Attr("mu")); bool use_nesterov = ctx.Attr("use_nesterov"); auto p_out = framework::EigenVector::Flatten(*param_out); @@ -42,18 +42,17 @@ class MomentumOpKernel : public framework::OpKernel { auto p = framework::EigenVector::Flatten(*param); auto v = framework::EigenVector::Flatten(*velocity); auto g = framework::EigenVector::Flatten(*grad); - auto lr = framework::EigenVector::Flatten(*learning_rate); + auto* lr = learning_rate->data(); - auto place = ctx.GetEigenDevice(); + auto place = ctx.GetEigenDevice(); Eigen::DSizes grad_dsize(grad->numel()); v_out.device(place) = v * mu + g; if (use_nesterov) { - p_out.device(place) = p - g * lr.broadcast(grad_dsize) + - v_out * mu * lr.broadcast(grad_dsize); + p_out.device(place) = p - (g - v_out * mu) * lr[0]; } else { - p_out.device(place) = p - lr.broadcast(grad_dsize) * v_out; + p_out.device(place) = p - lr[0] * v_out; } } }; -- GitLab