diff --git a/paddle/operators/rmsprop_op.cc b/paddle/operators/rmsprop_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8f61c7fdda9f80c69745a9bc4569fcbc099630aa --- /dev/null +++ b/paddle/operators/rmsprop_op.cc @@ -0,0 +1,120 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/rmsprop_op.h" + +namespace paddle { +namespace operators { + +class RmspropOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("MeanSquare"), + "Input(MeanSquare) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Moment"), + "Input(Moment) of RmspropOp should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(param_out) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("MomentOut"), + "Output(Momentum_out) of RmspropOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("MeanSquareOut"), + "Output(MeanSquareOut) of RmspropOp should not be null."); + + auto param_dim = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Grad"), + "Param and grad input of RmspropOp should have the same dimension."); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Moment"), + "Param and Momentum input of RmspropOp " + "should have the same dimension."); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("MeanSquare"), + "Param and Momentum input of RmspropOp " + "should have the same dimension."); + + auto lr_dim = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, + "Learning Rate should be a scalar."); + + ctx->SetOutputDim("ParamOut", param_dim); + ctx->SetOutputDim("MomentOut", param_dim); + ctx->SetOutputDim("MeanSquareOut", param_dim); + } +}; + +class RmspropOpMaker : public framework::OpProtoAndCheckerMaker { + public: + RmspropOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter value that has to be updated"); + AddInput("MeanSquare", + "(Tensor, default Tensor)" + " The mean square value that gets updated"); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "The learning rate should be a tensor of size 1"); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter"); + AddInput("Moment", + "(Tensor, default Tensor) The moment that gets updated"); + + AddOutput("ParamOut", "(Tensor) Output updated parameter value"); + AddOutput("MomentOut", "(Tensor) Output updated moment"); + AddOutput("MeanSquareOut", "(Tensor) Output Mean squared updated value"); + + AddAttr("epsilon", + "(float, default 1e-10) Constant " + "for numerical stability.") + .SetDefault(1.0e-10f); + AddAttr("decay", + "(float, default 0.9) " + "Discounting factor for coming gradient.") + .SetDefault(0.9f); + AddAttr("momentum", "(float, default 0.0) Constant value") + .SetDefault(0.0f); + AddComment(R"DOC( + +RMSprop + +MeanSquareOut = decay * MeanSquare + (1 - decay) * Grad * Grad +MomentOut = momentum * Moment + + LearningRate * Grad / sqrt(MeanSquareOut + epsilon) +ParamOut = Param - MomentOut + +The original slides that proposed RMSprop: Slide 29 of +http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(rmsprop, ops::RmspropOp, ops::RmspropOpMaker); +REGISTER_OP_CPU_KERNEL(rmsprop, + ops::RmspropOpKernel); diff --git a/paddle/operators/rmsprop_op.cu b/paddle/operators/rmsprop_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..52634a54816bcd5ad0ba82a56f1df95110112265 --- /dev/null +++ b/paddle/operators/rmsprop_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/rmsprop_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(rmsprop, + ops::RmspropOpKernel); diff --git a/paddle/operators/rmsprop_op.h b/paddle/operators/rmsprop_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7bf2129010f994966d79ef11d5cec30159b47068 --- /dev/null +++ b/paddle/operators/rmsprop_op.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class RmspropOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* param_out = ctx.Output("ParamOut"); + auto* moment_out = ctx.Output("MomentOut"); + auto* mean_square_out = ctx.Output("MeanSquareOut"); + + auto grad = ctx.Input("Grad"); + + param_out->mutable_data(ctx.GetPlace()); + moment_out->mutable_data(ctx.GetPlace()); + mean_square_out->mutable_data(ctx.GetPlace()); + + float epsilon = ctx.Attr("epsilon"); + float rho = ctx.Attr("decay"); + float momentum = ctx.Attr("momentum"); + + auto p = EigenVector::Flatten(*ctx.Input("Param")); + auto ms = EigenVector::Flatten(*ctx.Input("MeanSquare")); + auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); + auto g = EigenVector::Flatten(*grad); + auto mom = EigenVector::Flatten(*ctx.Input("Moment")); + + auto p_out = EigenVector::Flatten(*param_out); + auto mom_out = EigenVector::Flatten(*moment_out); + auto ms_out = EigenVector::Flatten(*mean_square_out); + auto place = ctx.GetEigenDevice(); + + Eigen::DSizes grad_dsize(grad->numel()); + + ms_out.device(place) = rho * ms + (1 - rho) * g * g; + mom_out.device(place) = + momentum * mom + + lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt(); + p_out.device(place) = p - mom_out; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_rmsprop_op.py b/python/paddle/v2/framework/tests/test_rmsprop_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3e5ff733e9b55fe8c9727e9721e25083a494be15 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_rmsprop_op.py @@ -0,0 +1,89 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestRmspropOp1(OpTest): + ''' Test RMSProp with explicit inputs + ''' + + def setUp(self): + self.op_type = "rmsprop" + + param = np.random.random((123, 321)).astype("float32") + mean_square = np.random.random((123, 321)).astype("float32") + learning_rate = np.array([0.01]).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + moment = np.zeros((123, 321)).astype("float32") + + epsilon = 1e-6 + decay = 0.9 + momentum = 0.0 + + self.inputs = { + 'Param': param, + 'MeanSquare': mean_square, + 'LearningRate': learning_rate, + 'Grad': grad, + 'Moment': moment, + } + + self.attrs = {'epsilon': epsilon, 'decay': decay, 'momentum': momentum} + + ms_out = decay * mean_square + (1 - decay) * grad * grad + moment_out = momentum * moment + \ + learning_rate * grad / np.sqrt(ms_out + epsilon) + param_out = param - moment_out + + self.outputs = { + 'ParamOut': param_out, + 'MomentOut': moment_out, + 'MeanSquareOut': ms_out + } + + def test_check_output(self): + self.check_output() + + +class TestRmspropOp2(OpTest): + '''Test RMSProp with defaukt values for attributes + ''' + + def setUp(self): + self.op_type = "rmsprop" + + param = np.random.random((123, 321)).astype("float32") + mean_square = np.random.random((123, 321)).astype("float32") + learning_rate = np.array([0.01]).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + moment = np.zeros((123, 321)).astype("float32") + + epsilon = 1.0e-10 + decay = 0.9 + momentum = 0.0 + + self.inputs = { + 'Param': param, + 'MeanSquare': mean_square, + 'LearningRate': learning_rate, + 'Grad': grad, + 'Moment': moment, + } + + ms_out = decay * mean_square + (1 - decay) * grad * grad + moment_out = momentum * moment + \ + learning_rate * grad / np.sqrt(ms_out + epsilon) + param_out = param - moment_out + + self.outputs = { + 'ParamOut': param_out, + 'MomentOut': moment_out, + 'MeanSquareOut': ms_out + } + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main()