From f52cdaa0cee682ddc3588286af42d960141596f0 Mon Sep 17 00:00:00 2001 From: Kavya Srinet Date: Thu, 5 Oct 2017 19:12:27 -0700 Subject: [PATCH] Updated RMSProp to have learning rate as an input and work with GPU --- paddle/operators/rmsprop_op.h | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/paddle/operators/rmsprop_op.h b/paddle/operators/rmsprop_op.h index 9c04276ec6..7bf2129010 100644 --- a/paddle/operators/rmsprop_op.h +++ b/paddle/operators/rmsprop_op.h @@ -32,6 +32,8 @@ class RmspropOpKernel : public framework::OpKernel { auto* moment_out = ctx.Output("MomentOut"); auto* mean_square_out = ctx.Output("MeanSquareOut"); + auto grad = ctx.Input("Grad"); + param_out->mutable_data(ctx.GetPlace()); moment_out->mutable_data(ctx.GetPlace()); mean_square_out->mutable_data(ctx.GetPlace()); @@ -42,8 +44,8 @@ class RmspropOpKernel : public framework::OpKernel { auto p = EigenVector::Flatten(*ctx.Input("Param")); auto ms = EigenVector::Flatten(*ctx.Input("MeanSquare")); - float lr = ctx.Input("LearningRate")->data()[0]; - auto g = EigenVector::Flatten(*ctx.Input("Grad")); + auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); + auto g = EigenVector::Flatten(*grad); auto mom = EigenVector::Flatten(*ctx.Input("Moment")); auto p_out = EigenVector::Flatten(*param_out); @@ -51,8 +53,12 @@ class RmspropOpKernel : public framework::OpKernel { auto ms_out = EigenVector::Flatten(*mean_square_out); auto place = ctx.GetEigenDevice(); + Eigen::DSizes grad_dsize(grad->numel()); + ms_out.device(place) = rho * ms + (1 - rho) * g * g; - mom_out.device(place) = momentum * mom + lr * g / (ms_out + epsilon).sqrt(); + mom_out.device(place) = + momentum * mom + + lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt(); p_out.device(place) = p - mom_out; } }; -- GitLab