提交 94855f4a 编写于 作者: K Kavya Srinet

Fixed changes proposed in the review

上级 163d2871
......@@ -25,25 +25,32 @@ class RmspropOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("MeanSquare"),
"Input(MeanSquare) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
"Output(moment_out) of RmspropOp should not be null.");
"Output(Momentum_out) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"),
"Output(MeanSquareOut) of RmspropOp should not be null.");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and grad input of RmspropOp should have the same dimension.");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Moment"),
"Param and moment input of RmspropOp should have the same dimension.");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"),
"Param and Momentum input of RmspropOp "
"should have the same dimension.");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"),
"Param and Momentum input of RmspropOp "
"should have the same dimension.");
auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
......@@ -51,6 +58,7 @@ class RmspropOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim);
ctx->SetOutputDim("MeanSquareOut", param_dim);
}
};
......@@ -59,27 +67,46 @@ class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
RmspropOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param", "Input parameter");
AddInput("Grad", "Input gradient");
AddInput("Moment", "Second moment");
AddInput("LearningRate", "Learning Rate");
AddOutput("ParamOut", "Output parameter");
AddOutput("MomentOut", "Output second moment");
AddAttr<float>("epsilon", "Constant for numerical stability");
AddAttr<float>("decayRate", "Decay rate for moving average of gradients");
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated");
AddInput("MeanSquare",
"(Tensor, default Tensor<float>)"
" The mean square value that gets updated");
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"The learning rate should be a tensor of size 1");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter");
AddInput("Moment",
"(Tensor, default Tensor<float>) The moment that gets updated");
AddOutput("ParamOut", "(Tensor) Output updated parameter value");
AddOutput("MomentOut", "(Tensor) Output updated moment");
AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value");
AddAttr<float>("epsilon",
"(float, default 1e-10) Constant "
"for numerical stability.")
.SetDefault(1e-10);
AddAttr<float>("decay",
"(float, default 0.9) "
"Discounting factor for coming gradient.")
.SetDefault(0.9);
AddAttr<float>("momentum", "(float, default 0.0) Constant value")
.SetDefault(0.0);
AddComment(R"DOC(
RMSprop
MomentOut = decayRate * Moment + (1 - decayRate) * Grad * Grad
ParamOut = Param - LearningRate * Grad / (sqrt(MomentOut) + epsilon)
MeanSquareOut = decay * MeanSquare + (1 - decay) * Grad * Grad
MomentOut = momentum * Moment +
LearningRate * Grad / sqrt(MeanSquareOut + epsilon)
ParamOut = Param - MomentOut
The original slide(Slide 29 of
The original slides that proposed RMSprop: Slide 29 of
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
does not have the epsilon attribute. It is added here for numerical stability
to avoid division by zero.
)DOC");
}
......
......@@ -30,23 +30,30 @@ class RmspropOpKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto param_out = ctx.Output<Tensor>("ParamOut");
auto moment_out = ctx.Output<Tensor>("MomentOut");
auto mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace());
mean_square_out->mutable_data<T>(ctx.GetPlace());
float epsilon = ctx.Attr<float>("epsilon");
float decay = ctx.Attr<float>("decayRate");
float rho = ctx.Attr<float>("decay");
float momentum = ctx.Attr<float>("momentum");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad"));
auto m = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare"));
float lr = ctx.Input<Tensor>("LearningRate")->data<float>()[0];
auto g = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Grad"));
auto mom = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
auto p_out = EigenVector<T>::Flatten(*param_out);
auto m_out = EigenVector<T>::Flatten(*moment_out);
auto mom_out = EigenVector<T>::Flatten(*moment_out);
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
auto place = ctx.GetEigenDevice<Place>();
m_out.device(place) = decay * m + (1 - decay) * g * g;
p_out.device(place) = p - lr * g / (m_out.sqrt() + epsilon);
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
mom_out.device(place) = momentum * mom + lr * g / (ms_out + epsilon).sqrt();
p_out.device(place) = p - mom_out;
}
};
......
......@@ -8,27 +8,35 @@ class TestRmspropOp(OpTest):
self.op_type = "rmsprop"
param = np.random.random((123, 321)).astype("float32")
mean_square = np.random.random((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
epsilon = 1e-6
decay_rate = 0.9
decay = 0.9
momentum = 0.0
self.inputs = {
'Param': param,
'MeanSquare': mean_square,
'LearningRate': learning_rate,
'Grad': grad,
'Moment': moment,
'LearningRate': learning_rate
}
self.attrs = {'epsilon': epsilon, 'decayRate': decay_rate}
self.attrs = {'epsilon': epsilon, 'decay': decay, 'momentum': momentum}
moment_out = decay_rate * moment + (1 - decay_rate) * grad * grad
param_out = param - learning_rate * grad / (np.sqrt(moment_out) +
epsilon)
ms_out = decay * mean_square + (1 - decay) * grad * grad
moment_out = momentum * moment + \
learning_rate * grad / np.sqrt(ms_out + epsilon)
param_out = param - moment_out
self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out}
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'MeanSquareOut': ms_out
}
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册