提交 48f98a67 编写于 作者: K kavyasrinet 提交者: GitHub

Merge pull request #4565 from kavyasrinet/rmsprop

Adding the implementation for rmsprop operator
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/rmsprop_op.h"
namespace paddle {
namespace operators {
class RmspropOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("MeanSquare"),
"Input(MeanSquare) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MomentOut"),
"Output(Momentum_out) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"),
"Output(MeanSquareOut) of RmspropOp should not be null.");
auto param_dim = ctx->GetInputDim("Param");
PADDLE_ENFORCE_EQ(
param_dim, ctx->GetInputDim("Grad"),
"Param and grad input of RmspropOp should have the same dimension.");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"),
"Param and Momentum input of RmspropOp "
"should have the same dimension.");
PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"),
"Param and Momentum input of RmspropOp "
"should have the same dimension.");
auto lr_dim = ctx->GetInputDim("LearningRate");
PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1,
"Learning Rate should be a scalar.");
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim);
ctx->SetOutputDim("MeanSquareOut", param_dim);
}
};
class RmspropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RmspropOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Param",
"(Tensor, default Tensor<float>) "
"Input parameter value that has to be updated");
AddInput("MeanSquare",
"(Tensor, default Tensor<float>)"
" The mean square value that gets updated");
AddInput("LearningRate",
"(Tensor, default Tensor<float>) "
"The learning rate should be a tensor of size 1");
AddInput("Grad",
"(Tensor, default Tensor<float>) "
"Input gradient of the parameter");
AddInput("Moment",
"(Tensor, default Tensor<float>) The moment that gets updated");
AddOutput("ParamOut", "(Tensor) Output updated parameter value");
AddOutput("MomentOut", "(Tensor) Output updated moment");
AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value");
AddAttr<float>("epsilon",
"(float, default 1e-10) Constant "
"for numerical stability.")
.SetDefault(1.0e-10f);
AddAttr<float>("decay",
"(float, default 0.9) "
"Discounting factor for coming gradient.")
.SetDefault(0.9f);
AddAttr<float>("momentum", "(float, default 0.0) Constant value")
.SetDefault(0.0f);
AddComment(R"DOC(
RMSprop
MeanSquareOut = decay * MeanSquare + (1 - decay) * Grad * Grad
MomentOut = momentum * Moment +
LearningRate * Grad / sqrt(MeanSquareOut + epsilon)
ParamOut = Param - MomentOut
The original slides that proposed RMSprop: Slide 29 of
http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker);
REGISTER_OP_CPU_KERNEL(rmsprop,
ops::RmspropOpKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/rmsprop_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(rmsprop,
ops::RmspropOpKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class RmspropOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* moment_out = ctx.Output<Tensor>("MomentOut");
auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
auto grad = ctx.Input<Tensor>("Grad");
param_out->mutable_data<T>(ctx.GetPlace());
moment_out->mutable_data<T>(ctx.GetPlace());
mean_square_out->mutable_data<T>(ctx.GetPlace());
float epsilon = ctx.Attr<float>("epsilon");
float rho = ctx.Attr<float>("decay");
float momentum = ctx.Attr<float>("momentum");
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare"));
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
auto g = EigenVector<T>::Flatten(*grad);
auto mom = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
auto p_out = EigenVector<T>::Flatten(*param_out);
auto mom_out = EigenVector<T>::Flatten(*moment_out);
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
auto place = ctx.GetEigenDevice<Place>();
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
mom_out.device(place) =
momentum * mom +
lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt();
p_out.device(place) = p - mom_out;
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestRmspropOp1(OpTest):
''' Test RMSProp with explicit inputs
'''
def setUp(self):
self.op_type = "rmsprop"
param = np.random.random((123, 321)).astype("float32")
mean_square = np.random.random((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
epsilon = 1e-6
decay = 0.9
momentum = 0.0
self.inputs = {
'Param': param,
'MeanSquare': mean_square,
'LearningRate': learning_rate,
'Grad': grad,
'Moment': moment,
}
self.attrs = {'epsilon': epsilon, 'decay': decay, 'momentum': momentum}
ms_out = decay * mean_square + (1 - decay) * grad * grad
moment_out = momentum * moment + \
learning_rate * grad / np.sqrt(ms_out + epsilon)
param_out = param - moment_out
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'MeanSquareOut': ms_out
}
def test_check_output(self):
self.check_output()
class TestRmspropOp2(OpTest):
'''Test RMSProp with defaukt values for attributes
'''
def setUp(self):
self.op_type = "rmsprop"
param = np.random.random((123, 321)).astype("float32")
mean_square = np.random.random((123, 321)).astype("float32")
learning_rate = np.array([0.01]).astype("float32")
grad = np.random.random((123, 321)).astype("float32")
moment = np.zeros((123, 321)).astype("float32")
epsilon = 1.0e-10
decay = 0.9
momentum = 0.0
self.inputs = {
'Param': param,
'MeanSquare': mean_square,
'LearningRate': learning_rate,
'Grad': grad,
'Moment': moment,
}
ms_out = decay * mean_square + (1 - decay) * grad * grad
moment_out = momentum * moment + \
learning_rate * grad / np.sqrt(ms_out + epsilon)
param_out = param - moment_out
self.outputs = {
'ParamOut': param_out,
'MomentOut': moment_out,
'MeanSquareOut': ms_out
}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册