提交 163d2871 编写于 作者: K Kavya Srinet

Made learning rate the input

上级 61c03f9d
......@@ -24,11 +24,13 @@ class RmspropOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(param) of RmspropOp should not be null.");
"Input(Param) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(grad) of RmspropOp should not be null.");
"Input(Grad) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(moment) of RmspropOp should not be null.");
"Input(Moment) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of RmspropOp should not be null.");
......@@ -43,6 +45,10 @@ class RmspropOp : public framework::OperatorWithKernel {
param_dim, ctx->GetInputDim("Moment"),
"Param and moment input of RmspropOp should have the same dimension.");
auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
"Learning Rate should be a scalar.");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim);
}
......@@ -56,11 +62,11 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Param", "Input parameter");
AddInput("Grad", "Input gradient");
AddInput("Moment", "Second moment");
AddInput("LearningRate", "Learning Rate");
AddOutput("ParamOut", "Output parameter");
AddOutput("MomentOut", "Output second moment");
AddAttr<float>("learningRate", "Learning rate");
AddAttr<float>("epsilon", "Constant for numerical stability");
AddAttr<float>("decayRate", "Decay rate for moving average of gradients");
AddComment(R"DOC(
......@@ -68,7 +74,7 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
RMSprop
MomentOut = decayRate * Moment + (1 - decayRate) * Grad * Grad
ParamOut = Param - learningRate * Grad / (sqrt(MomentOut) + epsilon)
ParamOut = Param - LearningRate * Grad / (sqrt(MomentOut) + epsilon)
The original slide(Slide 29 of
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
......
......@@ -34,13 +34,13 @@ class RmspropOpKernel : public framework::OpKernel<T> {
param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace());
float lr = ctx.Attr<float>("learningRate");
float epsilon = ctx.Attr<float>("epsilon");
float decay = ctx.Attr<float>("decayRate");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad"));
auto m = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
auto p_out = EigenVector<T>::Flatten(*param_out);
auto m_out = EigenVector<T>::Flatten(*moment_out);
auto place = ctx.GetEigenDevice<Place>();
......
......@@ -10,19 +10,20 @@ class TestRmspropOp(OpTest):
param = np.random.random((123, 321)).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
learning_rate = 0.01
epsilon = 1e-6
decay_rate = 0.9
self.inputs = {'Param': param, 'Grad': grad, 'Moment': moment}
self.attrs = {
'learningRate': learning_rate,
'epsilon': epsilon,
'decayRate': decay_rate
self.inputs = {
'Param': param,
'Grad': grad,
'Moment': moment,
'LearningRate': learning_rate
}
self.attrs = {'epsilon': epsilon, 'decayRate': decay_rate}
moment_out = decay_rate * moment + (1 - decay_rate) * grad * grad
param_out = param - learning_rate * grad / (np.sqrt(moment_out) +
epsilon)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册