From c10b8e808fc88d96ce0b4f864014bd461098de87 Mon Sep 17 00:00:00 2001 From: kavyasrinet Date: Wed, 18 Oct 2017 16:21:16 -0700 Subject: [PATCH] Adding Proximal Gradient Descent (#4848) * Adding Proximal Gradient Descent * Fixing review comments --- paddle/operators/proximal_gd_op.cc | 93 +++++++++++++++++++ paddle/operators/proximal_gd_op.cu | 19 ++++ paddle/operators/proximal_gd_op.h | 64 +++++++++++++ .../v2/framework/tests/test_proximal_gd_op.py | 33 +++++++ 4 files changed, 209 insertions(+) create mode 100644 paddle/operators/proximal_gd_op.cc create mode 100644 paddle/operators/proximal_gd_op.cu create mode 100644 paddle/operators/proximal_gd_op.h create mode 100644 python/paddle/v2/framework/tests/test_proximal_gd_op.py diff --git a/paddle/operators/proximal_gd_op.cc b/paddle/operators/proximal_gd_op.cc new file mode 100644 index 0000000000..e4b014b9f5 --- /dev/null +++ b/paddle/operators/proximal_gd_op.cc @@ -0,0 +1,93 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/proximal_gd_op.h" + +namespace paddle { +namespace operators { + +class ProximalGDOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Param"), + "Input(Param) of ProximalGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Grad"), + "Input(Grad) of ProximalGDOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("LearningRate"), + "Input(LearningRate) of ProximalGDOp should not be null."); + + PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), + "Output(ParamOut) of ProximalGDOp should not be null."); + + auto param_dim = ctx->GetInputDim("Param"); + PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("Grad"), + "Two input of ProximalGD Op's dimension must be same."); + + auto lr_dim = ctx->GetInputDim("LearningRate"); + PADDLE_ENFORCE_EQ(framework::product(lr_dim), 1, + "Learning Rate should be a scalar."); + + ctx->SetOutputDim("ParamOut", param_dim); + } +}; + +class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ProximalGDOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Param", + "(Tensor, default Tensor) " + "Input parameter value that has to be updated."); + AddInput("Grad", + "(Tensor, default Tensor) " + "Input gradient of the parameter."); + AddInput("LearningRate", + "(Tensor, default Tensor) " + "The learning rate should be a tensor of size 1."); + + AddOutput("ParamOut", "(Tensor) Output updated parameter value."); + + AddAttr("l1", + "(float, default 0.0) " + "L1 regularization strength.") + .SetDefault(0.0f); + AddAttr("l2", + "(float, default 0.0)" + "L2 regularization strength.") + .SetDefault(0.0f); + AddComment(R"DOC( + +Optimizer that implements the proximal gradient descent algorithm. + +prox_param = param - learning_rate * grad +param = sign(prox_param) / (1 + learning_rate * l2) * + max { |prox_param| - learning_rate * l1 , 0 } + +The paper that proposed Proximal Gradient Descent: +(http://papers.nips.cc/paper/3793-efficient-learning-using-forward-backward-splitting.pdf) +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(proximal_gd, ops::ProximalGDOp, + ops::ProximalGDOpMaker); +REGISTER_OP_CPU_KERNEL( + proximal_gd, ops::ProximalGDOpKernel); diff --git a/paddle/operators/proximal_gd_op.cu b/paddle/operators/proximal_gd_op.cu new file mode 100644 index 0000000000..26f4ebaa0f --- /dev/null +++ b/paddle/operators/proximal_gd_op.cu @@ -0,0 +1,19 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +You may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software distributed +under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +CONDITIONS OF ANY KIND, either express or implied. See the License for the +specific language governing permissions and limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/proximal_gd_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + proximal_gd, ops::ProximalGDOpKernel); diff --git a/paddle/operators/proximal_gd_op.h b/paddle/operators/proximal_gd_op.h new file mode 100644 index 0000000000..bebda02041 --- /dev/null +++ b/paddle/operators/proximal_gd_op.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +class ProximalGDOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* param_out = ctx.Output("ParamOut"); + + param_out->mutable_data(ctx.GetPlace()); + + auto grad = ctx.Input("Grad"); + + auto l1 = static_cast(ctx.Attr("l1")); + auto l2 = static_cast(ctx.Attr("l2")); + + auto p = EigenVector::Flatten(*ctx.Input("Param")); + auto g = EigenVector::Flatten(*grad); + auto lr = EigenVector::Flatten(*ctx.Input("LearningRate")); + + auto p_out = EigenVector::Flatten(*param_out); + auto place = ctx.GetEigenDevice(); + + Eigen::DSizes grad_dsize(grad->numel()); + + auto prox_param = p - lr.broadcast(grad_dsize) * g; + if (l1 > 0) { + p_out.device(place) = + prox_param.sign() * + (((prox_param.abs() - (lr * l1).broadcast(grad_dsize)) + .cwiseMax(T(0.0))) / + (1.0 + (lr * l2).broadcast(grad_dsize))); + } else { + p_out.device(place) = + prox_param / (1.0 + (lr * l2).broadcast(grad_dsize)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_proximal_gd_op.py b/python/paddle/v2/framework/tests/test_proximal_gd_op.py new file mode 100644 index 0000000000..9ca79ce6b3 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_proximal_gd_op.py @@ -0,0 +1,33 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestProximalGDOp(OpTest): + def setUp(self): + self.op_type = "proximal_gd" + w = np.random.random((102, 105)).astype("float32") + g = np.random.random((102, 105)).astype("float32") + lr = np.array([0.1]).astype("float32") + l1 = 0.1 + l2 = 0.2 + + self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr} + self.attrs = {'l1': l1, 'l2': l2} + prox_param = w - lr * g + param_out = 0.0 + if l1 > 0.0: + x = np.abs(prox_param) - lr * l1 + x[x < 0] = 0 + param_out = np.sign(prox_param) * (x / (1.0 + lr * l2)) + else: + param_out = prox_param / (1.0 + lr * l2) + + self.outputs = {'ParamOut': param_out} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() -- GitLab