From 163d28714349d51596be1cb165f93be2b8290bda Mon Sep 17 00:00:00 2001 From: Kavya Srinet Date: Mon, 2 Oct 2017 19:23:05 -0700 Subject: [PATCH] Made learning rate the input --- paddle/operators/rmsprop_op.cc | 16 +++++++++++----- paddle/operators/rmsprop_op.h | 2 +- .../paddle/v2/framework/tests/test_rmsprop_op.py | 15 ++++++++------- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/paddle/operators/rmsprop_op.cc b/paddle/operators/rmsprop_op.cc index dcf3599f4d..602efab3db 100644 --- a/paddle/operators/rmsprop_op.cc +++ b/paddle/operators/rmsprop_op.cc @@ -24,11 +24,13 @@ class RmspropOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContextBase *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), - "Input(param) of RmspropOp should not be null."); + "Input(Param) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), - "Input(grad) of RmspropOp should not be null."); + "Input(Grad) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Moment"), - "Input(moment) of RmspropOp should not be null."); + "Input(Moment) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(param_out) of RmspropOp should not be null."); @@ -43,6 +45,10 @@ class RmspropOp : public framework::OperatorWithKernel { param_dim, ctx->GetInputDim("Moment"), "Param and moment input of RmspropOp should have the same dimension."); + auto lr_dim = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, + "Learning Rate should be a scalar."); + ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim); } @@ -56,11 +62,11 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Param", "Input parameter"); AddInput("Grad", "Input gradient"); AddInput("Moment", "Second moment"); + AddInput("LearningRate", "Learning Rate"); AddOutput("ParamOut", "Output parameter"); AddOutput("MomentOut", "Output second moment"); - AddAttr("learningRate", "Learning rate"); AddAttr("epsilon", "Constant for numerical stability"); AddAttr("decayRate", "Decay rate for moving average of gradients"); AddComment(R"DOC( @@ -68,7 +74,7 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { RMSprop MomentOut = decayRate * Moment + (1 - decayRate) * Grad * Grad -ParamOut = Param - learningRate * Grad / (sqrt(MomentOut) + epsilon) +ParamOut = Param - LearningRate * Grad / (sqrt(MomentOut) + epsilon) The original slide(Slide 29 of http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) diff --git a/paddle/operators/rmsprop_op.h b/paddle/operators/rmsprop_op.h index c94c24bddd..65b9edd35b 100644 --- a/paddle/operators/rmsprop_op.h +++ b/paddle/operators/rmsprop_op.h @@ -34,13 +34,13 @@ class RmspropOpKernel : public framework::OpKernel { param_out->mutable_data(ctx.GetPlace()); moment_out->mutable_data(ctx.GetPlace()); - float lr = ctx.Attr("learningRate"); float epsilon = ctx.Attr("epsilon"); float decay = ctx.Attr("decayRate"); auto p = EigenVector::Flatten(*ctx.Input("Param")); auto g = EigenVector::Flatten(*ctx.Input("Grad")); auto m = EigenVector::Flatten(*ctx.Input("Moment")); + float lr = ctx.Input("LearningRate")->data()[0]; auto p_out = EigenVector::Flatten(*param_out); auto m_out = EigenVector::Flatten(*moment_out); auto place = ctx.GetEigenDevice(); diff --git a/python/paddle/v2/framework/tests/test_rmsprop_op.py b/python/paddle/v2/framework/tests/test_rmsprop_op.py index 1fc59a0f11..64ca5da48e 100644 --- a/python/paddle/v2/framework/tests/test_rmsprop_op.py +++ b/python/paddle/v2/framework/tests/test_rmsprop_op.py @@ -10,19 +10,20 @@ class TestRmspropOp(OpTest): param = np.random.random((123, 321)).astype("float32") grad = np.random.random((123, 321)).astype("float32") moment = np.zeros((123, 321)).astype("float32") + learning_rate = np.array([0.01]).astype("float32") - learning_rate = 0.01 epsilon = 1e-6 decay_rate = 0.9 - self.inputs = {'Param': param, 'Grad': grad, 'Moment': moment} - - self.attrs = { - 'learningRate': learning_rate, - 'epsilon': epsilon, - 'decayRate': decay_rate + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment': moment, + 'LearningRate': learning_rate } + self.attrs = {'epsilon': epsilon, 'decayRate': decay_rate} + moment_out = decay_rate * moment + (1 - decay_rate) * grad * grad param_out = param - learning_rate * grad / (np.sqrt(moment_out) + epsilon) -- GitLab