From 9f85f21880566ddd8bcfdd507293a493e82ff0db Mon Sep 17 00:00:00 2001 From: Hongyu Liu <43953930+phlrain@users.noreply.github.com> Date: Wed, 29 May 2019 10:34:13 +0800 Subject: [PATCH] Add new gard clip [old gradient clip not support in dy graph] (#17523) * add gradient clip in minimize; test=develop * fix bug; test=develop * fix format; test=develop * move new grad clip to dygraph/grad_clip.py; test=develop * fix lr decay and grad clip test; test=develop * seperate dygraph grad clip; test=develop * fix grad clip test; develop * fix api spec bug; test=develop * add blank line, test=develop,test=document_preview to fix format problem --- paddle/fluid/API.spec | 29 +- python/paddle/fluid/__init__.py | 2 + python/paddle/fluid/dygraph_grad_clip.py | 282 ++++++++++++++++++ python/paddle/fluid/optimizer.py | 11 +- .../unittests/test_grad_clip_minimize.py | 254 ++++++++++++++++ 5 files changed, 564 insertions(+), 14 deletions(-) create mode 100644 python/paddle/fluid/dygraph_grad_clip.py create mode 100644 python/paddle/fluid/tests/unittests/test_grad_clip_minimize.py diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 38631ea798..2732b8f7a8 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -448,81 +448,81 @@ paddle.fluid.optimizer.SGDOptimizer.apply_gradients (ArgSpec(args=['self', 'para paddle.fluid.optimizer.SGDOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.SGDOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.SGDOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.SGDOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.SGDOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.MomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'use_nesterov', 'regularization', 'name'], varargs=None, keywords=None, defaults=(False, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.MomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.MomentumOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.MomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.MomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.MomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.MomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.AdagradOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'epsilon', 'regularization', 'name', 'initial_accumulator_value'], varargs=None, keywords=None, defaults=(1e-06, None, None, 0.0)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.AdagradOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.AdagradOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.AdagradOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.AdagradOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.AdagradOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.AdagradOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.AdamOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon', 'regularization', 'name', 'lazy_mode'], varargs=None, keywords=None, defaults=(0.001, 0.9, 0.999, 1e-08, None, None, False)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.AdamOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.AdamOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.AdamOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.AdamOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.AdamOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.AdamOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.AdamaxOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'beta1', 'beta2', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.9, 0.999, 1e-08, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.AdamaxOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.AdamaxOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.AdamaxOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.AdamaxOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.AdamaxOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.AdamaxOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.DecayedAdagradOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'decay', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.95, 1e-06, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.DecayedAdagradOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.DecayedAdagradOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.DecayedAdagradOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.DecayedAdagradOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.DecayedAdagradOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.DecayedAdagradOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.FtrlOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'l1', 'l2', 'lr_power', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.0, 0.0, -0.5, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.FtrlOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.FtrlOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.FtrlOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.FtrlOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.FtrlOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.FtrlOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.RMSPropOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'rho', 'epsilon', 'momentum', 'centered', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.95, 1e-06, 0.0, False, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.RMSPropOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.RMSPropOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.RMSPropOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.RMSPropOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.RMSPropOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.RMSPropOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.AdadeltaOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'epsilon', 'rho', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1e-06, 0.95, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.AdadeltaOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.AdadeltaOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.AdadeltaOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.AdadeltaOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.AdadeltaOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.AdadeltaOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.ModelAverage.__init__ (ArgSpec(args=['self', 'average_window_rate', 'min_average_window', 'max_average_window', 'regularization', 'name'], varargs=None, keywords=None, defaults=(10000, 10000, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.ModelAverage.apply (ArgSpec(args=['self', 'executor', 'need_restore'], varargs=None, keywords=None, defaults=(True,)), ('document', '46234a5470590feb336346f70a3db715')) paddle.fluid.optimizer.ModelAverage.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.ModelAverage.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.ModelAverage.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.ModelAverage.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.ModelAverage.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.ModelAverage.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.ModelAverage.restore (ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None), ('document', '18db9c70be9c4dd466f9844457b21bfe')) paddle.fluid.optimizer.LarsMomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'lars_coeff', 'lars_weight_decay', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.0005, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.LarsMomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.LarsMomentumOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.LarsMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.LarsMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.LarsMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.LarsMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.DGCMomentumOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'momentum', 'rampup_begin_step', 'rampup_step', 'sparsity', 'use_nesterov', 'local_grad_clip_norm', 'num_trainers', 'regularization', 'name'], varargs=None, keywords=None, defaults=(1, [0.999], False, None, None, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.DGCMomentumOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.DGCMomentumOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.DGCMomentumOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.DGCMomentumOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.DGCMomentumOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.LambOptimizer.__init__ (ArgSpec(args=['self', 'learning_rate', 'lamb_weight_decay', 'beta1', 'beta2', 'epsilon', 'regularization', 'name'], varargs=None, keywords=None, defaults=(0.001, 0.01, 0.9, 0.999, 1e-06, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.LambOptimizer.apply_gradients (ArgSpec(args=['self', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', 'bfe7305918552aaecfdaa22411dbe871')) paddle.fluid.optimizer.LambOptimizer.apply_optimize (ArgSpec(args=['self', 'loss', 'startup_program', 'params_grads'], varargs=None, keywords=None, defaults=None), ('document', '5c46d1926a40f1f873ffe9f37ac89dae')) paddle.fluid.optimizer.LambOptimizer.backward (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'callbacks'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'ba3a113d0229ff7bc9d39bda0a6d947f')) paddle.fluid.optimizer.LambOptimizer.get_opti_var_name_list (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) -paddle.fluid.optimizer.LambOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', '35fd5d3330c97903528c7e0dacc7f6ea')) +paddle.fluid.optimizer.LambOptimizer.minimize (ArgSpec(args=['self', 'loss', 'startup_program', 'parameter_list', 'no_grad_set', 'grad_clip'], varargs=None, keywords=None, defaults=(None, None, None, None)), ('document', 'b15cffad0903fc81af77a0580ceb2a9b')) paddle.fluid.optimizer.ExponentialMovingAverage.__init__ (ArgSpec(args=['self', 'decay', 'thres_steps', 'name'], varargs=None, keywords=None, defaults=(0.999, None, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.optimizer.ExponentialMovingAverage.apply (ArgSpec(args=['self', 'executor', 'need_restore'], varargs=None, keywords=None, defaults=(True,)), ('document', '30f494752ac8921dc5835a63637f453a')) paddle.fluid.optimizer.ExponentialMovingAverage.restore (ArgSpec(args=['self', 'executor'], varargs=None, keywords=None, defaults=None), ('document', '8c8a1791608b02a1ede53d6dd3a4fcec')) @@ -551,6 +551,9 @@ paddle.fluid.clip.ErrorClipByValue.__init__ (ArgSpec(args=['self', 'max', 'min'] paddle.fluid.clip.GradientClipByValue.__init__ (ArgSpec(args=['self', 'max', 'min'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.clip.GradientClipByNorm.__init__ (ArgSpec(args=['self', 'clip_norm'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.clip.GradientClipByGlobalNorm.__init__ (ArgSpec(args=['self', 'clip_norm', 'group_name'], varargs=None, keywords=None, defaults=('default_group',)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.dygraph_grad_clip.GradClipByValue.__init__ (ArgSpec(args=['self', 'min_value', 'max_value'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.dygraph_grad_clip.GradClipByNorm.__init__ (ArgSpec(args=['self', 'clip_norm'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) +paddle.fluid.dygraph_grad_clip.GradClipByGlobalNorm.__init__ (ArgSpec(args=['self', 'max_global_norm'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.profiler.cuda_profiler (ArgSpec(args=['output_file', 'output_mode', 'config'], varargs=None, keywords=None, defaults=(None, None)), ('document', '49f5db5da13cfd8c069754dd11be3901')) paddle.fluid.profiler.reset_profiler (ArgSpec(args=[], varargs=None, keywords=None, defaults=None), ('document', 'd33483b1781e47c4c5d5fefa7b7debcb')) paddle.fluid.profiler.profiler (ArgSpec(args=['state', 'sorted_key', 'profile_path'], varargs=None, keywords=None, defaults=(None, '/tmp/profile')), ('document', 'd8db46bf9a579bec476d09dea80eb23d')) diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 71ad2f0cf0..00f97389b7 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -54,6 +54,7 @@ from .transpiler import DistributeTranspiler, \ memory_optimize, release_memory, DistributeTranspilerConfig from .lod_tensor import create_lod_tensor, create_random_int_lodtensor from . import clip +from . import dygraph_grad_clip from . import profiler from . import unique_name from . import recordio_writer @@ -93,6 +94,7 @@ __all__ = framework.__all__ + executor.__all__ + \ 'WeightNormParamAttr', 'DataFeeder', 'clip', + 'dygraph_grad_clip', 'profiler', 'unique_name', 'recordio_writer', diff --git a/python/paddle/fluid/dygraph_grad_clip.py b/python/paddle/fluid/dygraph_grad_clip.py new file mode 100644 index 0000000000..bcc307511e --- /dev/null +++ b/python/paddle/fluid/dygraph_grad_clip.py @@ -0,0 +1,282 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import copy +import six + +import functools + +from . import layers +from . import framework +from . import core + +__all__ = [ + 'GradClipByValue', + 'GradClipByNorm', + 'GradClipByGlobalNorm', +] + + +class GradClipBase(object): + def __str__(self): + raise NotImplementedError() + + def _clip(self, para_and_grad): + raise NotImplementedError + + def __call__(self, para_and_grad): + return self._clip(para_and_grad) + + +class GradClipByValue(GradClipBase): + """ + Clips gradient values to the range [min_value, max_value]. + + Given a gradient g, this operation clips its value to min_value and max_value. + + - Any values less than min_value are set to min_value. + - Any values greater than max_value are set to max_value. + + Args: + max_value (float): The maximum value to clip by. + min (float, optional): The minimum value to clip by. if not set by user, \ + will be set to -max_value(max_value MUST be postive) by framework. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.fluid as fluid + + from paddle.fluid.dygraph.base import to_variable + from paddle.fluid.dygraph.nn import FC + + from paddle.fluid.clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm + + from paddle.fluid.optimizer import SGDOptimizer + + with fluid.dygraph.guard(): + value_clip = GradClipByValue( -1.0, 1.0 ) + sgd = SGDOptimizer(learning_rate=1.0) + + init_value = np.random.uniform( -1, 1, (10, 10)).astype('float32') + + fc = FC( "fc", 10) + + out = fc( to_variable(init_value) ) + + loss = fluid.layers.reduce_mean( out ) + + loss.backward() + sgd.minimize(loss, grad_clip = value_clip) + + """ + + def __init__(self, min_value, max_value=None): + + if min_value is None: + assert (max_value > 0.0) + min_value = -max_value + else: + min_value = float(min_value) + self.max_value = max_value + self.min_value = min_value + + def __str__(self): + return "ClipByValue, min = %f, max=%f" % (self.min_value, + self.max_value) + + def _clip(self, para_and_grad): + out = [] + for p, g in para_and_grad: + if g is None: + out.append((p, g)) + continue + + new_grad = layers.clip(x=g, min=self.min_value, max=self.max_value) + + out.append((p, new_grad)) + + return out + + +class GradClipByNorm(GradClipBase): + """ + Clips tensor values to a maximum L2-norm. + + This operator limits the L2 norm of the input :math:`X` within :math:`max\_norm`. + If the L2 norm of :math:`X` is less than or equal to :math:`max\_norm`, :math:`Out` + will be the same as :math:`X`. If the L2 norm of :math:`X` is greater than + :math:`max\_norm`, :math:`X` will be linearly scaled to make the L2 norm of + :math:`Out` equal to :math:`max\_norm`, as shown in the following formula: + + .. math:: + + Out = \\frac{max\_norm * X}{norm(X)}, + + where :math:`norm(X)` represents the L2 norm of :math:`X`. + + Args: + clip_norm (float): The maximum norm value + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.fluid as fluid + + from paddle.fluid.dygraph.base import to_variable + from paddle.fluid.dygraph.nn import FC + + from paddle.fluid.clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm + + from paddle.fluid.optimizer import SGDOptimizer + + with fluid.dygraph.guard(): + norm_clip = GradClipByNorm( 5.0 ) + sgd = SGDOptimizer(learning_rate=1.0) + + init_value = np.random.uniform( -1, 1, (10, 10)).astype('float32') + + fc = FC( "fc", 10) + + out = fc( to_variable(init_value) ) + + loss = fluid.layers.reduce_mean( out ) + + loss.backward() + sgd.minimize(loss, grad_clip = norm_clip) + + """ + + def __init__(self, clip_norm): + self.clip_norm = clip_norm + + def __str__(self): + return "ClipByNorm, clip_norm=%f" % self.clip_norm + + def _clip(self, para_and_grad): + out = [] + + for p, g in para_and_grad: + if g is None: + out.append((p, g)) + continue + new_g = layers.clip_by_norm(x=g, max_norm=self.clip_norm) + + out.append((p, new_g)) + + return out + + +class GradClipByGlobalNorm(GradClipBase): + """ + Clips values of multiple tensors by the ratio of the sum of their norms. + + Given a list of tensors t_list, and a clipping ratio clip_norm, this + operation returns a list of clipped tensors list_clipped and the global + norm (global_norm) of all tensors in t_list. + + To perform the clipping, the values :math:`t\_list[i]` are set to: + + .. math:: + + t\_list[i] = t\_list[i] * \\frac{clip\_norm}{\max(global\_norm, clip\_norm)} + + where: + + .. math:: + + global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2} + + If :math:`clip\_norm > global\_norm` then the entries in t_list remain as they are, + otherwise they're all shrunk by the global ratio. + + Args: + clip_norm (float): The maximum norm value + group_name (str, optional): The group name for this clip. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.fluid as fluid + + from paddle.fluid.dygraph.base import to_variable + from paddle.fluid.dygraph.nn import FC + + from paddle.fluid.clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm + + from paddle.fluid.optimizer import SGDOptimizer + + with fluid.dygraph.guard(): + gloabl_norm_clip = GradClipByGlobalNorm( 5.0 ) + sgd = SGDOptimizer(learning_rate=1.0) + + init_value = np.random.uniform( -1, 1, (10, 10)).astype('float32') + + fc = FC( "fc", 10) + + out = fc( to_variable(init_value) ) + + loss = fluid.layers.reduce_mean( out ) + + loss.backward() + sgd.minimize(loss, grad_clip = gloabl_norm_clip) + + + """ + + def __init__(self, max_global_norm): + self.max_global_norm = layers.fill_constant( + shape=[1], dtype='float32', value=max_global_norm) + + def __str__(self): + return "ClipByGlobalNorm, max_global_norm=%f" % (self.max_global_norm) + + def _clip(self, para_and_grad): + + out = [] + + norm_arr = [] + for p, g in para_and_grad: + if g is None: + continue + power = layers.square(g) + sum_t = layers.reduce_sum(power) + norm_arr.append(sum_t) + + norm_global = layers.concat(norm_arr) + norm_global = layers.reduce_sum(norm_global) + norm_global = layers.sqrt(norm_global) + + clip_scale = layers.elementwise_div( + x=self.max_global_norm, + y=layers.elementwise_max( + x=norm_global, y=self.max_global_norm)) + + for p, g in para_and_grad: + if g is None: + out.append((p, g)) + continue + new_grad = g * clip_scale + + out.append((p, new_grad)) + + return out diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 88587bdb41..f8c6683e32 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -463,6 +463,8 @@ class Optimizer(object): if framework.in_dygraph_mode(): with program_guard(framework.default_main_program(), framework.default_startup_program()): + params_grads = append_regularization_ops(params_grads, + self.regularization) optimize_ops = self._create_optimization_pass(params_grads) else: program = loss.block.program @@ -474,7 +476,8 @@ class Optimizer(object): loss, startup_program=None, parameter_list=None, - no_grad_set=None): + no_grad_set=None, + grad_clip=None): """ Add operations to minimize `loss` by updating `parameter_list`. @@ -487,6 +490,7 @@ class Optimizer(object): in `parameter_list`. parameter_list (list): list of Variables to update. no_grad_set (set|None): set of Variables should be ignored. + grad_clip (GradClipBase|None) : Gradient clip strategy Returns: tuple: (optimize_ops, params_grads) which are, list of operators appended; @@ -497,6 +501,11 @@ class Optimizer(object): startup_program=startup_program, parameter_list=parameter_list, no_grad_set=no_grad_set) + + if grad_clip is not None and framework.in_dygraph_mode(): + # TODO(hongyu): FIX later, this is only for dygraph, should be work for static mode + params_grads = grad_clip(params_grads) + optimize_ops = self.apply_optimize( loss, startup_program=startup_program, params_grads=params_grads) diff --git a/python/paddle/fluid/tests/unittests/test_grad_clip_minimize.py b/python/paddle/fluid/tests/unittests/test_grad_clip_minimize.py new file mode 100644 index 0000000000..fb80b5c1d2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_grad_clip_minimize.py @@ -0,0 +1,254 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import contextlib +import unittest +import numpy as np +import six + +import paddle +import paddle.fluid as fluid +from paddle.fluid import core + +from paddle.fluid.dygraph.base import to_variable + +from paddle.fluid.dygraph_grad_clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm + + +class TestGradClipByGlobalNorm(unittest.TestCase): + def init_value(self): + self.max_global_norm = 5.0 + self.init_scale = 1.0 + + self.shape = (20, 20) + + def generate_p_g(self): + + self.para_and_grad = [] + for i in range(10): + self.para_and_grad.append( + (np.random.uniform(-self.init_scale, self.init_scale, + self.shape).astype('float32'), + np.random.uniform(-self.init_scale, self.init_scale, + self.shape).astype('float32'))) + + def get_numpy_global_norm_result(self): + gloabl_norm = 0.0 + for p, g in self.para_and_grad: + gloabl_norm += np.sum(np.square(g)) + + gloabl_norm_np = np.sqrt(gloabl_norm) + + new_np_p_g = [] + scale = 1.0 + if gloabl_norm_np > self.max_global_norm: + scale = self.max_global_norm / gloabl_norm_np + + for p, g in self.para_and_grad: + new_np_p_g.append((p, g * scale)) + + return new_np_p_g + + def get_dygrap_global_norm_result(self): + with fluid.dygraph.guard(): + + gloabl_norm_clip = GradClipByGlobalNorm(self.max_global_norm) + p_g_var = [] + for p, g in self.para_and_grad: + new_p = to_variable(p) + new_g = to_variable(g) + p_g_var.append((new_p, new_g)) + + new_p_g_var = gloabl_norm_clip(p_g_var) + + p_g_dy_out = [] + for p, g in new_p_g_var: + p_g_dy_out.append((p.numpy(), g.numpy())) + + return p_g_dy_out + + def test_clip_by_global_norm(self): + self.init_value() + self.generate_p_g() + np_p_g = self.get_numpy_global_norm_result() + dy_out_p_g = self.get_dygrap_global_norm_result() + + for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g): + self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8)) + + def test_clip_by_global_norm_2(self): + self.init_value() + + self.init_scale = 0.2 + self.max_global_norm = 10 + self.generate_p_g() + np_p_g = self.get_numpy_global_norm_result() + dy_out_p_g = self.get_dygrap_global_norm_result() + + for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g): + self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8)) + + +class TestGradClipByNorm(unittest.TestCase): + def init_value(self): + self.max_norm = 5.0 + self.init_scale = 1.0 + + self.shape = (10, 10) + + def generate_p_g(self): + + self.para_and_grad = [] + for i in range(10): + self.para_and_grad.append( + (np.random.uniform(-self.init_scale, self.init_scale, + self.shape).astype('float32'), + np.random.uniform(-self.init_scale, self.init_scale, + self.shape).astype('float32'))) + + def get_numpy_norm_result(self): + + new_p_g = [] + for p, g in self.para_and_grad: + norm = np.sqrt(np.sum(np.square(g))) + + if norm > self.max_norm: + new_p_g.append((p, g * self.max_norm / norm)) + else: + new_p_g.append((p, g)) + + return new_p_g + + def get_dygrap_norm_result(self): + with fluid.dygraph.guard(): + + norm_clip = GradClipByNorm(self.max_norm) + p_g_var = [] + for p, g in self.para_and_grad: + new_p = to_variable(p) + new_g = to_variable(g) + p_g_var.append((new_p, new_g)) + + new_p_g_var = norm_clip(p_g_var) + + p_g_dy_out = [] + for p, g in new_p_g_var: + p_g_dy_out.append((p.numpy(), g.numpy())) + + return p_g_dy_out + + def test_clip_by_norm(self): + self.init_value() + self.generate_p_g() + np_p_g = self.get_numpy_norm_result() + dy_out_p_g = self.get_dygrap_norm_result() + + for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g): + self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8)) + + def test_clip_by_norm_2(self): + self.init_value() + + self.init_scale = 0.2 + self.max_norm = 10.0 + self.generate_p_g() + np_p_g = self.get_numpy_norm_result() + dy_out_p_g = self.get_dygrap_norm_result() + + for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g): + self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8)) + + +class TestGradClipByValue(unittest.TestCase): + def init_value(self): + self.max_value = 0.8 + self.min_value = -0.1 + self.init_scale = 1.0 + + self.shape = (10, 10) + + def generate_p_g(self): + + self.para_and_grad = [] + for i in range(10): + self.para_and_grad.append( + (np.random.uniform(-self.init_scale, self.init_scale, + self.shape).astype('float32'), + np.random.uniform(-self.init_scale, self.init_scale, + self.shape).astype('float32'))) + + def get_numpy_clip_result(self): + + new_p_g = [] + for p, g in self.para_and_grad: + new_p_g.append((p, np.clip(g, self.min_value, self.max_value))) + + return new_p_g + + def get_dygrap_clip_result(self): + with fluid.dygraph.guard(): + + value_clip = GradClipByValue(self.min_value, self.max_value) + p_g_var = [] + for p, g in self.para_and_grad: + new_p = to_variable(p) + new_g = to_variable(g) + p_g_var.append((new_p, new_g)) + + new_p_g_var = value_clip(p_g_var) + + p_g_dy_out = [] + for p, g in new_p_g_var: + p_g_dy_out.append((p.numpy(), g.numpy())) + + return p_g_dy_out + + def test_clip_by_value(self): + self.init_value() + self.generate_p_g() + np_p_g = self.get_numpy_clip_result() + dy_out_p_g = self.get_dygrap_clip_result() + + for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g): + self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8)) + + def test_clip_by_norm_2(self): + self.init_value() + + self.init_scale = 0.2 + self.generate_p_g() + np_p_g = self.get_numpy_clip_result() + dy_out_p_g = self.get_dygrap_clip_result() + + for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g): + self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8)) + + def test_clip_by_norm_3(self): + self.init_value() + + self.init_scale = 0.5 + self.max_value = 0.6 + self.min_value = None + self.generate_p_g() + np_p_g = self.get_numpy_clip_result() + dy_out_p_g = self.get_dygrap_clip_result() + + for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g): + self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8)) + + +if __name__ == '__main__': + unittest.main() -- GitLab