From d3b8bffaf1615f58622c384ec3a5f200d020e539 Mon Sep 17 00:00:00 2001 From: kexinzhao <19hskevin87@gmail.com> Date: Wed, 11 Oct 2017 23:49:21 -0700 Subject: [PATCH] Implementing the Decayed Adagrad optimizer operator (#4645) * Implementing the DecayedAdagrad optimizer step operator * implementing DecayedAdagrad operator * remove file * small fix --- paddle/operators/decayed_adagrad_op.cc | 96 +++++++++++++++++++ paddle/operators/decayed_adagrad_op.cu | 21 ++++ paddle/operators/decayed_adagrad_op.h | 56 +++++++++++ .../tests/test_decayed_adagrad_op.py | 71 ++++++++++++++ 4 files changed, 244 insertions(+) create mode 100644 paddle/operators/decayed_adagrad_op.cc create mode 100644 paddle/operators/decayed_adagrad_op.cu create mode 100644 paddle/operators/decayed_adagrad_op.h create mode 100644 python/paddle/v2/framework/tests/test_decayed_adagrad_op.py diff --git a/paddle/operators/decayed_adagrad_op.cc b/paddle/operators/decayed_adagrad_op.cc new file mode 100644 index 0000000000..ca5141dabc --- /dev/null +++ b/paddle/operators/decayed_adagrad_op.cc @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/decayed_adagrad_op.h" + +namespace paddle { +namespace operators { + +class DecayedAdagradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of DecayedAdagradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of DecayedAdagradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Moment"), + "Input(Moment) of DecayedAdagradOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("LearningRate"), + "Input(LearningRate) of DecayedAdagradOp should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of DecayedAdagradOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("MomentOut"), + "Output(MomentOut) of DecayedAdagradOp should not be null."); + + auto lr_dims = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1, + "LearningRate should have one element"); + auto param_dims = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ(param_dims, ctx->GetInputDim("Grad"), + "Param and Grad input of DecayedAdagradOp should have " + "the same dimension."); + PADDLE_ENFORCE_EQ(param_dims, ctx->GetInputDim("Moment"), + "Param and Moment input of DecayedAdagradOp should have " + "the same dimension."); + + ctx->SetOutputDim("ParamOut", param_dims); + ctx->SetOutputDim("MomentOut", param_dims); + } +}; + +class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker { + public: + DecayedAdagradOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", "(Tensor) Input parameter"); + AddInput("Grad", "(Tensor) Input gradient"); + AddInput("Moment", "(Tensor) Second moment"); + AddInput("LearningRate", "(Tensor) Learning rate"); + + AddOutput("ParamOut", "(Tensor) Output parameter"); + AddOutput("MomentOut", "(Tensor) Output second moment"); + + AddAttr("decay", + "(float, default 0.95) " + "Discounting factor for coming gradient") + .SetDefault(0.95); + AddAttr("epsilon", + "(float, default 1.0e-6) " + "Constant for numerical stability") + .SetDefault(1.0e-6f); + AddComment(R"DOC( + +Decayed Adagrad + +moment_out = decay * moment + (1 - decay) * grad * grad +param_out = param - learning_rate * grad / (sqrt(moment_out) + epsilon) + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(decayed_adagrad, ops::DecayedAdagradOp, + ops::DecayedAdagradOpMaker); +REGISTER_OP_CPU_KERNEL( + decayed_adagrad, + ops::DecayedAdagradOpKernel); diff --git a/paddle/operators/decayed_adagrad_op.cu b/paddle/operators/decayed_adagrad_op.cu new file mode 100644 index 0000000000..6fce77fe4e --- /dev/null +++ b/paddle/operators/decayed_adagrad_op.cu @@ -0,0 +1,21 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/decayed_adagrad_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + decayed_adagrad, + ops::DecayedAdagradOpKernel); diff --git a/paddle/operators/decayed_adagrad_op.h b/paddle/operators/decayed_adagrad_op.h new file mode 100644 index 0000000000..0fe0fc5acd --- /dev/null +++ b/paddle/operators/decayed_adagrad_op.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class DecayedAdagradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto param_out_tensor = ctx.Output("ParamOut"); + auto moment_out_tensor = ctx.Output("MomentOut"); + + param_out_tensor->mutable_data(ctx.GetPlace()); + moment_out_tensor->mutable_data(ctx.GetPlace()); + + float decay = ctx.Attr("decay"); + float epsilon = ctx.Attr("epsilon"); + + auto param = framework::EigenVector::Flatten( + *ctx.Input("Param")); + auto grad = framework::EigenVector::Flatten( + *ctx.Input("Grad")); + auto moment = framework::EigenVector::Flatten( + *ctx.Input("Moment")); + auto lr = framework::EigenVector::Flatten( + *ctx.Input("LearningRate")); + + auto param_out = framework::EigenVector::Flatten(*param_out_tensor); + auto moment_out = framework::EigenVector::Flatten(*moment_out_tensor); + auto place = ctx.GetEigenDevice(); + + moment_out.device(place) = decay * moment + (1 - decay) * grad * grad; + Eigen::DSizes m_dsize(moment_out_tensor->numel()); + param_out.device(place) = + param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_decayed_adagrad_op.py b/python/paddle/v2/framework/tests/test_decayed_adagrad_op.py new file mode 100644 index 0000000000..674c3fda5c --- /dev/null +++ b/python/paddle/v2/framework/tests/test_decayed_adagrad_op.py @@ -0,0 +1,71 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestDecayedAdagradOp1(OpTest): + ''' Test DecayedAdagrad operator with explicit attributes + ''' + + def setUp(self): + self.op_type = "decayed_adagrad" + + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + moment = np.zeros((123, 321)).astype("float32") + lr = 0.01 + decay = 0.80 + epsilon = 1e-8 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment': moment, + 'LearningRate': np.array([lr]).astype("float32") + } + + self.attrs = {'decay': decay, 'epsilon': epsilon} + + moment_out = decay * moment + (1 - decay) * grad * grad + param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon) + + self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out} + + def test_check_output(self): + self.check_output() + + +class TestDecayedAdagradOp2(OpTest): + ''' Test DecayedAdagrad operator with default attributes + ''' + + def setUp(self): + self.op_type = "decayed_adagrad" + + param = np.random.random((123, 321)).astype("float32") + grad = np.random.random((123, 321)).astype("float32") + moment = np.zeros((123, 321)).astype("float32") + lr = 0.01 + decay = 0.95 + epsilon = 1e-6 + + self.inputs = { + 'Param': param, + 'Grad': grad, + 'Moment': moment, + 'LearningRate': np.array([lr]).astype("float32") + } + + self.attrs = {'decay': decay, 'epsilon': epsilon} + + moment_out = decay * moment + (1 - decay) * grad * grad + param_out = param - lr * grad / (np.sqrt(moment_out) + epsilon) + + self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() -- GitLab