diff --git a/paddle/operators/rmsprop_op.cc b/paddle/operators/rmsprop_op.cc index 1e06e08ede214caa7f4c2de12aeb237631152668..8f61c7fdda9f80c69745a9bc4569fcbc099630aa 100644 --- a/paddle/operators/rmsprop_op.cc +++ b/paddle/operators/rmsprop_op.cc @@ -89,13 +89,13 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("epsilon", "(float, default 1e-10) Constant " "for numerical stability.") - .SetDefault(1e-10); + .SetDefault(1.0e-10f); AddAttr("decay", "(float, default 0.9) " "Discounting factor for coming gradient.") - .SetDefault(0.9); + .SetDefault(0.9f); AddAttr("momentum", "(float, default 0.0) Constant value") - .SetDefault(0.0); + .SetDefault(0.0f); AddComment(R"DOC( RMSprop diff --git a/paddle/operators/rmsprop_op.h b/paddle/operators/rmsprop_op.h index ed4b283ce46146240aa6810348214b75f02c250a..9c04276ec618bfa9da31fb301f5a0361c58017a8 100644 --- a/paddle/operators/rmsprop_op.h +++ b/paddle/operators/rmsprop_op.h @@ -28,9 +28,9 @@ template class RmspropOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto param_out = ctx.Output("ParamOut"); - auto moment_out = ctx.Output("MomentOut"); - auto mean_square_out = ctx.Output("MeanSquareOut"); + auto* param_out = ctx.Output("ParamOut"); + auto* moment_out = ctx.Output("MomentOut"); + auto* mean_square_out = ctx.Output("MeanSquareOut"); param_out->mutable_data(ctx.GetPlace()); moment_out->mutable_data(ctx.GetPlace()); diff --git a/python/paddle/v2/framework/tests/test_rmsprop_op.py b/python/paddle/v2/framework/tests/test_rmsprop_op.py index 84bd815c8ca2cdb99fba88f8aaead109e4606602..3e5ff733e9b55fe8c9727e9721e25083a494be15 100644 --- a/python/paddle/v2/framework/tests/test_rmsprop_op.py +++ b/python/paddle/v2/framework/tests/test_rmsprop_op.py @@ -3,7 +3,10 @@ import numpy as np from op_test import OpTest -class TestRmspropOp(OpTest): +class TestRmspropOp1(OpTest): + ''' Test RMSProp with explicit inputs + ''' + def setUp(self): self.op_type = "rmsprop" @@ -42,5 +45,45 @@ class TestRmspropOp(OpTest): self.check_output() +class TestRmspropOp2(OpTest): + '''Test RMSProp with defaukt values for attributes + ''' + + def setUp(self): + self.op_type = "rmsprop" + + param = np.random.random((123, 321)).astype("float32") + mean_square = np.random.random((123, 321)).astype("float32") + learning_rate = np.array([0.01]).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + moment = np.zeros((123, 321)).astype("float32") + + epsilon = 1.0e-10 + decay = 0.9 + momentum = 0.0 + + self.inputs = { + 'Param': param, + 'MeanSquare': mean_square, + 'LearningRate': learning_rate, + 'Grad': grad, + 'Moment': moment, + } + + ms_out = decay * mean_square + (1 - decay) * grad * grad + moment_out = momentum * moment + \ + learning_rate * grad / np.sqrt(ms_out + epsilon) + param_out = param - moment_out + + self.outputs = { + 'ParamOut': param_out, + 'MomentOut': moment_out, + 'MeanSquareOut': ms_out + } + + def test_check_output(self): + self.check_output() + + if __name__ == "__main__": unittest.main()