提交 94855f4a 编写于 作者: K Kavya Srinet

Fixed changes proposed in the review

上级 163d2871
...@@ -25,25 +25,32 @@ class RmspropOp : public framework::OperatorWithKernel { ...@@ -25,25 +25,32 @@ class RmspropOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"), PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of RmspropOp should not be null."); "Input(Param) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("MeanSquare"),
"Input(MeanSquare) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"), PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of RmspropOp should not be null."); "Input(Grad) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"), PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of RmspropOp should not be null."); "Input(Moment) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of RmspropOp should not be null."); "Output(param_out) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"), PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
"Output(moment_out) of RmspropOp should not be null."); "Output(Momentum_out) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"),
"Output(MeanSquareOut) of RmspropOp should not be null.");
auto param_dim = ctx->GetInputDim("Param"); auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"), param_dim, ctx->GetInputDim("Grad"),
"Param and grad input of RmspropOp should have the same dimension."); "Param and grad input of RmspropOp should have the same dimension.");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"),
param_dim, ctx->GetInputDim("Moment"), "Param and Momentum input of RmspropOp "
"Param and moment input of RmspropOp should have the same dimension."); "should have the same dimension.");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"),
"Param and Momentum input of RmspropOp "
"should have the same dimension.");
auto lr_dim = ctx->GetInputDim("LearningRate"); auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
...@@ -51,6 +58,7 @@ class RmspropOp : public framework::OperatorWithKernel { ...@@ -51,6 +58,7 @@ class RmspropOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim);
ctx->SetOutputDim("MeanSquareOut", param_dim);
} }
}; };
...@@ -59,27 +67,46 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -59,27 +67,46 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
RmspropOpMaker(framework::OpProto *proto, RmspropOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "Input parameter"); AddInput("Param",
AddInput("Grad", "Input gradient"); "(Tensor, default Tensor<float>) "
AddInput("Moment", "Second moment"); "Input parameter value that has to be updated");
AddInput("LearningRate", "Learning Rate"); AddInput("MeanSquare",
"(Tensor, default Tensor<float>)"
AddOutput("ParamOut", "Output parameter"); " The mean square value that gets updated");
AddOutput("MomentOut", "Output second moment"); AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
AddAttr<float>("epsilon", "Constant for numerical stability"); "The learning rate should be a tensor of size 1");
AddAttr<float>("decayRate", "Decay rate for moving average of gradients"); AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter");
AddInput("Moment",
"(Tensor, default Tensor<float>) The moment that gets updated");
AddOutput("ParamOut", "(Tensor) Output updated parameter value");
AddOutput("MomentOut", "(Tensor) Output updated moment");
AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value");
AddAttr<float>("epsilon",
"(float, default 1e-10) Constant "
"for numerical stability.")
.SetDefault(1e-10);
AddAttr<float>("decay",
"(float, default 0.9) "
"Discounting factor for coming gradient.")
.SetDefault(0.9);
AddAttr<float>("momentum", "(float, default 0.0) Constant value")
.SetDefault(0.0);
AddComment(R"DOC( AddComment(R"DOC(
RMSprop RMSprop
MomentOut = decayRate * Moment + (1 - decayRate) * Grad * Grad MeanSquareOut = decay * MeanSquare + (1 - decay) * Grad * Grad
ParamOut = Param - LearningRate * Grad / (sqrt(MomentOut) + epsilon) MomentOut = momentum * Moment +
LearningRate * Grad / sqrt(MeanSquareOut + epsilon)
ParamOut = Param - MomentOut
The original slide(Slide 29 of The original slides that proposed RMSprop: Slide 29 of
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
does not have the epsilon attribute. It is added here for numerical stability
to avoid division by zero.
)DOC"); )DOC");
} }
......
...@@ -30,23 +30,30 @@ class RmspropOpKernel : public framework::OpKernel<T> { ...@@ -30,23 +30,30 @@ class RmspropOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<Tensor>("ParamOut"); auto param_out = ctx.Output<Tensor>("ParamOut");
auto moment_out = ctx.Output<Tensor>("MomentOut"); auto moment_out = ctx.Output<Tensor>("MomentOut");
auto mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace()); moment_out->mutable_data<T>(ctx.GetPlace());
mean_square_out->mutable_data<T>(ctx.GetPlace());
float epsilon = ctx.Attr<float>("epsilon"); float epsilon = ctx.Attr<float>("epsilon");
float decay = ctx.Attr<float>("decayRate"); float rho = ctx.Attr<float>("decay");
float momentum = ctx.Attr<float>("momentum");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param")); auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad")); auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare"));
auto m = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0]; float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad"));
auto mom = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
auto p_out = EigenVector<T>::Flatten(*param_out); auto p_out = EigenVector<T>::Flatten(*param_out);
auto m_out = EigenVector<T>::Flatten(*moment_out); auto mom_out = EigenVector<T>::Flatten(*moment_out);
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
auto place = ctx.GetEigenDevice<Place>(); auto place = ctx.GetEigenDevice<Place>();
m_out.device(place) = decay * m + (1 - decay) * g * g; ms_out.device(place) = rho * ms + (1 - rho) * g * g;
p_out.device(place) = p - lr * g / (m_out.sqrt() + epsilon); mom_out.device(place) = momentum * mom + lr * g / (ms_out + epsilon).sqrt();
p_out.device(place) = p - mom_out;
} }
}; };
......
...@@ -8,27 +8,35 @@ class TestRmspropOp(OpTest): ...@@ -8,27 +8,35 @@ class TestRmspropOp(OpTest):
self.op_type = "rmsprop" self.op_type = "rmsprop"
param = np.random.random((123, 321)).astype("float32") param = np.random.random((123, 321)).astype("float32")
mean_square = np.random.random((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
grad = np.random.random((123, 321)).astype("float32") grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32") moment = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
epsilon = 1e-6 epsilon = 1e-6
decay_rate = 0.9 decay = 0.9
momentum = 0.0
self.inputs = { self.inputs = {
'Param': param, 'Param': param,
'MeanSquare': mean_square,
'LearningRate': learning_rate,
'Grad': grad, 'Grad': grad,
'Moment': moment, 'Moment': moment,
'LearningRate': learning_rate
} }
self.attrs = {'epsilon': epsilon, 'decayRate': decay_rate} self.attrs = {'epsilon': epsilon, 'decay': decay, 'momentum': momentum}
moment_out = decay_rate * moment + (1 - decay_rate) * grad * grad ms_out = decay * mean_square + (1 - decay) * grad * grad
param_out = param - learning_rate * grad / (np.sqrt(moment_out) + moment_out = momentum * moment + \
epsilon) learning_rate * grad / np.sqrt(ms_out + epsilon)
param_out = param - moment_out
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out} self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'MeanSquareOut': ms_out
}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册