diff --git a/paddle/operators/rmsprop_op.h b/paddle/operators/rmsprop_op.h index 9c04276ec618bfa9da31fb301f5a0361c58017a8..7bf2129010f994966d79ef11d5cec30159b47068 100644 --- a/paddle/operators/rmsprop_op.h +++ b/paddle/operators/rmsprop_op.h @@ -32,6 +32,8 @@ class RmspropOpKernel : public framework::OpKernel { auto* moment_out = ctx.Output("MomentOut"); auto* mean_square_out = ctx.Output("MeanSquareOut"); + auto grad = ctx.Input("Grad"); + param_out->mutable_data(ctx.GetPlace()); moment_out->mutable_data(ctx.GetPlace()); mean_square_out->mutable_data(ctx.GetPlace()); @@ -42,8 +44,8 @@ class RmspropOpKernel : public framework::OpKernel { auto p = EigenVector::Flatten(*ctx.Input("Param")); auto ms = EigenVector::Flatten(*ctx.Input("MeanSquare")); - float lr = ctx.Input("LearningRate")->data()[0]; - auto g = EigenVector::Flatten(*ctx.Input("Grad")); + auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); + auto g = EigenVector::Flatten(*grad); auto mom = EigenVector::Flatten(*ctx.Input("Moment")); auto p_out = EigenVector::Flatten(*param_out); @@ -51,8 +53,12 @@ class RmspropOpKernel : public framework::OpKernel { auto ms_out = EigenVector::Flatten(*mean_square_out); auto place = ctx.GetEigenDevice(); + Eigen::DSizes grad_dsize(grad->numel()); + ms_out.device(place) = rho * ms + (1 - rho) * g * g; - mom_out.device(place) = momentum * mom + lr * g / (ms_out + epsilon).sqrt(); + mom_out.device(place) = + momentum * mom + + lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt(); p_out.device(place) = p - mom_out; } };