From bb80dae7d08aca609137576877bc6a078ff199b3 Mon Sep 17 00:00:00 2001 From: chengduo Date: Fri, 29 Mar 2019 11:17:40 -0500 Subject: [PATCH] Add DecoupledWeightDecay (#16427) * Add DecoupledWeightDecay --- paddle/fluid/API.spec | 13 ++ python/paddle/fluid/contrib/__init__.py | 3 + .../contrib/extend_optimizer/__init__.py | 20 +++ .../extend_optimizer_with_weight_decay.py | 152 ++++++++++++++++++ .../contrib/tests/test_weight_decay_extend.py | 151 +++++++++++++++++ python/paddle/fluid/optimizer.py | 99 +++++++----- python/setup.py.in | 1 + 7 files changed, 402 insertions(+), 37 deletions(-) create mode 100644 python/paddle/fluid/contrib/extend_optimizer/__init__.py create mode 100644 python/paddle/fluid/contrib/extend_optimizer/extend_optimizer_with_weight_decay.py create mode 100644 python/paddle/fluid/contrib/tests/test_weight_decay_extend.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 79277a4174b..923a923bccc 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -406,6 +406,7 @@ paddle.fluid.contrib.HDFSClient.rename (ArgSpec(args=['self', 'hdfs_src_path', ' paddle.fluid.contrib.HDFSClient.upload (ArgSpec(args=['self', 'hdfs_path', 'local_path', 'overwrite', 'retry_times'], varargs=None, keywords=None, defaults=(False, 5)), ('document', '7d053b4bfd6dcfdd2c9dda0e0dbd9665')) paddle.fluid.contrib.multi_download (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'trainer_id', 'trainers', 'multi_processes'], varargs=None, keywords=None, defaults=(5,)), ('document', '100927be598ed8f9eaa1f3ef1b23568a')) paddle.fluid.contrib.multi_upload (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'multi_processes', 'overwrite', 'sync'], varargs=None, keywords=None, defaults=(5, False, True)), ('document', '183f34c83d30dbe16e09e8716c41958a')) +paddle.fluid.contrib.extend_with_decoupled_weight_decay (ArgSpec(args=['base_optimizer'], varargs=None, keywords=None, defaults=None), ('document', 'a1095dfd4ec725747f662d69cd7659d4')) paddle.fluid.transpiler.DistributeTranspiler.__init__ (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program (ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None), ('document', '292ab72977afbe58e6a3bde175452680')) paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs (ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None), ('document', '78f4949aedf317666a89ca74b3748ba8')) @@ -428,63 +429,75 @@ paddle.fluid.nets.scaled_dot_product_attention (ArgSpec(args=['queries', 'keys', paddle.fluid.nets.img_conv_group (ArgSpec(args=['input', 'conv_num_filter', 'pool_size', 'conv_padding', 'conv_filter_size', 'conv_act', 'param_attr', 'conv_with_batchnorm', 'conv_batchnorm_drop_rate', 'pool_stride', 'pool_type', 'use_cudnn'], varargs=None, keywords=None, defaults=(1, 3, None, None, False, 0.0, 1, 'max', True)), ('document', '3802be78fbfb206dae64a2d9f8480970')) paddle.fluid.optimizer.SGDOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'regularization', 'name'], varargs=None, keywords=None, defaults=(None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.SGDOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.SGDOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.SGDOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.SGDOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.SGDOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.MomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'use_nesterov', 'regularization', 'name'], varargs=None, keywords=None, defaults=(False, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.MomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.MomentumOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.MomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.MomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.MomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.AdagradOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'epsilon', 'regularization', 'name', 'initial_accumulator_value'], varargs=None, keywords=None, defaults=(1e-06, None, None, 0.0)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.AdagradOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.AdagradOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.AdagradOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.AdagradOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.AdagradOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.AdamOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon', 'regularization', 'name', 'lazy_mode'], varargs=None, keywords=None, defaults=(0.001, 0.9, 0.999, 1e-08, None, None, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.AdamOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.AdamOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.AdamOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.AdamOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.AdamOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.AdamaxOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.9, 0.999, 1e-08, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.AdamaxOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.AdamaxOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.AdamaxOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.AdamaxOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.AdamaxOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.DecayedAdagradOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'decay', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.95, 1e-06, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.DecayedAdagradOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.DecayedAdagradOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.DecayedAdagradOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.DecayedAdagradOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.DecayedAdagradOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.FtrlOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'l1', 'l2', 'lr_power', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.0, 0.0, -0.5, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.FtrlOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.FtrlOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.FtrlOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.FtrlOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.FtrlOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.RMSPropOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum', 'centered', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.95, 1e-06, 0.0, False, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.RMSPropOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.RMSPropOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.RMSPropOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.RMSPropOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.RMSPropOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.AdadeltaOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'epsilon', 'rho', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1e-06, 0.95, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.AdadeltaOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.AdadeltaOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.AdadeltaOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.AdadeltaOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.AdadeltaOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.ModelAverage.__init__ (ArgSpec(args=['self', 'average_window_rate', 'min_average_window', 'max_average_window', 'regularization', 'name'], varargs=None, keywords=None, defaults=(10000, 10000, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.ModelAverage.apply (ArgSpec(args=['self', 'executor', 'need_restore'], varargs=None, keywords=None, defaults=(True,)), ('document', '46234a5470590feb336346f70a3db715')) paddle.fluid.optimizer.ModelAverage.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.ModelAverage.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.ModelAverage.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.ModelAverage.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.ModelAverage.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.ModelAverage.restore (ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None), ('document', '18db9c70be9c4dd466f9844457b21bfe')) paddle.fluid.optimizer.LarsMomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'lars_coeff', 'lars_weight_decay', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.0005, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.LarsMomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.LarsMomentumOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.LarsMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.LarsMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.LarsMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) paddle.fluid.optimizer.DGCMomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'rampup_begin_step', 'rampup_step', 'sparsity', 'use_nesterov', 'local_grad_clip_norm', 'num_trainers', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1, [0.999], False, None, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.DGCMomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) +paddle.fluid.optimizer.DGCMomentumOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) diff --git a/python/paddle/fluid/contrib/__init__.py b/python/paddle/fluid/contrib/__init__.py index 870c57e5401..7442059ba07 100644 --- a/python/paddle/fluid/contrib/__init__.py +++ b/python/paddle/fluid/contrib/__init__.py @@ -30,6 +30,8 @@ from . import slim from .slim import * from . import utils from .utils import * +from . import extend_optimizer +from .extend_optimizer import * __all__ = [] __all__ += decoder.__all__ @@ -40,3 +42,4 @@ __all__ += int8_inference.__all__ __all__ += reader.__all__ __all__ += slim.__all__ __all__ += utils.__all__ +__all__ += extend_optimizer.__all__ diff --git a/python/paddle/fluid/contrib/extend_optimizer/__init__.py b/python/paddle/fluid/contrib/extend_optimizer/__init__.py new file mode 100644 index 00000000000..697ea0f05ae --- /dev/null +++ b/python/paddle/fluid/contrib/extend_optimizer/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from . import extend_optimizer_with_weight_decay +from .extend_optimizer_with_weight_decay import * + +__all__ = [] +__all__ += extend_optimizer_with_weight_decay.__all__ diff --git a/python/paddle/fluid/contrib/extend_optimizer/extend_optimizer_with_weight_decay.py b/python/paddle/fluid/contrib/extend_optimizer/extend_optimizer_with_weight_decay.py new file mode 100644 index 00000000000..fcc99c07346 --- /dev/null +++ b/python/paddle/fluid/contrib/extend_optimizer/extend_optimizer_with_weight_decay.py @@ -0,0 +1,152 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle.fluid +from paddle.fluid import framework as framework + +__all__ = ["extend_with_decoupled_weight_decay"] + + +class DecoupledWeightDecay(object): + def __init__(self, coeff=0.0, apply_decay_param_fun=None, **kwargs): + if not isinstance(coeff, float) and \ + not isinstance(coeff, framework.Variable): + raise TypeError("coeff should be float or Variable.") + self._params_name = set() + self._apply_decay_param_fun = apply_decay_param_fun + self._coeff = coeff + super(DecoupledWeightDecay, self).__init__(**kwargs) + + def _scale_parameters(self, params_and_grads): + """ + Adds weight decay ops. + scaled_parameter = parameter * coeff + + Args: + params_and_grads: A list of (parameters, gradients) pairs, + the parameters need to decay. + Raises: + Exception: The type of coeff and parameter is not consistent. + """ + if isinstance(self._coeff, float) and self._coeff == 0.0: + return + + scaled_params = [] + for param, grad in params_and_grads: + # If no gradient then we don't need to do anything + if grad is None: + continue + if self._apply_decay_param_fun is not None \ + and not self._apply_decay_param_fun(param.name): + continue + + if isinstance(self._coeff, float): + assert param.dtype is not paddle.fluid.core.VarDesc.VarType.FP32, \ + "the type of coeff(float) and parameter(%s) is not consistent."%(self._coeff.dtype) + else: + assert self._coeff.dtype == param.dtype, \ + "the type of coeff(%s) and parameter(%s) is not consistent."%(self._coeff.dtype, param.dtype) + + with param.block.program._optimized_guard( + [param, grad]), framework.name_scope('weight decay'): + assert param.name not in self._params_name + scaled_params.append((param, grad, param * self._coeff)) + self._params_name.add(param.name) + return scaled_params + + def backward(self, **kargs): + return super(DecoupledWeightDecay, self).backward(**kargs) + + def apply_optimize(self, **kargs): + return super(DecoupledWeightDecay, self).apply_optimize(**kargs) + + def minimize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + params_grads = self.backward( + loss=loss, + startup_program=startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set) + scaled_params = self._scale_parameters(params_grads) + for p_grad_sgrad in scaled_params: + param, grad, scaled_param = p_grad_sgrad + with param.block.program._optimized_guard( + [param, grad]), framework.name_scope('weight decay'): + updated_param = paddle.fluid.layers.elementwise_sub( + x=param, y=scaled_param) + paddle.fluid.layers.assign(input=updated_param, output=param) + + optimize_ops = self.apply_optimize( + loss=loss, + params_grads=params_grads, + startup_program=startup_program) + return optimize_ops, params_grads + + def __str__(self): + return " ".join(["Weight Decay, params:", ",".join(self._params_name)]) + + +def extend_with_decoupled_weight_decay(base_optimizer): + """ + extend_with_decoupled_weight_decay is a decorator function, it returns an + optimizer class with decoupled weight decay. The returned optimizer will + apply weight decay on the optimized parameters with the parameters before + optimization, i.e: new_parameter = optimized_parameter - parameter * coeff. + The details of decoupled weight decay yplease refer to this + `DECOUPLED WEIGHT DECAY REGULARIZATION `_. + + Args: + base_optimizer (Optimizer): The base_optimizer should be a derived class of Optimizer. + + Returns: + OptimizerWithDecoupledWeightDecay: the optimizer with decouple weight decay. + + Examples: + + .. code-block:: python + + AdamW = fluid.contrib.extend_with_decoupled_weight_decay( + fluid.optimizer.Adam) + optimizer = AdamW(learning_rate=0.1, + weight_decay=0.01) + + optimizer.minimize(cost) + """ + if not issubclass(base_optimizer, paddle.fluid.optimizer.Optimizer): + raise TypeError( + "The input(base_optimizer) should be a derived class of Optimizer.") + + class OptimizerWithDecoupledWeightDecay(DecoupledWeightDecay, + base_optimizer): + """ + OptimizerWithDecoupledWeightDecay is used to update the optimized parameters + with the parameters before optimization. For more information, please refer: + https://arxiv.org/pdf/1711.05101.pdf. + + Args: + weight_decay (float|Variable): The weight decay coefficient, it can be + float or Variable. + apply_decay_param_fun (function|None): If it is not None, + only variables that makes apply_decay_param_fun(variable)==True + will be updated. It only works when we want to specify variables. + Default: None. + """ + + def __init__(self, weight_decay, apply_decay_param_fun=None, **kwargs): + super(OptimizerWithDecoupledWeightDecay, self).__init__( + weight_decay, apply_decay_param_fun, **kwargs) + + return OptimizerWithDecoupledWeightDecay diff --git a/python/paddle/fluid/contrib/tests/test_weight_decay_extend.py b/python/paddle/fluid/contrib/tests/test_weight_decay_extend.py new file mode 100644 index 00000000000..2b331308de5 --- /dev/null +++ b/python/paddle/fluid/contrib/tests/test_weight_decay_extend.py @@ -0,0 +1,151 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +from functools import partial +import numpy as np +import paddle +import paddle.fluid as fluid +import contextlib + + +def get_places(): + places = [fluid.CPUPlace()] + if fluid.core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + return places + + +@contextlib.contextmanager +def prog_scope_guard(main_prog, startup_prog): + scope = fluid.core.Scope() + with fluid.unique_name.guard(): + with fluid.scope_guard(scope): + with fluid.program_guard(main_prog, startup_prog): + yield + + +def bow_net(data, + label, + dict_dim, + is_sparse=False, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2): + """ + BOW net + This model is from https://github.com/PaddlePaddle/models: + fluid/PaddleNLP/text_classification/nets.py + """ + emb = fluid.layers.embedding( + input=data, is_sparse=is_sparse, size=[dict_dim, emb_dim]) + bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') + bow_tanh = fluid.layers.tanh(bow) + fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh") + fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh") + prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax") + cost = fluid.layers.cross_entropy(input=prediction, label=label) + avg_cost = fluid.layers.mean(x=cost) + + return avg_cost + + +class TestWeightDecay(unittest.TestCase): + def setUp(self): + self.word_dict = paddle.dataset.imdb.word_dict() + reader = paddle.batch( + paddle.dataset.imdb.train(self.word_dict), batch_size=2)() + self.train_data = [next(reader) for _ in range(5)] + self.learning_rate = .5 + + def run_program(self, place, feed_list): + exe = fluid.Executor(place) + feeder = fluid.DataFeeder(feed_list=feed_list, place=place) + exe.run(fluid.default_startup_program()) + + main_prog = fluid.default_main_program() + param_list = [var.name for var in main_prog.block(0).all_parameters()] + + param_sum = [] + for data in self.train_data: + out = exe.run(main_prog, + feed=feeder.feed(data), + fetch_list=param_list) + p_sum = 0 + for v in out: + p_sum += np.sum(np.abs(v)) + param_sum.append(p_sum) + return param_sum + + def check_weight_decay(self, place, model): + main_prog = fluid.framework.Program() + startup_prog = fluid.framework.Program() + startup_prog.random_seed = 1 + with prog_scope_guard(main_prog=main_prog, startup_prog=startup_prog): + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + avg_cost = model(data, label, len(self.word_dict)) + AdamW = fluid.contrib.extend_with_decoupled_weight_decay( + fluid.optimizer.Adam) + + optimizer = AdamW( + learning_rate=self.learning_rate, + weight_decay=self.learning_rate) + + optimizer.minimize(avg_cost) + param_sum = self.run_program(place, [data, label]) + + return param_sum + + def check_weight_decay2(self, place, model): + main_prog = fluid.framework.Program() + startup_prog = fluid.framework.Program() + startup_prog.random_seed = 1 + with prog_scope_guard(main_prog=main_prog, startup_prog=startup_prog): + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + avg_cost = model(data, label, len(self.word_dict)) + + param_list = [(var, var * self.learning_rate) + for var in main_prog.block(0).all_parameters()] + + optimizer = fluid.optimizer.Adam(learning_rate=self.learning_rate) + + optimizer.minimize(avg_cost) + for params in param_list: + updated_p = fluid.layers.elementwise_sub( + x=params[0], y=params[1]) + fluid.layers.assign(input=updated_p, output=params[0]) + + param_sum = self.run_program(place, [data, label]) + return param_sum + + def test_weight_decay(self): + for place in get_places(): + model = partial(bow_net, is_sparse=False) + param_sum1 = self.check_weight_decay(place, model) + param_sum2 = self.check_weight_decay2(place, model) + + for i in range(len(param_sum1)): + assert np.isclose(a=param_sum1[i], b=param_sum2[i], rtol=5e-5) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 479c0b0a4ab..45a065da835 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -325,12 +325,38 @@ class Optimizer(object): Examples: See examples in `apply_gradients`. """ - if callbacks is None: - callbacks = [error_clip_callback] + self._dtype = loss.dtype + if framework._in_dygraph_mode(): + if parameter_list is not None: + parameters = parameter_list + else: + parameters = framework._dygraph_tracer().all_parameters() + + params_grads = [] + for param in parameters: + if not param.trainable: + continue + if param._ivar._grad_ivar() is not None: + # create gradient variable + grad_var = Variable( + block=loss.block, + name=param._ivar._grad_name(), + stop_gradient=True, + ivar=param._ivar._grad_ivar()) + params_grads.append((param, grad_var)) else: - assert (isinstance(callbacks, list)) - callbacks.append(error_clip_callback) - return append_backward(loss, parameter_list, no_grad_set, callbacks) + if callbacks is None: + callbacks = [error_clip_callback] + else: + assert (isinstance(callbacks, list)) + program = loss.block.program + with program_guard(program, startup_program): + params_grads = append_backward(loss, parameter_list, + no_grad_set, callbacks) + # Note: since we can't use all_reduce_op now, + # dgc_op should be the last op of one grad. + self._append_dgc_ops(params_grads) + return params_grads def apply_gradients(self, params_grads): """ @@ -371,6 +397,30 @@ class Optimizer(object): return optimize_ops + def apply_optimize(self, loss, startup_program, params_grads): + """ + Second part of `minimize`, appending optimization operators for + given `params_grads` pairs. + + Args: + loss (Variable): loss variable to run optimizations. + startup_program (Program): startup_program for initializing parameters + in `parameter_list`. + params_grads (list): list of (param, grad) pair to do optimization. + + Returns: + list: A list of operators appended to the current program. + """ + if framework._in_dygraph_mode(): + with program_guard(framework.default_main_program(), + framework.default_startup_program()): + optimize_ops = self._create_optimization_pass(params_grads) + else: + program = loss.block.program + with program_guard(program, startup_program): + optimize_ops = self.apply_gradients(params_grads) + return optimize_ops + def minimize(self, loss, startup_program=None, @@ -393,38 +443,13 @@ class Optimizer(object): tuple: (optimize_ops, params_grads) which are, list of operators appended; and list of (param, grad) Variables pair for optimization. """ - self._dtype = loss.dtype - optimize_ops = [] - if framework._in_dygraph_mode(): - if parameter_list is not None: - parameters = parameter_list - else: - parameters = framework._dygraph_tracer().all_parameters() - - params_grads = [] - for param in parameters: - if not param.trainable: - continue - if param._ivar._grad_ivar() is not None: - # create gradient variable - grad_var = Variable( - block=loss.block, - name=param._ivar._grad_name(), - stop_gradient=True, - ivar=param._ivar._grad_ivar()) - params_grads.append((param, grad_var)) - with program_guard(framework.default_main_program(), - framework.default_startup_program()): - optimize_ops = self._create_optimization_pass(params_grads) - else: - program = loss.block.program - with program_guard(program, startup_program): - params_grads = self.backward(loss, startup_program, - parameter_list, no_grad_set) - # Note: since we can't use all_reduce_op now, - # dgc_op should be the last op of one grad. - self._append_dgc_ops(params_grads) - optimize_ops = self.apply_gradients(params_grads) + params_grads = self.backward( + loss, + startup_program=startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set) + optimize_ops = self.apply_optimize( + loss, startup_program=startup_program, params_grads=params_grads) return optimize_ops, params_grads diff --git a/python/setup.py.in b/python/setup.py.in index 68f96273a23..75e821582f4 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -119,6 +119,7 @@ packages=['paddle', 'paddle.fluid.contrib.slim.quantization', 'paddle.fluid.contrib.slim.distillation', 'paddle.fluid.contrib.utils', + 'paddle.fluid.contrib.extend_optimizer', 'paddle.fluid.transpiler', 'paddle.fluid.transpiler.details'] -- GitLab