diff --git a/paddle/operators/rmsprop_op.cc b/paddle/operators/rmsprop_op.cc index 602efab3db475676a2a06f27d0a115e8c7c36b2e..1e06e08ede214caa7f4c2de12aeb237631152668 100644 --- a/paddle/operators/rmsprop_op.cc +++ b/paddle/operators/rmsprop_op.cc @@ -25,25 +25,32 @@ class RmspropOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContextBase *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("MeanSquare"), + "Input(MeanSquare) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), "Input(Grad) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Moment"), "Input(Moment) of RmspropOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("LearningRate"), - "Input(LearningRate) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), "Output(param_out) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("MomentOut"), - "Output(moment_out) of RmspropOp should not be null."); + "Output(Momentum_out) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"), + "Output(MeanSquareOut) of RmspropOp should not be null."); auto param_dim = ctx->GetInputDim("Param"); PADDLE_ENFORCE_EQ( param_dim, ctx->GetInputDim("Grad"), "Param and grad input of RmspropOp should have the same dimension."); - PADDLE_ENFORCE_EQ( - param_dim, ctx->GetInputDim("Moment"), - "Param and moment input of RmspropOp should have the same dimension."); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"), + "Param and Momentum input of RmspropOp " + "should have the same dimension."); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"), + "Param and Momentum input of RmspropOp " + "should have the same dimension."); auto lr_dim = ctx->GetInputDim("LearningRate"); PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, @@ -51,6 +58,7 @@ class RmspropOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim); + ctx->SetOutputDim("MeanSquareOut", param_dim); } }; @@ -59,27 +67,46 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { RmspropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("Param", "Input parameter"); - AddInput("Grad", "Input gradient"); - AddInput("Moment", "Second moment"); - AddInput("LearningRate", "Learning Rate"); - - AddOutput("ParamOut", "Output parameter"); - AddOutput("MomentOut", "Output second moment"); - - AddAttr("epsilon", "Constant for numerical stability"); - AddAttr("decayRate", "Decay rate for moving average of gradients"); + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter value that has to be updated"); + AddInput("MeanSquare", + "(Tensor, default Tensor)" + " The mean square value that gets updated"); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "The learning rate should be a tensor of size 1"); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter"); + AddInput("Moment", + "(Tensor, default Tensor) The moment that gets updated"); + + AddOutput("ParamOut", "(Tensor) Output updated parameter value"); + AddOutput("MomentOut", "(Tensor) Output updated moment"); + AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value"); + + AddAttr("epsilon", + "(float, default 1e-10) Constant " + "for numerical stability.") + .SetDefault(1e-10); + AddAttr("decay", + "(float, default 0.9) " + "Discounting factor for coming gradient.") + .SetDefault(0.9); + AddAttr("momentum", "(float, default 0.0) Constant value") + .SetDefault(0.0); AddComment(R"DOC( RMSprop -MomentOut = decayRate * Moment + (1 - decayRate) * Grad * Grad -ParamOut = Param - LearningRate * Grad / (sqrt(MomentOut) + epsilon) +MeanSquareOut = decay * MeanSquare + (1 - decay) * Grad * Grad +MomentOut = momentum * Moment + + LearningRate * Grad / sqrt(MeanSquareOut + epsilon) +ParamOut = Param - MomentOut -The original slide(Slide 29 of +The original slides that proposed RMSprop: Slide 29 of http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) -does not have the epsilon attribute. It is added here for numerical stability -to avoid division by zero. )DOC"); } diff --git a/paddle/operators/rmsprop_op.h b/paddle/operators/rmsprop_op.h index 65b9edd35b97a866d6074c654b273b5a3bcb3498..ed4b283ce46146240aa6810348214b75f02c250a 100644 --- a/paddle/operators/rmsprop_op.h +++ b/paddle/operators/rmsprop_op.h @@ -30,23 +30,30 @@ class RmspropOpKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto param_out = ctx.Output("ParamOut"); auto moment_out = ctx.Output("MomentOut"); + auto mean_square_out = ctx.Output("MeanSquareOut"); param_out->mutable_data(ctx.GetPlace()); moment_out->mutable_data(ctx.GetPlace()); + mean_square_out->mutable_data(ctx.GetPlace()); float epsilon = ctx.Attr("epsilon"); - float decay = ctx.Attr("decayRate"); + float rho = ctx.Attr("decay"); + float momentum = ctx.Attr("momentum"); auto p = EigenVector::Flatten(*ctx.Input("Param")); - auto g = EigenVector::Flatten(*ctx.Input("Grad")); - auto m = EigenVector::Flatten(*ctx.Input("Moment")); + auto ms = EigenVector::Flatten(*ctx.Input("MeanSquare")); float lr = ctx.Input("LearningRate")->data()[0]; + auto g = EigenVector::Flatten(*ctx.Input("Grad")); + auto mom = EigenVector::Flatten(*ctx.Input("Moment")); + auto p_out = EigenVector::Flatten(*param_out); - auto m_out = EigenVector::Flatten(*moment_out); + auto mom_out = EigenVector::Flatten(*moment_out); + auto ms_out = EigenVector::Flatten(*mean_square_out); auto place = ctx.GetEigenDevice(); - m_out.device(place) = decay * m + (1 - decay) * g * g; - p_out.device(place) = p - lr * g / (m_out.sqrt() + epsilon); + ms_out.device(place) = rho * ms + (1 - rho) * g * g; + mom_out.device(place) = momentum * mom + lr * g / (ms_out + epsilon).sqrt(); + p_out.device(place) = p - mom_out; } }; diff --git a/python/paddle/v2/framework/tests/test_rmsprop_op.py b/python/paddle/v2/framework/tests/test_rmsprop_op.py index 64ca5da48e72e129fce89d5e9914b4e6cc8e58f5..84bd815c8ca2cdb99fba88f8aaead109e4606602 100644 --- a/python/paddle/v2/framework/tests/test_rmsprop_op.py +++ b/python/paddle/v2/framework/tests/test_rmsprop_op.py @@ -8,27 +8,35 @@ class TestRmspropOp(OpTest): self.op_type = "rmsprop" param = np.random.random((123, 321)).astype("float32") + mean_square = np.random.random((123, 321)).astype("float32") + learning_rate = np.array([0.01]).astype("float32") grad = np.random.random((123, 321)).astype("float32") moment = np.zeros((123, 321)).astype("float32") - learning_rate = np.array([0.01]).astype("float32") epsilon = 1e-6 - decay_rate = 0.9 + decay = 0.9 + momentum = 0.0 self.inputs = { 'Param': param, + 'MeanSquare': mean_square, + 'LearningRate': learning_rate, 'Grad': grad, 'Moment': moment, - 'LearningRate': learning_rate } - self.attrs = {'epsilon': epsilon, 'decayRate': decay_rate} + self.attrs = {'epsilon': epsilon, 'decay': decay, 'momentum': momentum} - moment_out = decay_rate * moment + (1 - decay_rate) * grad * grad - param_out = param - learning_rate * grad / (np.sqrt(moment_out) + - epsilon) + ms_out = decay * mean_square + (1 - decay) * grad * grad + moment_out = momentum * moment + \ + learning_rate * grad / np.sqrt(ms_out + epsilon) + param_out = param - moment_out - self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out} + self.outputs = { + 'ParamOut': param_out, + 'MomentOut': moment_out, + 'MeanSquareOut': ms_out + } def test_check_output(self): self.check_output()