diff --git a/paddle/operators/proximal_adagrad_op.cc b/paddle/operators/proximal_adagrad_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..39fbf800031cd559a49654667e5a6f634384523d --- /dev/null +++ b/paddle/operators/proximal_adagrad_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/proximal_adagrad_op.h" + +namespace paddle { +namespace operators { + +class ProximalAdagradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of ProximalAdagradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Moment"), + "Input(Moment) of ProximalAdagradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of ProximalAdagradOp should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("LearningRate"), + "Input(LearningRate) of ProximalAdagradOp should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of ProximalAdagradOp should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("MomentOut"), + "Output(MomentOut) of ProximalAdagradOp should not be null."); + + auto param_dim = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Grad"), + "Param and Grad of ProximalAdagrad Op must have same dimension."); + + PADDLE_ENFORCE_EQ( + param_dim, ctx->GetInputDim("Moment"), + "Param and Moment of ProximalAdagrad Op must have same dimension."); + + auto lr_dim = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, + "Learning Rate should be a scalar."); + + ctx->SetOutputDim("ParamOut", param_dim); + ctx->SetOutputDim("MomentOut", param_dim); + } +}; + +class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ProximalAdagradOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter that has to be updated."); + AddInput("Moment", + "(Tensor, default Tensor) " + "Moment parameter that has to be updated."); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter."); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "The learning rate should be a tensor of size 1."); + + AddOutput("ParamOut", "(Tensor) Output updated parameter value."); + AddOutput("MomentOut", "(Tensor) Output updated moment value."); + + AddAttr("l1", + "(float, default 0.0) " + "L1 regularization strength.") + .SetDefault(0.0f); + AddAttr("l2", + "(float, default 0.0)" + "L2 regularization strength.") + .SetDefault(0.0f); + AddComment(R"DOC( + +Optimizer that implements the proximal adagrad algorithm. + +moment = moment + grad * grad +prox_param = param - learning_rate * grad * (1 / sqrt(moment)) +param = sign(prox_param) / (1 + learning_rate * l2) * + max { |prox_param| - learning_rate * l1 , 0 } + +The paper that proposed Proximal GD: +(http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf) +Here, we use the adagrad learning rate as specified here: +(http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(proximal_adagrad, ops::ProximalAdagradOp, + ops::ProximalAdagradOpMaker); +REGISTER_OP_CPU_KERNEL( + proximal_adagrad, + ops::ProximalAdagradOpKernel); diff --git a/paddle/operators/proximal_adagrad_op.cu b/paddle/operators/proximal_adagrad_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..d0ae0395184ae4f794565f2e28c57f960f0ccbeb --- /dev/null +++ b/paddle/operators/proximal_adagrad_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed +under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/proximal_adagrad_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + proximal_adagrad, + ops::ProximalAdagradOpKernel); diff --git a/paddle/operators/proximal_adagrad_op.h b/paddle/operators/proximal_adagrad_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7a1560e8cb339a306ab19513808aab165f82cc8a --- /dev/null +++ b/paddle/operators/proximal_adagrad_op.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class ProximalAdagradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* param_out = ctx.Output("ParamOut"); + auto* moment_out = ctx.Output("MomentOut"); + + param_out->mutable_data(ctx.GetPlace()); + moment_out->mutable_data(ctx.GetPlace()); + + auto l1 = static_cast(ctx.Attr("l1")); + auto l2 = static_cast(ctx.Attr("l2")); + + auto grad = ctx.Input("Grad"); + auto p = EigenVector::Flatten(*ctx.Input("Param")); + auto m = EigenVector::Flatten(*ctx.Input("Moment")); + auto g = EigenVector::Flatten(*grad); + auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); + + auto p_out = EigenVector::Flatten(*param_out); + auto m_out = EigenVector::Flatten(*moment_out); + auto place = ctx.GetEigenDevice(); + + Eigen::DSizes grad_dsize(grad->numel()); + + m_out.device(place) = m + g * g; + auto prox_param = p - lr.broadcast(grad_dsize) * g / m_out.sqrt(); + if (l1 > static_cast(0)) { + p_out.device(place) = + prox_param.sign() * + (((prox_param.abs() - (lr * l1).broadcast(grad_dsize)) + .cwiseMax(static_cast(0.0))) / + (static_cast(1.0) + (lr * l2).broadcast(grad_dsize))); + } else { + p_out.device(place) = + prox_param / (static_cast(1.0) + (lr * l2).broadcast(grad_dsize)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_proximal_adagrad_op.py b/python/paddle/v2/framework/tests/test_proximal_adagrad_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f89a493ab7a7a3d841088b7db37bff4dfbe63735 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_proximal_adagrad_op.py @@ -0,0 +1,36 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestProximalAdagradOp(OpTest): + def setUp(self): + self.op_type = "proximal_adagrad" + w = np.random.random((102, 105)).astype("float32") + m = np.random.random((102, 105)).astype("float32") + g = np.random.random((102, 105)).astype("float32") + lr = np.array([0.1]).astype("float32") + l1 = 0.1 + l2 = 0.2 + + self.inputs = {'Param': w, 'Grad': g, 'Moment': m, 'LearningRate': lr} + self.attrs = {'l1': l1, 'l2': l2} + param_out = 0.0 + + moment_out = m + g * g + prox_param = w - lr * g / np.sqrt(moment_out) + if l1 > 0.0: + x = np.abs(prox_param) - lr * l1 + x[x < 0] = 0 + param_out = np.sign(prox_param) * (x / (1.0 + lr * l2)) + else: + param_out = prox_param / (1.0 + lr * l2) + + self.outputs = {'ParamOut': param_out, 'MomentOut': moment_out} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main()