diff --git a/paddle/fluid/operators/optimizers/adadelta_op.cc b/paddle/fluid/operators/optimizers/adadelta_op.cc index bde7131379a272e31fb1effe2f92204fa27f9a14..e3da79125be24f3156b10a4d1daedd3db2b841cf 100644 --- a/paddle/fluid/operators/optimizers/adadelta_op.cc +++ b/paddle/fluid/operators/optimizers/adadelta_op.cc @@ -24,49 +24,69 @@ class AdadeltaOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Param"), - "Input(Param) of AdadeltaOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Grad"), - "Input(Grad) of AdadeltaOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("AvgSquaredGrad"), - "Input(AvgSquaredGrad) of AdadeltaOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"), - "Input(AvgSquaredUpdate) of AdadeltaOp should not be null."); - PADDLE_ENFORCE( + PADDLE_ENFORCE_EQ(ctx->HasInput("Param"), true, + platform::errors::InvalidArgument( + "Input(Param) of AdadeltaOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Grad"), true, + platform::errors::InvalidArgument( + "Input(Grad) of AdadeltaOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("AvgSquaredGrad"), true, + platform::errors::InvalidArgument( + "Input(AvgSquaredGrad) of AdadeltaOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput("AvgSquaredUpdate"), true, + platform::errors::InvalidArgument( + "Input(AvgSquaredUpdate) of AdadeltaOp should not be null.")); + PADDLE_ENFORCE_EQ( ctx->GetInputsVarType("Param").front() == framework::proto::VarType::LOD_TENSOR, - "The input var's type should be LoDTensor, but the received is %s", - ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front()); - PADDLE_ENFORCE( + true, + platform::errors::InvalidArgument( + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Param").front(), + ctx->GetInputsVarType("Param").front())); + PADDLE_ENFORCE_EQ( ctx->GetInputsVarType("Grad").front() == framework::proto::VarType::LOD_TENSOR, - "The input var's type should be LoDTensor, but the received is %s", - ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front()); - - PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), - "Output(ParamOut) of AdadeltaOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("AvgSquaredGradOut"), - "Output(AvgSquaredGradOut) of AdadeltaOp should not be null."); - PADDLE_ENFORCE( - ctx->HasOutput("AvgSquaredUpdateOut"), - "Output(AvgSquaredUpdateOut) of AdadeltaOp should not be null."); + true, + platform::errors::InvalidArgument( + "The input var's type should be LoDTensor, but the received is %s", + ctx->Inputs("Grad").front(), + ctx->GetInputsVarType("Grad").front())); + + PADDLE_ENFORCE_EQ( + ctx->HasOutput("ParamOut"), true, + platform::errors::InvalidArgument( + "Output(ParamOut) of AdadeltaOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("AvgSquaredGradOut"), true, + platform::errors::InvalidArgument( + "Output(AvgSquaredGradOut) of AdadeltaOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("AvgSquaredUpdateOut"), true, + platform::errors::InvalidArgument( + "Output(AvgSquaredUpdateOut) of AdadeltaOp should not be null.")); auto param_dim = ctx->GetInputDim("Param"); PADDLE_ENFORCE_EQ( param_dim, ctx->GetInputDim("Grad"), "param and grad input of AdadeltaOp should have same dimension"); - PADDLE_ENFORCE_NE(framework::product(ctx->GetInputDim("AvgSquaredGrad")), 0, - "Maybe the Input variable AvgSquaredGrad has not " - "been initialized. You may need to confirm if you put " - "exe.run(startup_program) after optimizer.minimize " - "function."); + PADDLE_ENFORCE_NE( + framework::product(ctx->GetInputDim("AvgSquaredGrad")), 0, + platform::errors::InvalidArgument( + "Maybe the Input variable AvgSquaredGrad has not " + "been initialized. You may need to confirm if you put " + "exe.run(startup_program) after optimizer.minimize " + "function.")); PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredGrad"), - "Param and AvgSquaredGrad input of AdadeltaOp " - "should have same dimension"); + platform::errors::InvalidArgument( + "Param and AvgSquaredGrad input of AdadeltaOp " + "should have same dimension")); PADDLE_ENFORCE_EQ(param_dim, ctx->GetInputDim("AvgSquaredUpdate"), - "Param and AvgSquaredUpdate input of AdadeltaOp " - "should have same dimension"); + platform::errors::InvalidArgument( + "Param and AvgSquaredUpdate input of AdadeltaOp " + "should have same dimension")); ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("AvgSquaredGradOut", param_dim); diff --git a/paddle/fluid/operators/optimizers/adadelta_op.h b/paddle/fluid/operators/optimizers/adadelta_op.h index e66dec7cf0ff686f91103e438b6374fce29af774..85cfad35858bbe6b112169f196c0711d981e9446 100644 --- a/paddle/fluid/operators/optimizers/adadelta_op.h +++ b/paddle/fluid/operators/optimizers/adadelta_op.h @@ -24,17 +24,19 @@ class AdadeltaOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { const auto* param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE(param_var->IsType(), - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.InputNames("Param").front(), - framework::ToTypeName(param_var->Type())); + PADDLE_ENFORCE_EQ(param_var->IsType(), true, + platform::errors::InvalidArgument( + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.InputNames("Param").front(), + framework::ToTypeName(param_var->Type()))); const auto* grad_var = ctx.InputVar("Grad"); - PADDLE_ENFORCE(grad_var->IsType(), - "The Var(%s)'s type should be LoDTensor, " - "but the received is %s", - ctx.InputNames("Grad").front(), - framework::ToTypeName(grad_var->Type())); + PADDLE_ENFORCE_EQ(grad_var->IsType(), true, + platform::errors::InvalidArgument( + "The Var(%s)'s type should be LoDTensor, " + "but the received is %s", + ctx.InputNames("Grad").front(), + framework::ToTypeName(grad_var->Type()))); auto param_out_tensor = ctx.Output("ParamOut"); auto avg_squared_grad_out_tensor = diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index 1fbf6d00ef763f4cb608be6d62cf4bff54f620ec..d3f9754d307c6040a66a3452d7bb008159ff46e5 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -23,22 +23,27 @@ class TopkOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of TopkOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of TopkOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Indices"), - "Output(Indices) of TopkOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::InvalidArgument( + "Input(X) of TopkOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Output(Out) of TopkOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Indices"), true, + platform::errors::InvalidArgument( + "Output(Indices) of TopkOp should not be null.")); auto input_dims = ctx->GetInputDim("X"); const int k = static_cast(ctx->Attrs().Get("k")); PADDLE_ENFORCE_GE(k, 1, "k must >= 1"); - PADDLE_ENFORCE_GE(input_dims.size(), 1, "input must have >= 1d shape"); + PADDLE_ENFORCE_GE(input_dims.size(), 1, platform::errors::InvalidArgument( + "input must have >= 1d shape")); if (ctx->IsRuntime()) { - PADDLE_ENFORCE_GE(input_dims[input_dims.size() - 1], k, - "input must have >= k columns"); + PADDLE_ENFORCE_GE( + input_dims[input_dims.size() - 1], k, + platform::errors::InvalidArgument("input must have >= k columns")); } framework::DDim dims = input_dims; diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index d8b2e92616091a8c822c6fd0bfdfb1148c25534d..0a694e1ad5b012d70a89ddcca2d70fbe8c9e24ba 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -43,8 +43,9 @@ template class TopkOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument("It must use CUDAPlace.")); auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); diff --git a/python/paddle/fluid/tests/unittests/test_adadelta_op.py b/python/paddle/fluid/tests/unittests/test_adadelta_op.py index 969a7da3b71b69296f3313342adbf989c60edb50..2c6c018b9dfac13d97c242e1f36adbddf9dbf3f1 100644 --- a/python/paddle/fluid/tests/unittests/test_adadelta_op.py +++ b/python/paddle/fluid/tests/unittests/test_adadelta_op.py @@ -17,6 +17,8 @@ from __future__ import print_function import unittest import numpy as np from op_test import OpTest +import paddle +import paddle.fluid as fluid class TestAdadeltaOp1(OpTest): @@ -108,5 +110,54 @@ class TestAdadeltaOp2(OpTest): self.check_output() +class TestAdadeltaV2(unittest.TestCase): + def test_adadelta_dygraph(self): + paddle.disable_static(paddle.CPUPlace()) + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear = paddle.nn.Linear(13, 5) + # This can be any optimizer supported by dygraph. + adam = paddle.optimizer.Adadelta( + learning_rate=0.01, + parameters=linear.parameters(), + weight_decay=0.01) + out = linear(a) + out.backward() + adam.step() + adam.clear_gradients() + + def test_adadelta(self): + place = fluid.CPUPlace() + main = fluid.Program() + with fluid.program_guard(main): + x = fluid.layers.data(name='x', shape=[13], dtype='float32') + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + y_predict = fluid.layers.fc(input=x, size=1, act=None) + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + + rms_optimizer = paddle.optimizer.Adadelta(learning_rate=0.1) + rms_optimizer.minimize(avg_cost) + + fetch_list = [avg_cost] + train_reader = paddle.batch( + paddle.dataset.uci_housing.train(), batch_size=1) + feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + for data in train_reader(): + exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list) + + def test_raise_error(self): + self.assertRaises(ValueError, paddle.optimizer.Adadelta, None) + self.assertRaises( + ValueError, paddle.optimizer.Adadelta, learning_rate=0.1, rho=None) + self.assertRaises( + ValueError, + paddle.optimizer.Adadelta, + learning_rate=0.1, + epsilon=None) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base_3.py b/python/paddle/fluid/tests/unittests/test_fleet_base_3.py index f5e888ab0eb3ca597bf62245ff9f3024fe81ee95..25801793f1f2e70c404727ed4f64c7d3c830aec9 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base_3.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base_3.py @@ -43,7 +43,7 @@ class TestFleetBase(unittest.TestCase): role = role_maker.PaddleCloudRoleMaker(is_collective=True) fleet.init(role) strategy = fleet.DistributedStrategy() - optimizer = paddle.optimizer.SGD(learning_rate=0.001) + optimizer = paddle.fluid.optimizer.SGD(learning_rate=0.001) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer.minimize(avg_cost) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py index f7756e54168cc90311c443e7a768d4befa2ceda3..619e9e8e90783365b5f0d718783a14468520c8d4 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py @@ -658,7 +658,7 @@ class TestImperativeExponentialMovingAverage(TestImperativeOptimizerBase): class TestImperativePipelineOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): optimizer = paddle.optimizer.SGD(learning_rate=0.5, - parameter_list=parameter_list) + parameters=parameter_list) optimizer = PipelineOptimizer(optimizer) return optimizer @@ -670,7 +670,7 @@ class TestImperativePipelineOptimizer(TestImperativeOptimizerBase): class TestImperativeLookaheadOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): optimizer = paddle.optimizer.SGD(learning_rate=0.5, - parameter_list=parameter_list) + parameters=parameter_list) optimizer = LookaheadOptimizer(optimizer, alpha=0.5, k=5) return optimizer @@ -682,7 +682,7 @@ class TestImperativeLookaheadOptimizer(TestImperativeOptimizerBase): class TestImperativeRecomputeOptimizer(TestImperativeOptimizerBase): def get_optimizer_dygraph(self, parameter_list): optimizer = paddle.optimizer.SGD(learning_rate=0.5, - parameter_list=parameter_list) + parameters=parameter_list) optimizer = RecomputeOptimizer(optimizer) return optimizer diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index 77ec6f9b6bcda7568325698634fd4f86557cd1be..a535ef5e60397718e97100332b945b360838bbf4 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -19,6 +19,8 @@ import numpy as np import paddle.fluid.core as core from paddle.fluid.op import Operator from op_test import OpTest +import paddle +import paddle.fluid as fluid class TestMomentumOp1(OpTest): @@ -234,5 +236,48 @@ class TestSparseMomentumOp2(TestSparseMomentumOp): self.use_nesterov = True +class TestMomentumV2(unittest.TestCase): + def test_momentum_dygraph(self): + paddle.disable_static() + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear = paddle.nn.Linear(13, 5) + # This can be any optimizer supported by dygraph. + adam = paddle.optimizer.Momentum( + learning_rate=0.01, momentum=0.9, parameters=linear.parameters()) + out = linear(a) + out.backward() + adam.step() + adam.clear_gradients() + + def test_momentum(self): + place = fluid.CPUPlace() + main = fluid.Program() + with fluid.program_guard(main): + x = fluid.layers.data(name='x', shape=[13], dtype='float32') + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + y_predict = fluid.layers.fc(input=x, size=1, act=None) + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + + rms_optimizer = paddle.optimizer.Momentum( + learning_rate=0.1, momentum=0.9) + rms_optimizer.minimize(avg_cost) + + fetch_list = [avg_cost] + train_reader = paddle.batch( + paddle.dataset.uci_housing.train(), batch_size=1) + feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + for data in train_reader(): + exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list) + + def test_raise_error(self): + self.assertRaises( + ValueError, paddle.optimizer.Momentum, learning_rate=None) + self.assertRaises(ValueError, paddle.optimizer.Momentum, momentum=None) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sgd_op.py b/python/paddle/fluid/tests/unittests/test_sgd_op.py index fb3fc8735566fcf601a7cb507e3826dd92a5651e..2c87e06e893a4d6495ad81ac3dcdf375a41272fb 100644 --- a/python/paddle/fluid/tests/unittests/test_sgd_op.py +++ b/python/paddle/fluid/tests/unittests/test_sgd_op.py @@ -20,6 +20,7 @@ import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.op import Operator from op_test import OpTest +import paddle class TestSGDOp(OpTest): @@ -208,5 +209,46 @@ class TestSGDOpWithLargeInput(unittest.TestCase): result = exe.run(compiled_prog, fetch_list=[avg_cost]) +class TestSGDV2(unittest.TestCase): + def test_sgd_dygraph(self): + paddle.disable_static() + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear = paddle.nn.Linear(13, 5) + # This can be any optimizer supported by dygraph. + adam = paddle.optimizer.SGD(learning_rate=0.01, + parameters=linear.parameters(), + weight_decay=0.01) + out = linear(a) + out.backward() + adam.step() + adam.clear_gradients() + + def test_sgd(self): + place = fluid.CPUPlace() + main = fluid.Program() + with fluid.program_guard(main): + x = fluid.layers.data(name='x', shape=[13], dtype='float32') + y = fluid.layers.data(name='y', shape=[1], dtype='float32') + y_predict = fluid.layers.fc(input=x, size=1, act=None) + cost = fluid.layers.square_error_cost(input=y_predict, label=y) + avg_cost = fluid.layers.mean(cost) + + rms_optimizer = paddle.optimizer.SGD(learning_rate=0.1) + rms_optimizer.minimize(avg_cost) + + fetch_list = [avg_cost] + train_reader = paddle.batch( + paddle.dataset.uci_housing.train(), batch_size=1) + feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + for data in train_reader(): + exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list) + + def test_raise_error(self): + self.assertRaises(ValueError, paddle.optimizer.SGD, learning_rate=None) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/optimizer/__init__.py b/python/paddle/optimizer/__init__.py index 49314c9832dd389411dffb3f498b34d09337a3f0..095a34cb6fc68cda6900790141d226208b203f82 100644 --- a/python/paddle/optimizer/__init__.py +++ b/python/paddle/optimizer/__init__.py @@ -26,9 +26,8 @@ __all__ = [ ] -from ..fluid.optimizer import SGD, Momentum, Adagrad, Dpsgd, DecayedAdagrad, \ - Ftrl, Adadelta, \ - SGDOptimizer, MomentumOptimizer, AdagradOptimizer,DpsgdOptimizer,\ +from ..fluid.optimizer import Momentum, Adagrad, Dpsgd, DecayedAdagrad, Ftrl,\ + AdagradOptimizer,DpsgdOptimizer,\ DecayedAdagradOptimizer,FtrlOptimizer,AdadeltaOptimizer, \ ModelAverage, LarsMomentum, DGCMomentumOptimizer, LambOptimizer,\ ExponentialMovingAverage, PipelineOptimizer, LookaheadOptimizer, \ @@ -39,6 +38,9 @@ from .adam import Adam from .adamw import AdamW from .adamax import Adamax from .rmsprop import RMSProp +from .adadelta import Adadelta +from .sgd import SGD +from .momentum import Momentum from . import lr_scheduler from .lr_scheduler import _LRScheduler, NoamLR, PiecewiseLR, NaturalExpLR, InverseTimeLR, PolynomialLR, \ diff --git a/python/paddle/optimizer/adadelta.py b/python/paddle/optimizer/adadelta.py new file mode 100644 index 0000000000000000000000000000000000000000..bba2c11ea07490804573189bac8b315dfc80fd37 --- /dev/null +++ b/python/paddle/optimizer/adadelta.py @@ -0,0 +1,144 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .optimizer import Optimizer +from ..fluid import core +from ..fluid import framework +from ..fluid.framework import Variable, name_scope + +__all__ = ["Adadelta"] + + +class Adadelta(Optimizer): + """ + **Notes: This API does not support sparse parameter optimization.** + + Adadelta Optimizer. Please refer to this for details: + `ADADELTA: AN ADAPTIVE LEARNING RATE METHOD `_. + + The update is done as follows: + + .. math:: + + E(g_t^2) &= \\rho * E(g_{t-1}^2) + (1-\\rho) * g^2 + + learning\_rate &= \sqrt{ ( E(dx_{t-1}^2) + \\epsilon ) / ( E(g_t^2) + \\epsilon ) } + + E(dx_t^2) &= \\rho * E(dx_{t-1}^2) + (1-\\rho) * (-g*learning\_rate)^2 + + Args: + learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``. + It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. + epsilon (float): a small float number for numeric stability. Default 1.0e-6. + rho (float): a floating point value indicating the decay rate. Default 0.95. + parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \ + This parameter is required in dygraph mode. \ + The default value is None in static mode, at this time all parameters will be updated. + weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ + It canbe a float value as coeff of L2 regularization or \ + :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`. + If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \ + the regularization setting here in optimizer will be ignored for this parameter. \ + Otherwise, the regularization setting here in optimizer will take effect. \ + Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. + name (str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name` . + + Examples: + .. code-block:: python + import paddle + import numpy as np + paddle.disable_static() + inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + inp = paddle.to_tensor(inp) + out = linear(inp) + loss = paddle.mean(out) + beta1 = paddle.to_tensor([0.9], dtype="float32") + beta2 = paddle.to_tensor([0.99], dtype="float32") + adadelta = paddle.optimizer.Adadelta(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01) + back = out.backward() + adadelta.step() + adadelta.clear_grad() + + """ + + _avg_squared_grad_acc_str = "_avg_squared_grad" + _avg_squared_update_acc_str = "_avg_squared_update" + + def __init__(self, + learning_rate=0.001, + epsilon=1.0e-6, + rho=0.95, + parameters=None, + weight_decay=None, + grad_clip=None, + name=None): + if learning_rate is None: + raise ValueError("learning_rate is not set.") + if epsilon is None: + raise ValueError("epsilon is not set.") + if rho is None: + raise ValueError("rho is not set.") + super(Adadelta, self).__init__( + learning_rate=learning_rate, + parameters=parameters, + weight_decay=weight_decay, + grad_clip=grad_clip, + name=name) + self.type = "adadelta" + self._epsilon = epsilon + self._rho = rho + + def _create_accumulators(self, block, parameters): + if not isinstance(block, framework.Block): + raise TypeError("block is not instance of framework.Block.") + + for p in parameters: + self._add_accumulator(self._avg_squared_grad_acc_str, p) + self._add_accumulator(self._avg_squared_update_acc_str, p) + + def _append_optimize_op(self, block, param_and_grad): + if not isinstance(block, framework.Block): + raise TypeError("block is not instance of framework.Block.") + + avg_squared_grad_acc = self._get_accumulator( + self._avg_squared_grad_acc_str, param_and_grad[0]) + avg_squared_update_acc = self._get_accumulator( + self._avg_squared_update_acc_str, param_and_grad[0]) + + # Create the adadelta optimizer op + adadelta_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "AvgSquaredGrad": avg_squared_grad_acc, + "AvgSquaredUpdate": avg_squared_update_acc + }, + outputs={ + "ParamOut": param_and_grad[0], + "AvgSquaredGradOut": avg_squared_grad_acc, + "AvgSquaredUpdateOut": avg_squared_update_acc + }, + attrs={"epsilon": self._epsilon, + "rho": self._rho}, + stop_gradient=True) + + return adadelta_op diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py new file mode 100644 index 0000000000000000000000000000000000000000..87fa86c17615ef8cc455e95517608a246d677e74 --- /dev/null +++ b/python/paddle/optimizer/momentum.py @@ -0,0 +1,149 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .optimizer import Optimizer +from ..fluid import core +from ..fluid import framework +from ..fluid.framework import Variable, name_scope + +__all__ = ["Momentum"] + + +class Momentum(Optimizer): + """ + + Simple Momentum optimizer with velocity state + + This optimizer has a flag for Nestrov Momentum. + + The update equations are as follows: + + .. math:: + + & velocity = mu * velocity + gradient + + & if (use\_nesterov): + + &\quad param = param - (gradient + mu * velocity) * learning\_rate + + & else: + + &\quad param = param - learning\_rate * velocity + + Parameters: + + learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``. + It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. + momentum (float): Momentum factor. The default value is 0.9. + parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \ + This parameter is required in dygraph mode. \ + The default value is None in static mode, at this time all parameters will be updated. + weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ + It canbe a float value as coeff of L2 regularization or \ + :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`. + If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \ + the regularization setting here in optimizer will be ignored for this parameter. \ + Otherwise, the regularization setting here in optimizer will take effect. \ + Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. + name (str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name` . + + Examples: + .. code-block:: python + + import paddle + import numpy as np + paddle.disable_static() + inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + inp = paddle.to_tensor(inp) + out = linear(inp) + loss = paddle.mean(out) + beta1 = paddle.to_tensor([0.9], dtype="float32") + beta2 = paddle.to_tensor([0.99], dtype="float32") + momentum = paddle.optimizer.Momentum(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01) + back = out.backward() + momentum.step() + momentum.clear_grad() + """ + _velocity_acc_str = "velocity" + + def __init__(self, + learning_rate=0.001, + momentum=0.9, + parameters=None, + use_nesterov=False, + weight_decay=None, + grad_clip=None, + name=None): + if learning_rate is None: + raise ValueError("learning_rate is not set") + if momentum is None: + raise ValueError("momentum is not set") + super(Momentum, self).__init__( + learning_rate=learning_rate, + parameters=parameters, + weight_decay=weight_decay, + grad_clip=grad_clip, + name=name) + self.type = "momentum" + self._momentum = momentum + self._use_nesterov = bool(use_nesterov) + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + + for p in parameters: + self._add_accumulator(self._velocity_acc_str, p) + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, framework.Block) + + velocity_acc = self._get_accumulator(self._velocity_acc_str, + param_and_grad[0]) + lr = self._create_param_lr(param_and_grad) + + if framework.in_dygraph_mode(): + _, _ = core.ops.momentum(param_and_grad[0], param_and_grad[1], + velocity_acc, lr, param_and_grad[0], + velocity_acc, 'mu', self._momentum, + 'use_nesterov', self._use_nesterov) + return None + + attrs = {"mu": self._momentum, "use_nesterov": self._use_nesterov} + inputs = { + "Param": [param_and_grad[0]], + "Grad": [param_and_grad[1]], + "Velocity": [velocity_acc], + "LearningRate": [lr] + } + + outputs = { + "ParamOut": [param_and_grad[0]], + "VelocityOut": [velocity_acc] + } + # create the momentum optimize op + momentum_op = block.append_op( + type=self.type, + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True) + + return momentum_op diff --git a/python/paddle/optimizer/sgd.py b/python/paddle/optimizer/sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..bb3a578e15724e9501d69dc209bdedc65afeb82b --- /dev/null +++ b/python/paddle/optimizer/sgd.py @@ -0,0 +1,108 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .optimizer import Optimizer +from ..fluid import core +from ..fluid import framework +from ..fluid.framework import Variable, name_scope +from ..fluid.dygraph import no_grad +__all__ = ["SGD"] + + +class SGD(Optimizer): + """ + Optimizer of the stochastic gradient descent algorithm. + + .. math:: + + param\_out = param - learning\_rate * grad + + Parameters: + learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``. + It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. + parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \ + This parameter is required in dygraph mode. \ + The default value is None in static mode, at this time all parameters will be updated. + weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ + It canbe a float value as coeff of L2 regularization or \ + :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`. + If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \ + the regularization setting here in optimizer will be ignored for this parameter. \ + Otherwise, the regularization setting here in optimizer will take effect. \ + Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. + name (str, optional): The default value is None. Normally there is no need for user + to set this property. For more information, please refer to + :ref:`api_guide_Name` . + + Examples: + .. code-block:: python + + import paddle + import numpy as np + paddle.disable_static() + inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + inp = paddle.to_tensor(inp) + out = linear(inp) + loss = paddle.mean(out) + beta1 = paddle.to_tensor([0.9], dtype="float32") + beta2 = paddle.to_tensor([0.99], dtype="float32") + sgd = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), weight_decay=0.01) + back = out.backward() + sgd.step() + sgd.clear_grad() + + """ + + def __init__(self, + learning_rate=0.001, + parameters=None, + weight_decay=None, + grad_clip=None, + name=None): + if learning_rate is None: + raise ValueError("learning_rate is not set") + super(SGD, self).__init__( + learning_rate=learning_rate, + parameters=parameters, + weight_decay=weight_decay, + grad_clip=grad_clip, + name=name) + self.type = "sgd" + + @no_grad() + def _append_optimize_op(self, block, param_and_grad): + lr = self._create_param_lr(param_and_grad) + if framework.in_dygraph_mode(): + core.ops.sgd(param_and_grad[0], lr, param_and_grad[1], + param_and_grad[0]) + return None + + assert isinstance(block, framework.Block) + # create the optimize op + sgd_op = block.append_op( + type=self.type, + inputs={ + "Param": param_and_grad[0], + "Grad": param_and_grad[1], + "LearningRate": lr + }, + outputs={"ParamOut": param_and_grad[0]}, + stop_gradient=True) + + return sgd_op