diff --git a/python/paddle/fluid/tests/unittests/test_adadelta_op.py b/python/paddle/fluid/tests/unittests/test_adadelta_op.py index 2c6c018b9dfac13d97c242e1f36adbddf9dbf3f1..44dd3d60bdca1af0c81373dae60689cd579d35ec 100644 --- a/python/paddle/fluid/tests/unittests/test_adadelta_op.py +++ b/python/paddle/fluid/tests/unittests/test_adadelta_op.py @@ -127,6 +127,7 @@ class TestAdadeltaV2(unittest.TestCase): adam.clear_gradients() def test_adadelta(self): + paddle.enable_static() place = fluid.CPUPlace() main = fluid.Program() with fluid.program_guard(main): @@ -159,5 +160,29 @@ class TestAdadeltaV2(unittest.TestCase): epsilon=None) +class TestAdadeltaV2Group(TestAdadeltaV2): + def test_adadelta_dygraph(self): + paddle.disable_static(paddle.CPUPlace()) + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear_1 = paddle.nn.Linear(13, 5) + linear_2 = paddle.nn.Linear(5, 5) + # This can be any optimizer supported by dygraph. + adam = paddle.optimizer.Adadelta( + learning_rate=0.01, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + }], + weight_decay=0.1) + out = linear_1(a) + out = linear_2(out) + out.backward() + adam.step() + adam.clear_gradients() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_adagrad_op_v2.py b/python/paddle/fluid/tests/unittests/test_adagrad_op_v2.py index 0ccd42aa674dd410bdd2ea34a27929bede345332..c6a69c0723ce9142980ca3529c2c0c1fef7585c0 100644 --- a/python/paddle/fluid/tests/unittests/test_adagrad_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_adagrad_op_v2.py @@ -37,5 +37,28 @@ class TestAdagradOpV2(unittest.TestCase): adagrad.clear_grad() +class TestAdagradOpV2Group(TestAdagradOpV2): + def test_v20_coverage(self): + paddle.disable_static() + inp = paddle.rand(shape=[10, 10]) + linear_1 = paddle.nn.Linear(10, 10) + linear_2 = paddle.nn.Linear(10, 10) + out = linear_1(inp) + out = linear_2(out) + loss = paddle.mean(out) + adagrad = paddle.optimizer.Adagrad( + learning_rate=0.01, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + }], + weight_decay=0.1) + out.backward() + adagrad.step() + adagrad.clear_grad() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index 1e316c3383ea76f968868fbc7f90ccc898bc61a8..aea2a074aedd58a1152efbaa8d276f7d1c82387c 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -810,5 +810,31 @@ class TestNetWithEpsilonTensor(unittest.TestCase): paddle.enable_static() +class TestAdamOpV2Group(TestAdamOpV2): + def test_adam_op(self): + paddle.disable_static() + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear_1 = paddle.nn.Linear(13, 5) + linear_2 = paddle.nn.Linear(5, 3) + # This can be any optimizer supported by dygraph. + adam = paddle.optimizer.Adam( + learning_rate=0.01, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'beta1': 0.1, + 'beta2': 0.99 + }], + weight_decay=0.1) + out = linear_1(a) + out = linear_2(out) + out.backward() + adam.step() + adam.clear_gradients() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_adamax_api.py b/python/paddle/fluid/tests/unittests/test_adamax_api.py index 6d2ec0eefbb1c5157fdbcb5a2e04e97e918a95c9..57cb9d3cb5f7ddef60f6577ba0d8217ab3d16b45 100644 --- a/python/paddle/fluid/tests/unittests/test_adamax_api.py +++ b/python/paddle/fluid/tests/unittests/test_adamax_api.py @@ -37,6 +37,7 @@ class TestAdamaxAPI(unittest.TestCase): adam.clear_gradients() def test_adamax_api(self): + paddle.enable_static() place = fluid.CPUPlace() shape = [2, 3, 8, 8] exe = fluid.Executor(place) @@ -63,5 +64,31 @@ class TestAdamaxAPI(unittest.TestCase): assert rets[0] is not None +class TestAdamaxAPIGroup(TestAdamaxAPI): + def test_adamax_api_dygraph(self): + paddle.disable_static() + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear_1 = paddle.nn.Linear(13, 5) + linear_2 = paddle.nn.Linear(5, 3) + # This can be any optimizer supported by dygraph. + adam = paddle.optimizer.Adamax( + learning_rate=0.01, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'beta1': 0.1, + 'beta2': 0.99 + }], + weight_decay=0.1) + out = linear_1(a) + out = linear_2(out) + out.backward() + adam.step() + adam.clear_gradients() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_adamw_op.py b/python/paddle/fluid/tests/unittests/test_adamw_op.py index 9b77dae1afed2d58601724fed033119cffe6a8e6..ce01ca042c123d17ae629c11a86cb38f123251b3 100644 --- a/python/paddle/fluid/tests/unittests/test_adamw_op.py +++ b/python/paddle/fluid/tests/unittests/test_adamw_op.py @@ -121,5 +121,31 @@ class TestAdamWOp(unittest.TestCase): adam.clear_gradients() +class TestAdamWOpGroup(TestAdamWOp): + def test_adamw_op_dygraph(self): + paddle.disable_static() + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear_1 = paddle.nn.Linear(13, 5) + linear_2 = paddle.nn.Linear(5, 3) + adam = paddle.optimizer.AdamW( + learning_rate=0.01, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001 + }], + apply_decay_param_fun=lambda name: True, + weight_decay=0.01) + + for _ in range(2): + out = linear_1(a) + out = linear_2(out) + out.backward() + adam.step() + adam.clear_gradients() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_lambv2_op.py b/python/paddle/fluid/tests/unittests/test_lambv2_op.py index 7ffc056812f2ba1f1d1ace5d1fdc3fcf226dbd05..861418679a36620d2a31bf375de50c65cc10b5ea 100644 --- a/python/paddle/fluid/tests/unittests/test_lambv2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lambv2_op.py @@ -155,5 +155,31 @@ class TestLambOpWithCombinedOp(unittest.TestCase): self.assertTrue(np.allclose(out, output)) +class TestLambOpV2Group(TestLambOpV2): + def test_lamb_op(self): + paddle.disable_static() + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear_1 = paddle.nn.Linear(13, 5) + linear_2 = paddle.nn.Linear(5, 3) + # This can be any optimizer supported by dygraph. + adam = paddle.optimizer.Lamb( + learning_rate=0.01, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'lamb_weight_decay': 0.001, + 'beta1': 0.9, + 'beta2': 0.99 + }], + lamb_weight_decay=0.01) + out = linear_1(a) + out = linear_2(out) + out.backward() + adam.step() + adam.clear_gradients() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index 8f629b15224287bdb4f53de90cfc526bf12ad4d8..ba4c1458c7791debf8da23ca6860220bdee6fc94 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -610,5 +610,32 @@ class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase): self.__test_vs(place=place) +class TestMomentumV2Group(TestMomentumV2): + def test_momentum_dygraph(self): + paddle.disable_static() + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear_1 = paddle.nn.Linear(13, 5) + linear_2 = paddle.nn.Linear(5, 3) + # This can be any optimizer supported by dygraph. + adam = paddle.optimizer.Momentum( + learning_rate=0.01, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'learning_rate': 0.1, + 'momentum': 0.99 + }], + weight_decay=0.1, + momentum=0.9) + out = linear_1(a) + out = linear_2(out) + out.backward() + adam.step() + adam.clear_gradients() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py index ddac7f6b98b19d204d20ccdff75c6d4fcae50d4d..08ab2e18c733a6ba4bad904f10abce2baf9517ed 100644 --- a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py +++ b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py @@ -240,6 +240,7 @@ class TestRMSPropV2(unittest.TestCase): adam.clear_gradients() def test_rmsprop(self): + paddle.enable_static() place = fluid.CPUPlace() main = fluid.Program() with fluid.program_guard(main): @@ -290,5 +291,29 @@ class TestRMSPropV2(unittest.TestCase): 0.1, rho=-1, parameters=linear.parameters()) +class TestRMSPropV2Group(TestRMSPropV2): + def test_rmsprop_dygraph(self): + paddle.disable_static() + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear_1 = paddle.nn.Linear(13, 5) + linear_2 = paddle.nn.Linear(5, 3) + # This can be any optimizer supported by dygraph. + adam = paddle.optimizer.RMSProp( + learning_rate=0.01, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001 + }], + weight_decay=0.01) + out = linear_1(a) + out = linear_2(out) + out.backward() + adam.step() + adam.clear_gradients() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_sgd_op.py b/python/paddle/fluid/tests/unittests/test_sgd_op.py index 2c87e06e893a4d6495ad81ac3dcdf375a41272fb..afa004e769e092317a7fbf9551d067dc19f9c0f8 100644 --- a/python/paddle/fluid/tests/unittests/test_sgd_op.py +++ b/python/paddle/fluid/tests/unittests/test_sgd_op.py @@ -225,6 +225,7 @@ class TestSGDV2(unittest.TestCase): adam.clear_gradients() def test_sgd(self): + paddle.enable_static() place = fluid.CPUPlace() main = fluid.Program() with fluid.program_guard(main): @@ -250,5 +251,29 @@ class TestSGDV2(unittest.TestCase): self.assertRaises(ValueError, paddle.optimizer.SGD, learning_rate=None) +class TestSGDV2Group(TestSGDV2): + def test_sgd_dygraph(self): + paddle.disable_static() + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear_1 = paddle.nn.Linear(13, 5) + linear_2 = paddle.nn.Linear(5, 3) + # This can be any optimizer supported by dygraph. + adam = paddle.optimizer.SGD(learning_rate=0.01, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'learning_rate': 0.1 + }], + weight_decay=0.01) + out = linear_1(a) + out = linear_2(out) + out.backward() + adam.step() + adam.clear_gradients() + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/optimizer/adadelta.py b/python/paddle/optimizer/adadelta.py index 6c10d9bc2690a09b23ed2238ddd548d65f21df36..dd088b18ca27d9b749e602988ebd3954dbaacebf 100644 --- a/python/paddle/optimizer/adadelta.py +++ b/python/paddle/optimizer/adadelta.py @@ -43,7 +43,10 @@ class Adadelta(Optimizer): epsilon (float): a small float number for numeric stability. Default 1.0e-6. rho (float): a floating point value indicating the decay rate. Default 0.95. parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \ - This parameter is required in dygraph mode. \ + This parameter is required in dygraph mode. And you can specify different options for \ + different parameter groups such as the learning rate, weight decay, etc, \ + then the parameters are list of dict. Note that the learning_rate in paramter groups \ + represents the scale of base learning_rate. \ The default value is None in static mode, at this time all parameters will be updated. weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ It canbe a float value as coeff of L2 regularization or \ @@ -77,6 +80,27 @@ class Adadelta(Optimizer): adadelta.step() adadelta.clear_grad() + #Note that the learning_rate of linear_2 is 0.01. + linear_1 = paddle.nn.Linear(10, 10) + linear_2 = paddle.nn.Linear(10, 10) + inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1) + out = linear_1(inp) + out = linear_2(out) + loss = paddle.mean(out) + adadelta = paddle.optimizer.Adadelta( + learning_rate=0.1, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'learning_rate': 0.1, + }], + weight_decay=0.01) + out.backward() + adadelta.step() + adadelta.clear_grad() + """ _avg_squared_grad_acc_str = "_avg_squared_grad" @@ -105,10 +129,16 @@ class Adadelta(Optimizer): self.type = "adadelta" self._epsilon = epsilon self._rho = rho + self._default_dict = { + 'epsilon': epsilon, + 'rho': rho, + } def _create_accumulators(self, block, parameters): if not isinstance(block, framework.Block): raise TypeError("block is not instance of framework.Block.") + if isinstance(parameters, dict): + parameters = parameters.get('params') for p in parameters: self._add_accumulator(self._avg_squared_grad_acc_str, p) @@ -118,6 +148,9 @@ class Adadelta(Optimizer): if not isinstance(block, framework.Block): raise TypeError("block is not instance of framework.Block.") + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) + avg_squared_grad_acc = self._get_accumulator( self._avg_squared_grad_acc_str, param_and_grad[0]) avg_squared_update_acc = self._get_accumulator( @@ -142,3 +175,9 @@ class Adadelta(Optimizer): stop_gradient=True) return adadelta_op + + def _update_param_group(self, parameters): + self._epsilon = parameters.get('epsilon', self._default_dict['epsilon']) + self._rho = parameters.get('rho', self._default_dict['rho']) + parameters = parameters.get('params') + return parameters diff --git a/python/paddle/optimizer/adagrad.py b/python/paddle/optimizer/adagrad.py index bb934e5a9262c778029df3b29d84b6dd7a71bde3..6238d32e9c49dfa4664f2e269f415c44f06ffb3f 100644 --- a/python/paddle/optimizer/adagrad.py +++ b/python/paddle/optimizer/adagrad.py @@ -45,16 +45,19 @@ class Adagrad(Optimizer): It can be a float value or a ``Variable`` with a float type. epsilon (float, optional): A small float value for numerical stability. The default value is 1e-06. - parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \ - This parameter is required in dygraph mode. \ - The default value is None in static mode, at this time all parameters will be updated. - weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ - It canbe a float value as coeff of L2 regularization or \ - :ref:`api_paddle_regularizer_L1Decay`, :ref:`api_paddle_regularizer_L2Decay`. - If a parameter has set regularizer using :ref:`api_paddle_fluid_param_attr_aramAttr` already, \ - the regularization setting here in optimizer will be ignored for this parameter. \ - Otherwise, the regularization setting here in optimizer will take effect. \ - Default None, meaning there is no regularization. + parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \ + This parameter is required in dygraph mode. And you can specify different options for \ + different parameter groups such as the learning rate, weight decay, etc, \ + then the parameters are list of dict. Note that the learning_rate in paramter groups \ + represents the scale of base learning_rate. \ + The default value is None in static mode, at this time all parameters will be updated. + weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ + It canbe a float value as coeff of L2 regularization or \ + :ref:`api_paddle_regularizer_L1Decay`, :ref:`api_paddle_regularizer_L2Decay`. + If a parameter has set regularizer using :ref:`api_paddle_fluid_param_attr_aramAttr` already, \ + the regularization setting here in optimizer will be ignored for this parameter. \ + Otherwise, the regularization setting here in optimizer will take effect. \ + Default None, meaning there is no regularization. grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of some derived class of ``GradientClipBase`` . There are three cliping strategies, ClipGradByGlobalNorm, ClipGradByNorm and ClipGradByValue. Default None, @@ -81,6 +84,27 @@ class Adagrad(Optimizer): adagrad.step() adagrad.clear_grad() + #Note that the learning_rate of linear_2 is 0.01. + linear_1 = paddle.nn.Linear(10, 10) + linear_2 = paddle.nn.Linear(10, 10) + inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1) + out = linear_1(inp) + out = linear_2(out) + loss = paddle.mean(out) + adagrad = paddle.optimizer.Adagrad( + learning_rate=0.1, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'learning_rate': 0.1, + }], + weight_decay=0.01) + out.backward() + adagrad.step() + adagrad.clear_grad() + """ _moment_acc_str = "moment" @@ -103,10 +127,17 @@ class Adagrad(Optimizer): self.type = "adagrad" self._epsilon = epsilon self.initial_accumulator_value = initial_accumulator_value + self._default_dict = { + 'epsilon': epsilon, + 'initial_accumulator_value': initial_accumulator_value, + } def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) + if isinstance(parameters, dict): + parameters = self._update_param_group(parameters) + for p in parameters: self._add_accumulator( self._moment_acc_str, @@ -116,6 +147,9 @@ class Adagrad(Optimizer): def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) + moment_acc = self._get_accumulator(self._moment_acc_str, param_and_grad[0]) # Create the adagrad optimizer op @@ -133,3 +167,11 @@ class Adagrad(Optimizer): stop_gradient=True) return adagrad_op + + def _update_param_group(self, parameters): + self._epsilon = parameters.get('epsilon', self._default_dict['epsilon']) + self.initial_accumulator_value = parameters.get( + 'initial_accumulator_value', + self._default_dict['initial_accumulator_value']) + parameters = parameters.get('params') + return parameters diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 63ca462d1a26b8a17e540a1fac2284b77a523a21..baa6a307176dd5feb70f1b6f2201a89f298e6153 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -21,6 +21,7 @@ from ..fluid import unique_name from ..fluid.layer_helper import LayerHelper import warnings from ..fluid.dygraph import base as imperative_base +from collections import defaultdict import paddle @@ -63,16 +64,19 @@ class Adam(Optimizer): epsilon (float|Tensor, optional): A small float value for numerical stability. It should be a float number or a Tensor with shape [1] and data type as float32. The default value is 1e-08. - parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \ - This parameter is required in dygraph mode. \ - The default value is None in static mode, at this time all parameters will be updated. - weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ - It canbe a float value as coeff of L2 regularization or \ - :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`. - If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \ - the regularization setting here in optimizer will be ignored for this parameter. \ - Otherwise, the regularization setting here in optimizer will take effect. \ - Default None, meaning there is no regularization. + parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \ + This parameter is required in dygraph mode. And you can specify different options for \ + different parameter groups such as the learning rate, weight decay, etc, \ + then the parameters are list of dict. Note that the learning_rate in paramter groups \ + represents the scale of base learning_rate. \ + The default value is None in static mode, at this time all parameters will be updated. + weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ + It canbe a float value as coeff of L2 regularization or \ + :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`. + If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \ + the regularization setting here in optimizer will be ignored for this parameter. \ + Otherwise, the regularization setting here in optimizer will take effect. \ + Default None, meaning there is no regularization. grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of some derived class of ``GradientClipBase`` . There are three cliping strategies ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , @@ -126,6 +130,29 @@ class Adam(Optimizer): adam.step() adam.clear_grad() + #Note that the learning_rate of linear_2 is 0.01. + linear_1 = paddle.nn.Linear(10, 10) + linear_2 = paddle.nn.Linear(10, 10) + inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1) + out = linear_1(inp) + out = linear_2(out) + loss = paddle.mean(out) + adam = paddle.optimizer.Adam( + learning_rate=0.1, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'learning_rate': 0.1, + 'beta1': 0.8 + }], + weight_decay=0.01, + beta1=0.9) + out.backward() + adam.step() + adam.clear_grad() + """ _moment1_acc_str = "moment1" _moment2_acc_str = "moment2" @@ -172,6 +199,12 @@ class Adam(Optimizer): self._lazy_mode = lazy_mode self._multi_precision = multi_precision self._master_weights = {} + self._default_dict = { + 'beta1': beta1, + 'beta2': beta2, + 'epsilon': epsilon, + 'lazy_mode': lazy_mode, + } def _create_master_weight(self, param): assert isinstance(self.helper, LayerHelper) @@ -241,6 +274,8 @@ class Adam(Optimizer): def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) + if isinstance(parameters, dict): + parameters = self._update_param_group(parameters) # Create accumulator tensors for first and second moments for p in parameters: @@ -257,6 +292,8 @@ class Adam(Optimizer): def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) moment1 = self._get_accumulator(self._moment1_acc_str, param_and_grad[0]) @@ -274,6 +311,7 @@ class Adam(Optimizer): # create the adam optimize op if framework.in_dygraph_mode(): + _beta1 = self._beta1 if not isinstance( self._beta1, Variable) else self._beta1.numpy().item(0) _beta2 = self._beta2 if not isinstance( @@ -359,18 +397,43 @@ class Adam(Optimizer): adam.step() adam.clear_grad() """ - params_grads = [] - for param in self._parameter_list: - if param.stop_gradient: - continue - if param._grad_ivar() is not None: - grad_var = param._grad_ivar() - if hasattr(grad_var, "_is_sparse") and grad_var._is_sparse( - ) and self.regularization is not None: - raise RuntimeError( - "Adam don't support weight_decay with sparse parameters, please set it to None." - ) - params_grads.append((param, grad_var)) - - optimize_ops = self._apply_optimize( - loss=None, startup_program=None, params_grads=params_grads) + if not isinstance(self._parameter_list[0], dict): + params_grads = [] + for param in self._parameter_list: + if param.stop_gradient: + continue + if param._grad_ivar() is not None: + grad_var = param._grad_ivar() + if hasattr(grad_var, "_is_sparse") and grad_var._is_sparse( + ) and self.regularization is not None: + raise RuntimeError( + "Adam don't support weight_decay with sparse parameters, please set it to None." + ) + params_grads.append((param, grad_var)) + + optimize_ops = self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads) + else: + # optimize parameters in groups + for param_group in self._param_groups: + params_grads = defaultdict(lambda: list()) + for param in param_group['params']: + if param.stop_gradient: + continue + if param._grad_ivar() is not None: + grad_var = param._grad_ivar() + params_grads['params'].append((param, grad_var)) + params_grads.update( + {k: v + for k, v in param_group.items() if k != 'params'}) + self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads) + + def _update_param_group(self, parameters): + self._beta1 = parameters.get('beta1', self._default_dict['beta1']) + self._beta2 = parameters.get('beta2', self._default_dict['beta2']) + self._epsilon = parameters.get('epsilon', self._default_dict['epsilon']) + self._lazy_mode = parameters.get('lazy_mode', + self._default_dict['lazy_mode']) + parameters = parameters.get('params') + return parameters diff --git a/python/paddle/optimizer/adamax.py b/python/paddle/optimizer/adamax.py index 44ae89f49d1c0502a2f18ca9c4d58f10a6a9a69e..867b7703720ba3ffac3004ad886240fb53fc39ee 100644 --- a/python/paddle/optimizer/adamax.py +++ b/python/paddle/optimizer/adamax.py @@ -55,16 +55,19 @@ class Adamax(Optimizer): The default value is 0.999. epsilon (float, optional): A small float value for numerical stability. The default value is 1e-08. - parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \ - This parameter is required in dygraph mode. \ - The default value is None in static mode, at this time all parameters will be updated. - weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ - It canbe a float value as coeff of L2 regularization or \ - :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`. - If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \ - the regularization setting here in optimizer will be ignored for this parameter. \ - Otherwise, the regularization setting here in optimizer will take effect. \ - Default None, meaning there is no regularization. + parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \ + This parameter is required in dygraph mode. And you can specify different options for \ + different parameter groups such as the learning rate, weight decay, etc, \ + then the parameters are list of dict. Note that the learning_rate in paramter groups \ + represents the scale of base learning_rate. \ + The default value is None in static mode, at this time all parameters will be updated. + weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ + It canbe a float value as coeff of L2 regularization or \ + :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`. + If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \ + the regularization setting here in optimizer will be ignored for this parameter. \ + Otherwise, the regularization setting here in optimizer will take effect. \ + Default None, meaning there is no regularization. grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of some derived class of ``GradientClipBase`` . There are three cliping strategies ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , @@ -100,6 +103,29 @@ class Adamax(Optimizer): adam.step() adam.clear_grad() + + #Note that the learning_rate of linear_2 is 0.01. + linear_1 = paddle.nn.Linear(10, 10) + linear_2 = paddle.nn.Linear(10, 10) + inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1) + out = linear_1(inp) + out = linear_2(out) + loss = paddle.mean(out) + adam = paddle.optimizer.Adamax( + learning_rate=0.1, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'learning_rate': 0.1, + 'beta1': 0.8 + }], + weight_decay=0.01, + beta1=0.9) + out.backward() + adam.step() + adam.clear_grad() """ _moment_acc_str = "moment" _inf_norm_acc_str = "inf_norm" @@ -134,8 +160,16 @@ class Adamax(Optimizer): self._beta1 = beta1 self._beta2 = beta2 self._epsilon = epsilon + self._default_dict = { + 'beta1': beta1, + 'beta2': beta2, + 'epsilon': epsilon + } def _create_accumulators(self, block, parameters): + if isinstance(parameters, dict): + parameters = self._update_param_group(parameters) + # Create accumulator tensors for first moment and infinity norm for p in parameters: self._add_accumulator(self._moment_acc_str, p) @@ -148,6 +182,8 @@ class Adamax(Optimizer): def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) moment = self._get_accumulator(self._moment_acc_str, param_and_grad[0]) inf_norm = self._get_accumulator(self._inf_norm_acc_str, @@ -183,16 +219,40 @@ class Adamax(Optimizer): """Update Beta1 Power accumulator """ assert isinstance(block, framework.Block) - for param, grad in parameters_and_grads: - if grad is None or param.stop_gradient is True: - continue - with param.block.program._optimized_guard( - [param, grad]), name_scope('adamax'): - beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, - param) - block.append_op( - type="scale", - inputs={"X": beta1_pow_acc}, - outputs={"Out": beta1_pow_acc}, - attrs={"scale": self._beta1}, - stop_gradient=True) + if isinstance(parameters_and_grads, list): + for param, grad in parameters_and_grads: + if grad is None or param.stop_gradient is True: + continue + with param.block.program._optimized_guard( + [param, grad]), name_scope('adamax'): + beta1_pow_acc = self._get_accumulator( + self._beta1_pow_acc_str, param) + block.append_op( + type="scale", + inputs={"X": beta1_pow_acc}, + outputs={"Out": beta1_pow_acc}, + attrs={"scale": self._beta1}, + stop_gradient=True) + else: + for param, grad in parameters_and_grads['params']: + if grad is None or param.stop_gradient is True: + continue + with param.block.program._optimized_guard( + [param, grad]), name_scope('adamax'): + beta1_pow_acc = self._get_accumulator( + self._beta1_pow_acc_str, param) + self._beta1 = parameters_and_grads.get( + 'beta1', self._default_dict['beta1']) + block.append_op( + type="scale", + inputs={"X": beta1_pow_acc}, + outputs={"Out": beta1_pow_acc}, + attrs={"scale": self._beta1}, + stop_gradient=True) + + def _update_param_group(self, parameters): + self._beta1 = parameters.get('beta1', self._default_dict['beta1']) + self._beta2 = parameters.get('beta2', self._default_dict['beta2']) + self._epsilon = parameters.get('epsilon', self._default_dict['epsilon']) + parameters = parameters.get('params') + return parameters diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index 304f0b771826c946b7a28f17959aef7d426174c4..c3cffa2998f6cc0956412be7709251720f8a51db 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -45,9 +45,12 @@ class AdamW(Adam): Args: learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``. It can be a float value or a LRScheduler. The default value is 0.001. - parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \ - This parameter is required in dygraph mode. \ - The default value is None in static mode, at this time all parameters will be updated. + parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \ + This parameter is required in dygraph mode. And you can specify different options for \ + different parameter groups such as the learning rate, weight decay, etc, \ + then the parameters are list of dict. Note that the learning_rate in paramter groups \ + represents the scale of base learning_rate. \ + The default value is None in static mode, at this time all parameters will be updated. beta1 (float|Tensor, optional): The exponential decay rate for the 1st moment estimates. It should be a float number or a Tensor with shape [1] and data type as float32. The default value is 0.9. @@ -101,6 +104,30 @@ class AdamW(Adam): adam.step() adam.clear_grad() + + #Note that the learning_rate of linear_2 is 0.01. + linear_1 = paddle.nn.Linear(10, 10) + linear_2 = paddle.nn.Linear(10, 10) + inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1) + out = linear_1(inp) + out = linear_2(out) + loss = paddle.mean(out) + adam = paddle.optimizer.AdamW( + learning_rate=0.1, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'learning_rate': 0.1, + 'beta1': 0.8 + }], + weight_decay=0.01, + beta1=0.9) + out.backward() + adam.step() + adam.clear_grad() + """ def __init__(self, @@ -143,6 +170,7 @@ class AdamW(Adam): name=name, lazy_mode=lazy_mode, multi_precision=multi_precision) + self._default_dict = {'coeff': coeff} def _append_decoupled_weight_decay(self, block, param_and_grad): """ @@ -156,7 +184,10 @@ class AdamW(Adam): Raises: Exception: The type of coeff and parameter is not consistent. """ - param, grad = param_and_grad + if not isinstance(param_and_grad, dict): + param, grad = param_and_grad + else: + param, grad = self._update_param_group(param_and_grad) if self._apply_decay_param_fun is not None \ and not self._apply_decay_param_fun(param.name): @@ -207,3 +238,8 @@ class AdamW(Adam): def __str__(self): return " ".join(["Weight Decay, params:", ",".join(self._params_name)]) + + def _update_param_group(self, parameters): + self._coeff = parameters.get('coeff', self._default_dict['coeff']) + parameters = parameters.get('params') + return parameters diff --git a/python/paddle/optimizer/lamb.py b/python/paddle/optimizer/lamb.py index bff24e71c815366b6d12108436a82edb27d271a7..b2044ab3ca1715b749f074a4737cfc092aa29666 100644 --- a/python/paddle/optimizer/lamb.py +++ b/python/paddle/optimizer/lamb.py @@ -59,7 +59,10 @@ class Lamb(Optimizer): Default 0.999. epsilon (float, optional): A small float value for numerical stability. Default 1e-6. parameters (Iterable, optional): Iterable of ``Variable`` names to update to minimize ``loss``. \ - This parameter is required in dygraph mode. \ + This parameter is required in dygraph mode. And you can specify different options for \ + different parameter groups such as the learning rate, weight decay, etc, \ + then the parameters are list of dict. Note that the learning_rate in paramter groups \ + represents the scale of base learning_rate. \ The default value is None in static mode, at this time all parameters will be updated. grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of some derived class of ``GradientClipBase`` . There are three cliping strategies @@ -83,6 +86,31 @@ class Lamb(Optimizer): back = out.backward() lamb.step() lamb.clear_grad() + + + #Note that the learning_rate of linear_2 is 0.01. + linear_1 = paddle.nn.Linear(10, 10) + linear_2 = paddle.nn.Linear(10, 10) + inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1) + out = linear_1(inp) + out = linear_2(out) + loss = paddle.mean(out) + lamb = paddle.optimizer.Lamb( + learning_rate=0.1, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'learning_rate': 0.1, + 'lamb_weight_decay': 0.02 + }], + weight_decay=0.01, + lamb_weight_decay=0.01) + out.backward() + lamb.step() + lamb.clear_grad() + """ _moment1_acc_str = "moment1" _moment2_acc_str = "moment2" @@ -115,9 +143,18 @@ class Lamb(Optimizer): self._epsilon = epsilon self._lamb_weight_decay = lamb_weight_decay self._exclude_from_weight_decay_fn = exclude_from_weight_decay_fn + self._default_dict = { + 'beta1': beta1, + 'beta2': beta2, + 'epsilon': epsilon, + 'lamb_weight_decay': lamb_weight_decay, + 'exclude_from_weight_decay_fn': exclude_from_weight_decay_fn, + } def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) + if isinstance(parameters, dict): + parameters = self._update_param_group(parameters) # Create accumulator tensors for first and second moments for p in parameters: @@ -140,6 +177,9 @@ class Lamb(Optimizer): def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) + block.program._use_lamb = True moment1 = self._get_accumulator(self._moment1_acc_str, @@ -199,3 +239,15 @@ class Lamb(Optimizer): stop_gradient=True) return lamb_op + + def _update_param_group(self, parameters): + self._beta1 = parameters.get('beta1', self._default_dict['beta1']) + self._beta2 = parameters.get('beta2', self._default_dict['beta2']) + self._epsilon = parameters.get('epsilon', self._default_dict['epsilon']) + self._lamb_weight_decay = parameters.get( + 'lamb_weight_decay', self._default_dict['lamb_weight_decay']) + self._exclude_from_weight_decay_fn = parameters.get( + 'exclude_from_weight_decay_fn', + self._default_dict['exclude_from_weight_decay_fn']) + parameters = parameters.get('params') + return parameters diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index 372143553e0c39988f5d0456125ee91bb94d3329..faff090bcb1f4ec2e906d2a3071930176a9c339f 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -51,8 +51,11 @@ class Momentum(Optimizer): learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``. It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. momentum (float): Momentum factor. The default value is 0.9. - parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \ - This parameter is required in dygraph mode. \ + parameters (list|tuple, optional): List|Tuple of ``Tensor`` to update to minimize ``loss``. \ + This parameter is required in dygraph mode. And you can specify different options for \ + different parameter groups such as the learning rate, weight decay, etc, \ + then the parameters are list of dict. Note that the learning_rate in paramter groups \ + represents the scale of base learning_rate. \ The default value is None in static mode, at this time all parameters will be updated. weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ It canbe a float value as coeff of L2 regularization or \ @@ -88,6 +91,29 @@ class Momentum(Optimizer): back = out.backward() momentum.step() momentum.clear_grad() + + #Note that the learning_rate of linear_2 is 0.01. + linear_1 = paddle.nn.Linear(10, 10) + linear_2 = paddle.nn.Linear(10, 10) + inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1) + out = linear_1(inp) + out = linear_2(out) + loss = paddle.mean(out) + momentum = paddle.optimizer.Momentum( + learning_rate=0.1, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'learning_rate': 0.1 + }], + weight_decay=0.01, + momentum=0.9) + out.backward() + momentum.step() + momentum.clear_grad() + """ _velocity_acc_str = "velocity" @@ -105,7 +131,19 @@ class Momentum(Optimizer): raise ValueError("learning_rate is not set") if momentum is None: raise ValueError("momentum is not set") + predicate = lambda regular: isinstance(regular, (L2DecayRegularizer, float)) + if isinstance(parameters, list): + if isinstance(parameters[0], dict): + for param_group in parameters: + decay = param_group[ + 'weight_decay'] if 'weight_decay' in param_group else weight_decay + reg_method, reg_coeff = self._update_regularization(decay) + param_group['regularization_method'] = reg_method + param_group['regularization_coeff'] = reg_coeff + py_regular = None if predicate(decay) else decay + param_group['weight_decay'] = py_regular + py_regular = None if predicate(weight_decay) else weight_decay super(Momentum, self).__init__( learning_rate=learning_rate, @@ -116,22 +154,41 @@ class Momentum(Optimizer): self.type = "momentum" self._momentum = momentum self._use_nesterov = bool(use_nesterov) - self._regularization_method = "" - self._regularization_coeff = 0 - if (isinstance(weight_decay, L2DecayRegularizer)): - self._regularization_method = "l2_decay" - self._regularization_coeff = weight_decay._regularization_coeff - if (isinstance(weight_decay, float)): - self._regularization_method = "l2_decay" - self._regularization_coeff = weight_decay + self._regularization_method, self._regularization_coeff = self._update_regularization( + weight_decay) self._multi_precision = multi_precision self._rescale_grad = rescale_grad self._master_weights = {} + self._default_dict = { + 'momentum': momentum, + 'use_nesterov': use_nesterov, + 'rescale_grad': rescale_grad, + 'regularization_method': self._regularization_method, + 'regularization_coeff': self._regularization_coeff, + } + if framework.in_dygraph_mode(): self.helper = LayerHelper(self.__class__.__name__) - for p in parameters: - self._add_accumulator(self._velocity_acc_str, p) + if isinstance(self._parameter_list[0], dict): + for parameters in self._param_groups: + for p in parameters['params']: + self._add_accumulator(self._velocity_acc_str, p) + else: + for p in parameters: + self._add_accumulator(self._velocity_acc_str, p) + + def _update_regularization(self, weight_decay): + reg_method = "" + reg_coeff = 0 + + if (isinstance(weight_decay, L2DecayRegularizer)): + reg_method = "l2_decay" + reg_coeff = weight_decay._regularization_coeff + if (isinstance(weight_decay, float)): + reg_method = "l2_decay" + reg_coeff = weight_decay + return reg_method, reg_coeff def _create_master_weight(self, param): assert isinstance(self.helper, LayerHelper) @@ -197,12 +254,16 @@ class Momentum(Optimizer): def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) velocity_acc = self._get_accumulator(self._velocity_acc_str, param_and_grad[0]) lr = self._create_param_lr(param_and_grad) if framework.in_dygraph_mode(): + if isinstance(param_and_grad, dict): + self._update_regularization(param_and_grad['weight_decay']) _, _ = core.ops.momentum( param_and_grad[0], param_and_grad[1], velocity_acc, lr, param_and_grad[0], velocity_acc, 'mu', self._momentum, @@ -250,3 +311,18 @@ class Momentum(Optimizer): stop_gradient=True) return momentum_op + + def _update_param_group(self, parameters): + self._momentum = parameters.get('momentum', + self._default_dict['momentum']) + self._use_nesterov = parameters.get('use_nesterov', + self._default_dict['use_nesterov']) + self._rescale_grad = parameters.get('rescale_grad', + self._default_dict['rescale_grad']) + self._regularization_method = parameters.get( + 'regularization_method', + self._default_dict['regularization_method']) + self._regularization_coeff = parameters.get( + 'regularization_coeff', self._default_dict['regularization_coeff']) + parameters = parameters.get('params') + return parameters diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index b06bd2a2b0be9539ed33f5c898da7d15f92a09a6..0f22b920b17deba923b945115f4f274c84f2ddf6 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -28,7 +28,7 @@ from ..fluid import layers from ..fluid import unique_name from ..fluid.backward import append_backward, _some_in_set_, _append_grad_suffix_, _get_no_grad_set_name from ..fluid.clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops -from ..fluid.framework import program_guard +from ..fluid.framework import program_guard, Parameter from ..fluid.initializer import Constant from ..fluid.layer_helper import LayerHelper from ..fluid.layers import ops @@ -41,6 +41,7 @@ from functools import reduce from ..fluid.wrapped_decorator import signature_safe_contextmanager from .. import compat as cpt from .lr import LRScheduler +import copy __all__ = [] @@ -56,7 +57,10 @@ class Optimizer(object): learning_rate (float|LRScheduler): The learning rate used to update ``Parameter``. It can be a float value or any subclass of ``LRScheduler`` . parameters (list|tuple, optional): List/Tuple of ``Tensor`` names to update to minimize ``loss``. \ - This parameter is required in dygraph mode. \ + This parameter is required in dygraph mode. And you can specify different options for \ + different parameter groups such as the learning rate, weight decay, etc, \ + then the parameters are list of dict. Note that the learning_rate in paramter groups \ + represents the scale of base learning_rate. \ The default value is None in static mode, at this time all parameters will be updated. weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ It canbe a float value as coeff of L2 regularization or \ @@ -91,6 +95,29 @@ class Optimizer(object): adam.step() adam.clear_grad() + #Take the subclass sgd as an example + #optimize parameters in linear_1 and linear2 in different options. + #Note that the learning_rate of linear_2 is 0.01. + linear_1 = paddle.nn.Linear(10, 10) + linear_2 = paddle.nn.Linear(10, 10) + inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1) + out = linear_1(inp) + out = linear_2(out) + loss = paddle.mean(out) + sgd = paddle.optimizer.SGD( + learning_rate=0.1, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'learning_rate': 0.1 + }], + weight_decay=0.01) + out.backward() + sgd.step() + sgd.clear_grad() + """ @imperative_base.no_grad @@ -100,6 +127,7 @@ class Optimizer(object): weight_decay=None, grad_clip=None, name=None): + if parameters is not None: # paddle.Tensor is also iterable, so here we don't check whether # the input is iterable, if the input is paddle.Tensor, the @@ -109,6 +137,11 @@ class Optimizer(object): "`parameters` argument given to the optimizer should be " "an iterable of paddle Tensors, but got argument type is `{}`.". format(type(parameters))) + if isinstance(parameters, dict): + raise TypeError( + "`parameters` argument should not get dict type, " + "if parameter groups is needed, please set `parameters`" + " as list of dict") self._parameter_list = list(parameters) else: self._parameter_list = None @@ -120,14 +153,17 @@ class Optimizer(object): "parameters argument given to the Optimizer should not be None in dygraph mode." ) if weight_decay is not None: - for param in self._parameter_list: - if hasattr(param, - 'regularizer') and param.regularizer is not None: - logging.info( - "If regularizer of a Parameter has been set by 'paddle.ParamAttr' or 'static.WeightNormParamAttr' already. " - "The weight_decay[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!" - % weight_decay.__str__()) - break + if not isinstance(self._parameter_list[0], dict): + for param in self._parameter_list: + if hasattr( + param, + 'regularizer') and param.regularizer is not None: + logging.info( + "If regularizer of a Parameter has been set by 'paddle.ParamAttr' or 'static.WeightNormParamAttr' already. " + "The weight_decay[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!" + % weight_decay.__str__()) + break + if not isinstance(learning_rate, (float, LRScheduler)): raise TypeError( "learning rate should be float or LRScheduler, got %s here" % @@ -148,7 +184,13 @@ class Optimizer(object): self._dtype = None # Infer the dtype form parameter if self._parameter_list: - self._dtype = self._parameter_list[0].dtype + if isinstance(self._parameter_list[0], dict): + for param_group in self._parameter_list: + assert 'params' in param_group, \ + 'params should be set in parameters if parameter groups are optimized in different options' + self._dtype = self._parameter_list[0]['params'][0].dtype + else: + self._dtype = self._parameter_list[0].dtype # each program should have a independent learning rate # program -> tensor(learning_rate) @@ -163,6 +205,18 @@ class Optimizer(object): self._accumulators_holder = {} self._param_device_map = dict() self.clear_gradients = self.clear_grad + self._default_dict = { + 'learning_rate': self._learning_rate, + 'weight_decay': self.regularization, + 'grad_clip': self._grad_clip + } + + self._param_groups = [] + if self._parameter_list and isinstance(self._parameter_list[0], dict): + for param_group in self._parameter_list: + self._add_param_group(param_group.copy()) + else: + self._param_groups = self._parameter_list @framework.dygraph_only def state_dict(self): @@ -610,18 +664,45 @@ class Optimizer(object): start = len(target_block.ops) self.helper = LayerHelper(self.__class__.__name__) - self._update_param_device_map(parameters_and_grads, target_block) - self._create_accumulators( - target_block, - [p[0] for p in parameters_and_grads if not p[0].stop_gradient]) + params_grads_device_map = parameters_and_grads['params'] if isinstance( + parameters_and_grads, dict) else parameters_and_grads + self._update_param_device_map(params_grads_device_map, target_block) + if isinstance(parameters_and_grads, list): + self._create_accumulators( + target_block, + [p[0] for p in parameters_and_grads if not p[0].stop_gradient]) + + else: + params_acc_dict = parameters_and_grads.copy() + params_acc_dict['params'] = [ + p[0] for p in params_acc_dict['params'] + if not p[0].stop_gradient + ] + self._create_accumulators(target_block, params_acc_dict) + self._create_global_learning_rate() if framework.in_dygraph_mode(): - for param_and_grad in parameters_and_grads: - if param_and_grad[1] is None: - continue - if param_and_grad[0].stop_gradient is False: - self._append_optimize_op(target_block, param_and_grad) + + if isinstance(parameters_and_grads, list): + for param_and_grad in parameters_and_grads: + if param_and_grad[1] is None: + continue + if param_and_grad[0].stop_gradient is False: + self._append_optimize_op(target_block, param_and_grad) + else: + for param_and_grad in parameters_and_grads['params']: + if param_and_grad[1] is None: + continue + if param_and_grad[0].stop_gradient is False: + param_grad_dict = dict() + param_grad_dict['params'] = param_and_grad + param_grad_dict.update({ + k: v + for k, v in parameters_and_grads.items() + if k != 'params' + }) + self._append_optimize_op(target_block, param_grad_dict) else: for param_and_grad in parameters_and_grads: if param_and_grad[1] is None: @@ -790,10 +871,19 @@ class Optimizer(object): if framework.in_dygraph_mode(): with program_guard(framework.default_main_program(), framework.default_startup_program()): - if self._grad_clip is not None: - params_grads = self._grad_clip(params_grads) - params_grads = append_regularization_ops(params_grads, - self.regularization) + if isinstance(params_grads, list): + if self._grad_clip is not None: + params_grads = self._grad_clip(params_grads) + params_grads = append_regularization_ops( + params_grads, self.regularization) + else: + grad_clip = params_grads['grad_clip'] + if grad_clip is not None: + params_grads['params'] = grad_clip(params_grads[ + 'params']) + + params_grads['params'] = append_regularization_ops( + params_grads['params'], self.regularization) optimize_ops = self._create_optimization_pass(params_grads) else: program = loss.block.program @@ -840,9 +930,16 @@ class Optimizer(object): adam.clear_grad() """ - for p in self._parameter_list: - if not p.stop_gradient: - p.clear_gradient() + if self._parameter_list is None or not isinstance( + self._parameter_list[0], dict): + for p in self._parameter_list: + if not p.stop_gradient: + p.clear_gradient() + else: + for param_group in self._param_groups: + for p in param_group['params']: + if not p.stop_gradient: + p.clear_gradient() @imperative_base.no_grad def minimize(self, @@ -934,13 +1031,82 @@ class Optimizer(object): adam.step() adam.clear_grad() """ - params_grads = [] - for param in self._parameter_list: - if param.stop_gradient: - continue - if param._grad_ivar() is not None: - grad_var = param._grad_ivar() - params_grads.append((param, grad_var)) - - self._apply_optimize( - loss=None, startup_program=None, params_grads=params_grads) + + if not isinstance(self._param_groups[0], dict): + params_grads = [] + for param in self._param_groups: + if param.stop_gradient: + continue + if param._grad_ivar() is not None: + grad_var = param._grad_ivar() + params_grads.append((param, grad_var)) + + self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads) + + else: + # optimize parameters in groups + for param_group in self._param_groups: + params_grads = defaultdict(lambda: list()) + for param in param_group['params']: + if param.stop_gradient: + continue + if param._grad_ivar() is not None: + grad_var = param._grad_ivar() + params_grads['params'].append((param, grad_var)) + params_grads.update( + {k: v + for k, v in param_group.items() if k != 'params'}) + self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads) + + def _add_param_group(self, param_group): + """ + Add a param group to parameter_list. + + Args: + param_group (dict): The group of Tensors to be optimzed with + different optimization options. + """ + params = param_group['params'] + if isinstance(params, Parameter): + param_group['params'] = [params] + elif isinstance(params, set): + raise TypeError( + "optimizer parameters should be in ordered collections," + "but received set, please use list instead.") + else: + param_group['params'] = list(params) + + # Update optimization options for each groups + for k, v in self._default_dict.items(): + param_group.setdefault(k, v) + + param_set = set() + for group in self._param_groups: + param_set.update(set(group['params'])) + + if not param_set.isdisjoint(set(param_group['params'])): + raise ValueError( + "some parameters appear in more than one parameter group") + + for param in param_group['params']: + weight_decay = param_group['weight_decay'] + if isinstance(weight_decay, float): + from ..fluid.regularizer import L2Decay + regularization = L2Decay(weight_decay) + else: + regularization = weight_decay + param.regularizer = regularization + param.optimize_attr['learning_rate'] = param_group['learning_rate'] + + self._param_groups.append(param_group) + + def _update_param_group(self, parameters): + """ + Update the param group with new entry + Args: + parameters (dict): The extra group of Tensors to be optimzed with + different optimization options. Only used in child class. + """ + pass diff --git a/python/paddle/optimizer/rmsprop.py b/python/paddle/optimizer/rmsprop.py index b0bb0228c8ca82acc40b62e1a9074636b4def097..14249df3f5628fff3823e770d843f5af0a7e8c1e 100644 --- a/python/paddle/optimizer/rmsprop.py +++ b/python/paddle/optimizer/rmsprop.py @@ -80,16 +80,19 @@ class RMSProp(Optimizer): the gradient; if False, by the uncentered second moment. Setting this to True may help with training, but is slightly more expensive in terms of computation and memory. Defaults to False. - parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \ - This parameter is required in dygraph mode. \ - The default value is None in static mode, at this time all parameters will be updated. - weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ - It canbe a float value as coeff of L2 regularization or \ - :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`. - If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \ - the regularization setting here in optimizer will be ignored for this parameter. \ - Otherwise, the regularization setting here in optimizer will take effect. \ - Default None, meaning there is no regularization. + parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \ + This parameter is required in dygraph mode. And you can specify different options for \ + different parameter groups such as the learning rate, weight decay, etc, \ + then the parameters are list of dict. Note that the learning_rate in paramter groups \ + represents the scale of base learning_rate. \ + The default value is None in static mode, at this time all parameters will be updated. + weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ + It canbe a float value as coeff of L2 regularization or \ + :ref:`api_fluid_regularizer_L1Decay`, :ref:`api_fluid_regularizer_L2Decay`. + If a parameter has set regularizer using :ref:`api_fluid_ParamAttr` already, \ + the regularization setting here in optimizer will be ignored for this parameter. \ + Otherwise, the regularization setting here in optimizer will take effect. \ + Default None, meaning there is no regularization. grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of some derived class of ``GradientClipBase`` . There are three cliping strategies ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , @@ -117,6 +120,26 @@ class RMSProp(Optimizer): rmsprop.step() rmsprop.clear_grad() + #Note that the learning_rate of linear_2 is 0.01. + linear_1 = paddle.nn.Linear(10, 10) + linear_2 = paddle.nn.Linear(10, 10) + inp = paddle.uniform(shape=[10, 10], min=-0.1, max=0.1) + out = linear_1(inp) + out = linear_2(out) + loss = paddle.mean(out) + rmsprop = paddle.optimizer.RMSProp( + learning_rate=0.1, + parameters=[{ + 'params': linear_1.parameters() + }, { + 'params': linear_2.parameters(), + 'weight_decay': 0.001, + 'learning_rate': 0.1 + }], + weight_decay=0.01) + out.backward() + rmsprop.step() + rmsprop.clear_grad() """ _momentum_acc_str = "momentum" @@ -160,11 +183,20 @@ class RMSProp(Optimizer): self._epsilon = epsilon self._momentum = momentum self._centered = centered + self._default_dict = { + 'rho': rho, + 'epsilon': epsilon, + 'momentum': momentum, + 'centered': centered, + } def _create_accumulators(self, block, parameters): if not isinstance(block, framework.Block): raise TypeError("block is not instance of framework.Block.") + if isinstance(parameters, dict): + parameters = parameters.get('params') + for p in parameters: self._add_accumulator(self._momentum_acc_str, p) self._add_accumulator(self._mean_square_acc_str, p) @@ -174,6 +206,9 @@ class RMSProp(Optimizer): if not isinstance(block, framework.Block): raise TypeError("block is not instance of framework.Block.") + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) + momentum_acc = self._get_accumulator(self._momentum_acc_str, param_and_grad[0]) mean_square_acc = self._get_accumulator(self._mean_square_acc_str, @@ -205,3 +240,13 @@ class RMSProp(Optimizer): stop_gradient=True) return rmsprop_op + + def _update_param_group(self, parameters): + self._epsilon = parameters.get('epsilon', self._default_dict['epsilon']) + self._rho = parameters.get('rho', self._default_dict['rho']) + self._momentum = parameters.get('momentum', + self._default_dict['momentum']) + self._centered = parameters.get('centered', + self._default_dict['centered']) + parameters = parameters.get('params') + return parameters diff --git a/python/paddle/optimizer/sgd.py b/python/paddle/optimizer/sgd.py index 4526034b405b0c97f1b06e07f3e4279cdc2d0d95..107581e060588af8b51744f87eba1278c6f1c1eb 100644 --- a/python/paddle/optimizer/sgd.py +++ b/python/paddle/optimizer/sgd.py @@ -87,6 +87,8 @@ class SGD(Optimizer): @no_grad def _append_optimize_op(self, block, param_and_grad): + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) lr = self._create_param_lr(param_and_grad) if framework.in_dygraph_mode(): core.ops.sgd(param_and_grad[0], lr, param_and_grad[1], @@ -106,3 +108,7 @@ class SGD(Optimizer): stop_gradient=True) return sgd_op + + def _update_param_group(self, parameters): + parameters = parameters.get('params') + return parameters