提交 163d2871 编写于 作者: K Kavya Srinet

Made learning rate the input

上级 61c03f9d
...@@ -24,11 +24,13 @@ class RmspropOp : public framework::OperatorWithKernel { ...@@ -24,11 +24,13 @@ class RmspropOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"), PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(param) of RmspropOp should not be null."); "Input(Param) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"), PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(grad) of RmspropOp should not be null."); "Input(Grad) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"), PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(moment) of RmspropOp should not be null."); "Input(Moment) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of RmspropOp should not be null."); "Output(param_out) of RmspropOp should not be null.");
...@@ -43,6 +45,10 @@ class RmspropOp : public framework::OperatorWithKernel { ...@@ -43,6 +45,10 @@ class RmspropOp : public framework::OperatorWithKernel {
param_dim, ctx->GetInputDim("Moment"), param_dim, ctx->GetInputDim("Moment"),
"Param and moment input of RmspropOp should have the same dimension."); "Param and moment input of RmspropOp should have the same dimension.");
auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
"Learning Rate should be a scalar.");
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim);
} }
...@@ -56,11 +62,11 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -56,11 +62,11 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Param", "Input parameter"); AddInput("Param", "Input parameter");
AddInput("Grad", "Input gradient"); AddInput("Grad", "Input gradient");
AddInput("Moment", "Second moment"); AddInput("Moment", "Second moment");
AddInput("LearningRate", "Learning Rate");
AddOutput("ParamOut", "Output parameter"); AddOutput("ParamOut", "Output parameter");
AddOutput("MomentOut", "Output second moment"); AddOutput("MomentOut", "Output second moment");
AddAttr<float>("learningRate", "Learning rate");
AddAttr<float>("epsilon", "Constant for numerical stability"); AddAttr<float>("epsilon", "Constant for numerical stability");
AddAttr<float>("decayRate", "Decay rate for moving average of gradients"); AddAttr<float>("decayRate", "Decay rate for moving average of gradients");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -68,7 +74,7 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -68,7 +74,7 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
RMSprop RMSprop
MomentOut = decayRate * Moment + (1 - decayRate) * Grad * Grad MomentOut = decayRate * Moment + (1 - decayRate) * Grad * Grad
ParamOut = Param - learningRate * Grad / (sqrt(MomentOut) + epsilon) ParamOut = Param - LearningRate * Grad / (sqrt(MomentOut) + epsilon)
The original slide(Slide 29 of The original slide(Slide 29 of
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
......
...@@ -34,13 +34,13 @@ class RmspropOpKernel : public framework::OpKernel<T> { ...@@ -34,13 +34,13 @@ class RmspropOpKernel : public framework::OpKernel<T> {
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace()); moment_out->mutable_data<T>(ctx.GetPlace());
float lr = ctx.Attr<float>("learningRate");
float epsilon = ctx.Attr<float>("epsilon"); float epsilon = ctx.Attr<float>("epsilon");
float decay = ctx.Attr<float>("decayRate"); float decay = ctx.Attr<float>("decayRate");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param")); auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad")); auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad"));
auto m = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment")); auto m = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
auto p_out = EigenVector<T>::Flatten(*param_out); auto p_out = EigenVector<T>::Flatten(*param_out);
auto m_out = EigenVector<T>::Flatten(*moment_out); auto m_out = EigenVector<T>::Flatten(*moment_out);
auto place = ctx.GetEigenDevice<Place>(); auto place = ctx.GetEigenDevice<Place>();
......
...@@ -10,19 +10,20 @@ class TestRmspropOp(OpTest): ...@@ -10,19 +10,20 @@ class TestRmspropOp(OpTest):
param = np.random.random((123, 321)).astype("float32") param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32") grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32") moment = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
learning_rate = 0.01
epsilon = 1e-6 epsilon = 1e-6
decay_rate = 0.9 decay_rate = 0.9
self.inputs = {'Param': param, 'Grad': grad, 'Moment': moment} self.inputs = {
'Param': param,
self.attrs = { 'Grad': grad,
'learningRate': learning_rate, 'Moment': moment,
'epsilon': epsilon, 'LearningRate': learning_rate
'decayRate': decay_rate
} }
self.attrs = {'epsilon': epsilon, 'decayRate': decay_rate}
moment_out = decay_rate * moment + (1 - decay_rate) * grad * grad moment_out = decay_rate * moment + (1 - decay_rate) * grad * grad
param_out = param - learning_rate * grad / (np.sqrt(moment_out) + param_out = param - learning_rate * grad / (np.sqrt(moment_out) +
epsilon) epsilon)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册