diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 7bba2b5d4aab89ef049530b52ea18045e98be34a..23a983d845bf26fabff10933f84525d0598f9571 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -137,9 +137,6 @@ class GradientClipBase(object): raise NotImplementedError def __call__(self, params_grads): - assert len( - params_grads - ) > 0, "The number of trainable parameters should be greater than 0." if framework.in_dygraph_mode(): return self._dygraph_clip(params_grads) else: @@ -147,7 +144,7 @@ class GradientClipBase(object): if getattr(p, 'gradient_clip_attr', None) is not None: warnings.warn( "'set_gradient_clip' will be ineffective, because you have " - "pass 'grad_clip' into 'minimize'. So, 'set_gradient_clip' " + "set 'grad_clip' in 'optimizer'. So, 'set_gradient_clip' " "is redundant and you can remove it.") break return self._static_clip(params_grads) @@ -170,7 +167,7 @@ class GradientClipByValue(GradientClipBase): The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip`` is not None, then only part of gradients can be selected for gradient clipping. - Gradient clip will takes effect after being set in ``optimizer.minimize(grad_clip)`` , see the document ``optimizer`` + Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` (for example: :ref:`api_fluid_optimizer_SGDOptimizer`). Args: @@ -208,8 +205,8 @@ class GradientClipByValue(GradientClipBase): # return Parameter.name=="fc_0.w_0" # clip = fluid.clip.GradientClipByValue(min=-1, max=1, need_clip=fileter_func) - sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1) - sgd_optimizer.minimize(loss, grad_clip=clip) + sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, grad_clip=clip) + sgd_optimizer.minimize(loss) place = fluid.CPUPlace() exe = fluid.Executor(place) @@ -242,8 +239,8 @@ class GradientClipByValue(GradientClipBase): # clip = fluid.clip.GradientClipByValue(min=-1, max=1, need_clip=fileter_func) sgd_optimizer = fluid.optimizer.SGD( - learning_rate=0.1, parameter_list=linear.parameters()) - sgd_optimizer.minimize(loss, grad_clip=clip) + learning_rate=0.1, parameter_list=linear.parameters(), grad_clip=clip) + sgd_optimizer.minimize(loss) """ def __init__(self, max, min=None, need_clip=None): @@ -272,6 +269,7 @@ class GradientClipByValue(GradientClipBase): def _static_clip(self, params_grads): params_and_grads = [] + param_new_grad_name_dict = dict() with framework.name_scope('gradient_clip'): for p, g in params_grads: if g is None: @@ -284,7 +282,8 @@ class GradientClipByValue(GradientClipBase): with p.block.program._optimized_guard([p, g]): new_grad = layers.clip(x=g, min=self.min, max=self.max) params_and_grads.append((p, new_grad)) - _correct_clip_op_role_var(params_and_grads) + param_new_grad_name_dict[p.name] = new_grad.name + _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict) return params_and_grads def _process_context(self, context, param, grad): @@ -306,7 +305,7 @@ class GradientClipByNorm(GradientClipBase): The multidimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip`` is not None, then only part of gradients can be selected for gradient clipping. - Gradient clip will takes effect after being set in ``optimizer.minimize(grad_clip)`` , see the document ``optimizer`` + Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` (for example: :ref:`api_fluid_optimizer_SGDOptimizer`). The clipping formula is: @@ -359,8 +358,8 @@ class GradientClipByNorm(GradientClipBase): # return Parameter.name=="fc_0.w_0" # clip = fluid.clip.GradientClipByNorm(clip_norm=1.0, need_clip=fileter_func) - sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1) - sgd_optimizer.minimize(loss, grad_clip=clip) + sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, grad_clip=clip) + sgd_optimizer.minimize(loss) place = fluid.CPUPlace() exe = fluid.Executor(place) @@ -394,8 +393,8 @@ class GradientClipByNorm(GradientClipBase): # clip = fluid.clip.GradientClipByNorm(clip_norm=1.0, need_clip=fileter_func) sgd_optimizer = fluid.optimizer.SGD( - learning_rate=0.1, parameter_list=linear.parameters()) - sgd_optimizer.minimize(loss, grad_clip=clip) + learning_rate=0.1, parameter_list=linear.parameters(), grad_clip=clip) + sgd_optimizer.minimize(loss) """ @@ -422,6 +421,7 @@ class GradientClipByNorm(GradientClipBase): def _static_clip(self, params_grads): params_and_grads = [] with framework.name_scope('gradient_clip'): + param_new_grad_name_dict = dict() for p, g in params_grads: if g is None: continue @@ -432,8 +432,9 @@ class GradientClipByNorm(GradientClipBase): with p.block.program._optimized_guard([p, g]): new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm) + param_new_grad_name_dict[p.name] = new_grad.name params_and_grads.append((p, new_grad)) - _correct_clip_op_role_var(params_and_grads) + _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict) return params_and_grads def _process_context(self, context, param, grad): @@ -456,7 +457,7 @@ class GradientClipByGlobalNorm(GradientClipBase): The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters in ``Program`` . If ``need_clip`` is not None, then only part of gradients can be selected for gradient clipping. - Gradient clip will takes effect after being set in ``optimizer.minimize(grad_clip)`` , see the document ``optimizer`` + Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` (for example: :ref:`api_fluid_optimizer_SGDOptimizer`). The clipping formula is: @@ -505,8 +506,8 @@ class GradientClipByGlobalNorm(GradientClipBase): # return Parameter.name=="fc_0.w_0" # clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=fileter_func) - sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1) - sgd_optimizer.minimize(loss, grad_clip=clip) + sgd_optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.1, grad_clip=clip) + sgd_optimizer.minimize(loss) place = fluid.CPUPlace() exe = fluid.Executor(place) @@ -539,8 +540,8 @@ class GradientClipByGlobalNorm(GradientClipBase): # clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0, need_clip=fileter_func) sgd_optimizer = fluid.optimizer.SGD( - learning_rate=0.1, parameter_list=linear.parameters()) - sgd_optimizer.minimize(loss, grad_clip=clip) + learning_rate=0.1, parameter_list=linear.parameters(), grad_clip=clip) + sgd_optimizer.minimize(loss) """ @@ -628,6 +629,7 @@ class GradientClipByGlobalNorm(GradientClipBase): y=layers.elementwise_max( x=max_global_norm, y=global_norm_var)) + param_new_grad_name_dict = dict() for p, g in params_grads: if g is None: continue @@ -638,9 +640,10 @@ class GradientClipByGlobalNorm(GradientClipBase): with p.block.program._optimized_guard([p, g]): new_grad = layers.elementwise_mul(x=g, y=scale_var) + param_new_grad_name_dict[p.name] = new_grad.name params_and_grads.append((p, new_grad)) - _correct_clip_op_role_var(params_and_grads) + _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict) return params_and_grads def _process_context(self, context, param, grad): @@ -692,9 +695,10 @@ def set_gradient_clip(clip, param_list=None, program=None): This API must be used after building network, and before ``minimize`` , and it may be removed in future releases, so it is not recommended. - It is recommended to use ``minimize(loss, grad_clip=clip)`` to clip gradient. - There are three clipping strategies: :ref:`api_fluid_clip_GradientClipByGlobalNorm` , - :ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByValue` . + It is recommended to set ``grad_clip`` when initializing the ``optimizer`` , + this is a better method to clip gradient. There are three clipping strategies: + :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` . To specify parameters that require gradient clip. @@ -757,7 +761,7 @@ def set_gradient_clip(clip, param_list=None, program=None): sgd = fluid.optimizer.SGD(learning_rate=1e-3) sgd.minimize(loss) - # network 4: use 'set_gradient_clip' and 'minimize(grad_clip=clip)' together + # network 4: use 'set_gradient_clip' and 'optimize(grad_clip=clip)' together with fluid.program_guard(fluid.Program(), fluid.Program()): loss = network() clip1 = fluid.clip.GradientClipByValue(min=-1.0, max=1.0) @@ -765,8 +769,8 @@ def set_gradient_clip(clip, param_list=None, program=None): # Set the gradient clipping strategy: clip1 fluid.clip.set_gradient_clip(clip1) # Set the gradient clipping strategy: clip2 - sgd = fluid.optimizer.SGD(learning_rate=1e-3) - sgd.minimize(loss, grad_clip=clip2) + sgd = fluid.optimizer.SGD(learning_rate=1e-3, grad_clip=clip2) + sgd.minimize(loss) # 'set_gradient_clip' will not take effect when setting has a conflict, # and the gradient clipping strategy will be 'clip2' @@ -774,10 +778,10 @@ def set_gradient_clip(clip, param_list=None, program=None): """ warnings.warn("Caution! 'set_gradient_clip' is not recommended " "and may be deprecated in future! " - "We recommend a new strategy: clip gradient by " - "'optimizer.minimize(loss, grad_clip=clip)'. " + "We recommend a new strategy: set 'grad_clip' " + "when initializing the 'optimizer'. " "This method can reduce the mistakes, please " - "see documention of 'optimzier.minimize'.") + "refer to documention of 'optimizer'.") if not isinstance(clip, GradientClipBase): raise TypeError( @@ -824,33 +828,40 @@ def append_gradient_clip_ops(param_grads): clip_attr._process_context(context=context, param=p, grad=g) res = [] + param_new_grad_name_dict = dict() for p, g in param_grads: if g is None: continue with p.block.program._optimized_guard( [p, g]), framework.name_scope('graident_clip_@CLIP'): param, new_grad = clip_attr._create_operators(param=p, grad=g) + param_new_grad_name_dict[param.name] = new_grad.name res.append([param, new_grad]) - _correct_clip_op_role_var(res) + _correct_clip_op_role_var(res, param_new_grad_name_dict) return res # change wrong mapping relation between param & grad in clip op -def _correct_clip_op_role_var(params_grads): +def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict): + block_id_list = [] + if len(param_new_grad_name_dict) == 0: + return for param, grad in params_grads: if grad is None: continue + block_id = param.block.idx + if block_id in block_id_list: + continue + block_id_list.append(block_id) for op in param.block.program.global_block().ops: if 'op_namescope' in op.all_attrs() and "gradient_clip" in op.attr( - "op_namescope"): - if op.attr('op_role_var'): - param_name = op.attr('op_role_var')[0] - index = 0 - for i in range(len(params_grads)): - if params_grads[i][0].name == param_name: - index = i - correct_p_g = [param_name, params_grads[index][1].name] + "op_namescope") and op.attr('op_role_var'): + param_name = op.attr('op_role_var')[0] + if param_name in param_new_grad_name_dict: + correct_p_g = [ + param_name, param_new_grad_name_dict[param_name] + ] op._set_attr('op_role_var', correct_p_g) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 11f14051f14b8a90d38142173fb8b7fb422d8f96..38ffbc74b16bdb9c2ba6350b67ab3deb843d2f75 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -24,7 +24,7 @@ from . import framework from . import layers from . import unique_name from .backward import append_backward, _some_in_set_, _append_grad_suffix_, _get_no_grad_set_name -from .clip import GradientClipBase, error_clip_callback, append_gradient_clip_ops +from .clip import GradientClipBase, GradientClipByNorm, error_clip_callback, append_gradient_clip_ops from .framework import program_guard from .initializer import Constant from .layer_helper import LayerHelper @@ -64,6 +64,7 @@ class Optimizer(object): learning_rate, parameter_list=None, regularization=None, + grad_clip=None, name=None): self._parameter_list = parameter_list if framework.in_dygraph_mode(): @@ -88,7 +89,13 @@ class Optimizer(object): type(learning_rate)) self._name = name + if grad_clip is not None: + if not isinstance(grad_clip, GradientClipBase): + raise TypeError( + "'grad_clip' should be an instance of GradientClipBase's derived class" + ) self.regularization = regularization + self._grad_clip = grad_clip self._learning_rate = learning_rate # the learning rate type should be inferenced from loss self._dtype = None @@ -107,8 +114,6 @@ class Optimizer(object): self._opti_name_list = [] self._accumulators_holder = {} self._param_device_map = dict() - # if pass grad_clip into minimize, it will not be None - self._grad_clip = None @framework.dygraph_only def state_dict(self): @@ -787,8 +792,7 @@ class Optimizer(object): loss, startup_program=None, parameter_list=None, - no_grad_set=None, - grad_clip=None): + no_grad_set=None): """ Add operations to minimize ``loss`` by updating ``parameter_list``. @@ -801,12 +805,7 @@ class Optimizer(object): to minimize ``loss``. The default value is None, at this time all parameters will be updated. no_grad_set (set, optional): Set of ``Variable`` or ``Variable.name`` that don't need - to be updated. The default value is None. - grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of - some derived class of ``GradientClipBase`` . There are three cliping strategies - ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , - :ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no - gradient clipping. + to be updated. The default value is None. Returns: tuple: tuple (optimize_ops, params_grads), A list of operators appended @@ -820,12 +819,7 @@ class Optimizer(object): Please refer to the example of current Optimizer. """ assert isinstance(loss, Variable), "The loss should be an Variable." - if grad_clip is not None: - if not isinstance(grad_clip, GradientClipBase): - raise TypeError( - "'grad_clip' should be an instance of GradientClipBase's derived class" - ) - self._grad_clip = grad_clip + parameter_list = parameter_list if parameter_list \ else self._parameter_list params_grads = self.backward( @@ -859,6 +853,10 @@ class SGDOptimizer(Optimizer): regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. name (str, optional): This parameter is used by developers to print debugging information. \ For details, please refer to :ref:`api_guide_Name`. Default is None. @@ -896,12 +894,14 @@ class SGDOptimizer(Optimizer): learning_rate, parameter_list=None, regularization=None, + grad_clip=None, name=None): assert learning_rate is not None super(SGDOptimizer, self).__init__( learning_rate=learning_rate, parameter_list=parameter_list, regularization=regularization, + grad_clip=grad_clip, name=name) self.type = "sgd" @@ -962,6 +962,10 @@ class MomentumOptimizer(Optimizer): regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. name (str, optional): This parameter is used by developers to print debugging information. \ For details, please refer to :ref:`api_guide_Name`. Default is None. @@ -1002,6 +1006,7 @@ class MomentumOptimizer(Optimizer): parameter_list=None, use_nesterov=False, regularization=None, + grad_clip=None, name=None): assert learning_rate is not None assert momentum is not None @@ -1009,6 +1014,7 @@ class MomentumOptimizer(Optimizer): learning_rate=learning_rate, parameter_list=parameter_list, regularization=regularization, + grad_clip=grad_clip, name=name) self.type = "momentum" self._momentum = momentum @@ -1097,13 +1103,14 @@ class DGCMomentumOptimizer(Optimizer): This parameter is required in dygraph mode. \ The default value is None in static mode, at this time all parameters will be updated. use_nesterov (bool): Enables Nesterov momentum. True means use Nesterov. Default is False. - local_grad_clip_norm (float, optional): Local gradient clip norm value. Optional, default is None, represent no need clip. - num_trainers (int, optional): The number of training nodes. Optional, default is None. regularization (WeightDecayRegularizer, optional): The strategy of regularization. There are two method: \ :ref:`api_fluid_regularizer_L1Decay` , :ref:`api_fluid_regularizer_L2Decay` . If a parameter has set \ regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ Default None, meaning there is no regularization. + grad_clip (GradientClipByNorm, optional): Gradient cliping strategy. ``DGCMomentumOptimizer`` only support + :ref:`api_fluid_clip_GradientClipByNorm` , and if not, it will raise TypeError. Default None, + meaning there is no gradient clipping. name (str, optional): This parameter is used by developers to print debugging information. \ For details, please refer to :ref:`api_guide_Name`. Default is None. @@ -1130,9 +1137,9 @@ class DGCMomentumOptimizer(Optimizer): sparsity=[0.999], parameter_list=None, use_nesterov=False, - local_grad_clip_norm=None, num_trainers=None, regularization=None, + grad_clip=None, name=None): if framework.in_dygraph_mode(): raise Exception("In dygraph, don't support DGCMomentumOptimizer.") @@ -1146,6 +1153,7 @@ class DGCMomentumOptimizer(Optimizer): learning_rate=learning_rate, parameter_list=parameter_list, regularization=regularization, + grad_clip=grad_clip, name=name) self.type = "dgc_momentum" self._momentum = momentum @@ -1159,20 +1167,23 @@ class DGCMomentumOptimizer(Optimizer): self._rampup_begin_step_var = None self._global_step_var = None - self._local_grad_clip_norm = None - self._clip_norm = None - if local_grad_clip_norm is not None: - assert isinstance(num_trainers, int) - assert isinstance(local_grad_clip_norm, float) - assert num_trainers > 0 + self._dgc_clip_norm = None + if grad_clip is not None: + if not isinstance(grad_clip, GradientClipByNorm): + raise TypeError( + "The type of grad_clip should be 'GradientClipByNorm', because DGCMomentumOptimizer only support GradientClipByNorm" + ) + assert isinstance( + num_trainers, int + ), "The type of num_trainers should be 'int', but received %s" % type( + value) + assert num_trainers > 0, "The value of num_trainers should be greater than 0!" - self._local_grad_clip_norm = local_grad_clip_norm self._num_trainers = num_trainers - self._clip_norm = local_grad_clip_norm * (num_trainers**-0.5) + self._dgc_clip_norm = grad_clip.clip_norm * (num_trainers**-0.5) self.regular_type, self.regular_coeff = self._get_regularization_param( self.regularization) - self._grad_clip = None def _get_regularization_param(self, regularization): regular_type = 0 @@ -1342,8 +1353,8 @@ class DGCMomentumOptimizer(Optimizer): op._remove_attr(op_maker.kOpRoleVarAttrName()) clip_var = grad_var - if self._local_grad_clip_norm is not None: - clip_var = self._append_clip_norm(grad_var, self._clip_norm) + if self._dgc_clip_norm is not None: + clip_var = self._append_clip_norm(grad_var, self._dgc_clip_norm) self._dgc_op(param_var, clip_var, grad_var, u_var, v_var, k_var, encoded_var, gather_var) @@ -1494,6 +1505,10 @@ class LarsMomentumOptimizer(Optimizer): regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. name (str, optional): This parameter is used by developers to print debugging information. \ For details, please refer to :ref:`api_guide_Name`. Default is None. @@ -1526,6 +1541,7 @@ class LarsMomentumOptimizer(Optimizer): lars_weight_decay=0.0005, parameter_list=None, regularization=None, + grad_clip=None, name=None): assert learning_rate is not None assert momentum is not None @@ -1533,6 +1549,7 @@ class LarsMomentumOptimizer(Optimizer): learning_rate=learning_rate, parameter_list=parameter_list, regularization=regularization, + grad_clip=grad_clip, name=name) self.type = "lars_momentum" self._momentum = momentum @@ -1607,6 +1624,10 @@ class AdagradOptimizer(Optimizer): regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. The default value is None. @@ -1639,6 +1660,7 @@ class AdagradOptimizer(Optimizer): epsilon=1.0e-6, parameter_list=None, regularization=None, + grad_clip=None, name=None, initial_accumulator_value=0.0): assert learning_rate is not None @@ -1647,6 +1669,7 @@ class AdagradOptimizer(Optimizer): learning_rate=learning_rate, parameter_list=parameter_list, regularization=regularization, + grad_clip=grad_clip, name=name) self.type = "adagrad" self._epsilon = epsilon @@ -1726,6 +1749,10 @@ class AdamOptimizer(Optimizer): regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. The default value is None. @@ -1835,6 +1862,7 @@ class AdamOptimizer(Optimizer): epsilon=1e-8, parameter_list=None, regularization=None, + grad_clip=None, name=None, lazy_mode=False): assert learning_rate is not None @@ -1845,6 +1873,7 @@ class AdamOptimizer(Optimizer): learning_rate=learning_rate, parameter_list=parameter_list, regularization=regularization, + grad_clip=grad_clip, name=name) self.type = "adam" self._beta1 = beta1 @@ -1986,6 +2015,10 @@ class AdamaxOptimizer(Optimizer): regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. The default value is None. @@ -2031,6 +2064,7 @@ class AdamaxOptimizer(Optimizer): epsilon=1e-8, parameter_list=None, regularization=None, + grad_clip=None, name=None): assert learning_rate is not None assert beta1 is not None @@ -2040,6 +2074,7 @@ class AdamaxOptimizer(Optimizer): learning_rate=learning_rate, parameter_list=parameter_list, regularization=regularization, + grad_clip=grad_clip, name=name) self.type = "adamax" self._beta1 = beta1 @@ -2238,6 +2273,10 @@ class DecayedAdagradOptimizer(Optimizer): regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. name (str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. The default value is None. @@ -2264,6 +2303,7 @@ class DecayedAdagradOptimizer(Optimizer): epsilon=1.0e-6, parameter_list=None, regularization=None, + grad_clip=None, name=None): assert learning_rate is not None assert decay is not None @@ -2273,6 +2313,7 @@ class DecayedAdagradOptimizer(Optimizer): learning_rate=learning_rate, parameter_list=parameter_list, regularization=regularization, + grad_clip=grad_clip, name=name) self.type = "decayed_adagrad" self._decay = decay @@ -2337,6 +2378,10 @@ class AdadeltaOptimizer(Optimizer): regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . @@ -2367,6 +2412,7 @@ class AdadeltaOptimizer(Optimizer): rho=0.95, parameter_list=None, regularization=None, + grad_clip=None, name=None): if learning_rate is None: raise ValueError("learning_rate is not set.") @@ -2378,6 +2424,7 @@ class AdadeltaOptimizer(Optimizer): learning_rate=learning_rate, parameter_list=parameter_list, regularization=regularization, + grad_clip=grad_clip, name=name) self.type = "adadelta" self._epsilon = epsilon @@ -2488,6 +2535,10 @@ class RMSPropOptimizer(Optimizer): regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. name (str, optional): This parameter is used by developers to print debugging information. \ For details, please refer to :ref:`api_guide_Name`. Default is None. @@ -2536,11 +2587,13 @@ class RMSPropOptimizer(Optimizer): centered=False, parameter_list=None, regularization=None, + grad_clip=None, name=None): super(RMSPropOptimizer, self).__init__( learning_rate=learning_rate, parameter_list=parameter_list, regularization=regularization, + grad_clip=grad_clip, name=name) if learning_rate is None: raise ValueError("learning_rate is not set.") @@ -2656,6 +2709,10 @@ class FtrlOptimizer(Optimizer): regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. name (str, optional): This parameter is used by developers to print debugging information. \ For details, please refer to :ref:`api_guide_Name`. Default is None. @@ -2704,11 +2761,13 @@ class FtrlOptimizer(Optimizer): lr_power=-0.5, parameter_list=None, regularization=None, + grad_clip=None, name=None): super(FtrlOptimizer, self).__init__( learning_rate=learning_rate, parameter_list=parameter_list, regularization=regularization, + grad_clip=grad_clip, name=name) if learning_rate is None: raise ValueError("learning_rate is not set.") @@ -2798,6 +2857,10 @@ class LambOptimizer(AdamOptimizer): regularizer using :ref:`api_fluid_ParamAttr` already, the regularization setting here in optimizer will be \ ignored for this parameter. Otherwise, the regularization setting here in optimizer will take effect. \ Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping. exclude_from_weight_decay_fn (function|None): Exclude a parameter from weight decay when **exclude_from_weight_decay_fn(parameter)** returns true. Default None. @@ -2834,6 +2897,7 @@ class LambOptimizer(AdamOptimizer): epsilon=1e-6, parameter_list=None, regularization=None, + grad_clip=None, exclude_from_weight_decay_fn=None, name=None): assert learning_rate is not None @@ -2845,6 +2909,7 @@ class LambOptimizer(AdamOptimizer): learning_rate=learning_rate, parameter_list=parameter_list, regularization=regularization, + grad_clip=grad_clip, beta1=beta1, beta2=beta2, epsilon=epsilon, @@ -4046,20 +4111,13 @@ class RecomputeOptimizer(Optimizer): loss, startup_program=None, parameter_list=None, - no_grad_set=None, - grad_clip=None): + no_grad_set=None): assert isinstance(loss, Variable), "The loss should be an Variable." assert (self._checkpoints is not None ), "You should call _set_checkpoints first" if framework.in_dygraph_mode(): raise NotImplementedError( "DyGraph current does not support recompute") - if grad_clip is not None: - if not isinstance(grad_clip, GradientClipBase): - raise TypeError( - "'grad_clip' should be an instance of GradientClipBase's derived class" - ) - self._optimizer._grad_clip = grad_clip params_grads = self.backward( loss, startup_program=startup_program, diff --git a/python/paddle/fluid/param_attr.py b/python/paddle/fluid/param_attr.py index 7f2de02ecf738ff750fb10beec548e615b015b84..c6aed725b405e8975c3de60a0131e2a7c6463bb8 100644 --- a/python/paddle/fluid/param_attr.py +++ b/python/paddle/fluid/param_attr.py @@ -36,7 +36,7 @@ class ParamAttr(object): Note: ``gradient_clip`` of ``ParamAttr`` HAS BEEN DEPRECATED since 2.0. - It is recommended to use ``minimize(loss, grad_clip=clip)`` to clip gradient. + It is recommended to set ``grad_clip`` in ``optimizer`` to clip gradient. There are three clipping strategies: :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByValue` . diff --git a/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py b/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py index 29050710c62fa9af08f7e5e0a2a18588a241cdba..9f4e9ccd0d38d96c6f7a7a402abfc1ae8d588aba 100644 --- a/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_dgc_optimizer.py @@ -19,6 +19,7 @@ import unittest import paddle.fluid.framework as framework import paddle.fluid.optimizer as optimizer import paddle.fluid.regularizer as regularizer +import paddle.fluid.clip as clip import paddle.compat as cpt from paddle.fluid.backward import append_backward from paddle.fluid.transpiler.details import program_to_code @@ -70,9 +71,9 @@ class TestDGCMomentumOptimizer(unittest.TestCase): learning_rate=learning_rate, momentum=0.2, rampup_begin_step=0, - local_grad_clip_norm=1.0, num_trainers=2, - regularization=regularization) + regularization=regularization, + grad_clip=clip.GradientClipByNorm(1.0)) if use_recompute: dgc_momentum_optimizer = optimizer.RecomputeOptimizer( @@ -124,6 +125,16 @@ class TestDGCMomentumOptimizer(unittest.TestCase): #with open("test_dgc_optimizer_" + name + str(use_recompute) + ".log", "w") as f: # program_to_code(program, fout=f) + def test_tpyeError(self): + # the type of DGCMomentumOptimizer(grad_clip=) must be 'GradientClipByNorm' + with self.assertRaises(TypeError): + dgc_momentum_optimizer = self.MockDGCMomentum( + learning_rate=0.01, + momentum=0.2, + rampup_begin_step=0, + num_trainers=2, + grad_clip=clip.GradientClipByGlobalNorm(1.0)) + def test_momentum_without_dgc(self): self.check_dgc_momentum_optimizer( regularization=regularizer.L1Decay(1e-4)) diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py index 362b527d456b7c12781cc1535258ef7094f0f2b0..cc54e680c7525a044892a35251862cc7d10be5c4 100644 --- a/python/paddle/fluid/tests/unittests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -76,8 +76,8 @@ class TestGradientClip(unittest.TestCase): startup_program = fluid.Program() with fluid.program_guard( main_program=prog, startup_program=startup_program): - image = fluid.data(name='x', shape=[-1, 784], dtype='float32') - label = fluid.data(name='y', shape=[-1, 1], dtype='int64') + image = fluid.data(name="a", shape=[-1, 784], dtype='float32') + label = fluid.data(name="b", shape=[-1, 1], dtype='int64') hidden = fluid.layers.fc(input=image, size=32, act='relu') predict = fluid.layers.fc(input=hidden, size=10, act='softmax') @@ -112,13 +112,13 @@ class TestGradientClip(unittest.TestCase): self.check_clip_result(out, out_clip) def check_sparse_gradient_clip(self, place): - prog = fluid.framework.Program() - startup_program = fluid.framework.Program() + prog = fluid.Program() + startup_program = fluid.Program() with fluid.program_guard( main_program=prog, startup_program=startup_program): - data = fluid.layers.data( - name="words", shape=[1], dtype="int64", lod_level=1) - label = fluid.layers.data(name="label", shape=[1], dtype="int64") + data = fluid.data( + name="words", shape=[-1, 1], dtype="int64", lod_level=1) + label = fluid.data(name="label", shape=[-1, 1], dtype="int64") cost = bow_net(data, label, self.word_dict_len) self.backward_and_optimize(cost) @@ -172,7 +172,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip): self.clip_gradient = func self.check_gradient_clip(fluid.CPUPlace()) - # test whether the ouput is right when use 'minimize(grad_clip)' + # test whether the ouput is right when use grad_clip def test_new_gradient_clip(self): def func(params_grads): clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm) @@ -192,9 +192,10 @@ class TestGradientClipByGlobalNorm(TestGradientClip): clip = fluid.clip.GradientClipByGlobalNorm( clip_norm=5.0, need_clip=fileter_func) fluid.clip.set_gradient_clip(clip) - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01) - # if 'set_gradient_clip' and 'minimize(grad_clip)' together, 'set_gradient_clip' will be ineffective - sgd_optimizer.minimize(cost, grad_clip=clip) + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01, + grad_clip=clip) + # if 'set_gradient_clip' and 'optimize(grad_clip)' together, 'set_gradient_clip' will be ineffective + sgd_optimizer.minimize(cost) # 'set_gradient_clip' must before 'minimize', otherwise, 'set_gradient_clip' will be ineffective fluid.clip.set_gradient_clip(clip) @@ -232,24 +233,10 @@ class TestGradientClipByGlobalNorm(TestGradientClip): clip = fluid.clip.GradientClipByGlobalNorm( clip_norm=self.clip_norm, need_clip="test") - # the type of minimize(grad_clip=) must be an instance of GradientClipBase's derived class - with self.assertRaises(TypeError): - x = fluid.default_main_program().global_block().create_parameter( - name="x", shape=[2, 3], dtype="float32") - loss = fluid.layers.reduce_mean(x) - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1) - sgd_optimizer.minimize(loss, grad_clip="test") - - # the type of RecomputeOptimizer.minimize(grad_clip=) must be an instance of GradientClipBase's derived class + # the type of optimizer(grad_clip=) must be an instance of GradientClipBase's derived class with self.assertRaises(TypeError): - x = fluid.default_main_program().global_block().create_parameter( - name="x", shape=[2, 3], dtype="float32") - loss = fluid.layers.reduce_mean(x) - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1) - recompute_optimizer = fluid.optimizer.RecomputeOptimizer( - sgd_optimizer) - recompute_optimizer._set_checkpoints([x]) - recompute_optimizer.minimize(loss, grad_clip="test") + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1, + grad_clip="test") class TestGradientClipByNorm(TestGradientClip): @@ -271,7 +258,7 @@ class TestGradientClipByNorm(TestGradientClip): a=u, b=v, rtol=1e-5, atol=1e-8), "gradient clip by norm has wrong results!") - # test whether the ouput is right when use 'minimize(grad_clip)' + # test whether the ouput is right when use grad_clip def test_gradient_clip(self): self.check_gradient_clip(fluid.CPUPlace()) @@ -319,7 +306,7 @@ class TestGradientClipByValue(TestGradientClip): a=u, b=v, rtol=1e-6, atol=1e-8), "gradient clip by value has wrong results!") - # test whether the ouput is right when use 'minimize(grad_clip)' + # test whether the ouput is right when use grad_clip def test_gradient_clip(self): self.check_gradient_clip(fluid.CPUPlace()) @@ -357,7 +344,9 @@ class TestDygraphGradientClip(unittest.TestCase): loss = fluid.layers.reduce_mean(out) loss.backward() sgd_optimizer = fluid.optimizer.SGD( - learning_rate=0.0, parameter_list=linear.parameters()) + learning_rate=0.0, + parameter_list=linear.parameters(), + grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1)) self.check_clip_result(loss, sgd_optimizer) def check_clip_result(self, loss, optimizer): @@ -384,7 +373,7 @@ class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip): np.array([3, 4]).astype("float32"), name="y") assert len(self.clip1([(x, x), (x, y), (x, None)])) == 2 # get params and grads from network - opt, params_grads = optimizer.minimize(loss, grad_clip=self.clip2) + opt, params_grads = optimizer.minimize(loss) _, grads = zip(*params_grads) params_grads = self.clip2(params_grads) _, grads_clip = zip(*params_grads) @@ -426,7 +415,7 @@ class TestDygraphGradientClipByNorm(TestDygraphGradientClip): assert len(self.clip([(x, None)])) == 0 # get params and grads from network self.clip([(fluid.dygraph.to_variable(np.array([2, 3])), None)]) - params_grads = optimizer.backward(loss) + opt, params_grads = optimizer.minimize(loss) _, grads = zip(*params_grads) params_grads = self.clip(params_grads) _, grads_clip = zip(*params_grads) @@ -460,7 +449,7 @@ class TestDygraphGradientClipByValue(TestDygraphGradientClip): x = fluid.dygraph.to_variable(np.array([2, 3]).astype("float32")) assert len(self.clip([(x, None)])) == 0 # get params and grads from network - params_grads = optimizer.backward(loss) + opt, params_grads = optimizer.minimize(loss) _, grads = zip(*params_grads) params_grads = self.clip(params_grads) _, grads_clip = zip(*params_grads) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py index 3d9b4e2ef27fa4b6724eff4347d3a20023313d97..2a25bf6f8abade11d9ad25894753f6d17066e7fd 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py @@ -329,9 +329,9 @@ class TestImperativeAutoPrune(unittest.TestCase): place = fluid.CPUPlace() with fluid.dygraph.guard(place): model = MyLayer(size, vocab_size, size) - optimizer = fluid.optimizer.AdamOptimizer( - 0.001, parameter_list=model.parameters()) grad_clip = fluid.clip.GradientClipByGlobalNorm(0.001) + optimizer = fluid.optimizer.AdamOptimizer( + 0.001, parameter_list=model.parameters(), grad_clip=grad_clip) indices = fluid.dygraph.to_variable(indices) embed = fluid.dygraph.to_variable(embed) @@ -339,7 +339,7 @@ class TestImperativeAutoPrune(unittest.TestCase): loss = model.embed_linear0(indices) loss.backward() - _, params_grads = optimizer.minimize(loss, grad_clip=grad_clip) + _, params_grads = optimizer.minimize(loss) for items in params_grads: assert items[0].name is not model.embed1.weight.name assert items[0].name is not model.linear_1.weight.name @@ -348,9 +348,9 @@ class TestImperativeAutoPrune(unittest.TestCase): with fluid.dygraph.guard(place): model = MyLayer2(size, vocab_size, size) - optimizer = fluid.optimizer.AdamOptimizer( - 0.001, parameter_list=model.parameters()) grad_clip = fluid.clip.GradientClipByGlobalNorm(0.001) + optimizer = fluid.optimizer.AdamOptimizer( + 0.001, parameter_list=model.parameters(), grad_clip=grad_clip) indices = fluid.dygraph.to_variable(indices) emebd = fluid.dygraph.to_variable(embed) @@ -358,7 +358,7 @@ class TestImperativeAutoPrune(unittest.TestCase): loss = model.embed_linear0(indices) loss.backward() - optimizer.minimize(loss, grad_clip=grad_clip) + optimizer.minimize(loss) for items in params_grads: assert items[0].name is not model.embed1.weight.name assert items[0].name is not model.linear_1.weight.name diff --git a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py index 3c4e959c5cc1ff1fb032e4cac1a88307b78028e2..cfaca5a565deb4ac29542c76f4b9e7b4f8ec1431 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_selected_rows.py @@ -58,14 +58,15 @@ class TestSimpleNet(unittest.TestCase): simplenet = SimpleNet(20, 32, dtype) adam = SGDOptimizer( learning_rate=0.001, - parameter_list=simplenet.parameters()) + parameter_list=simplenet.parameters( + )) # grad_clip=grad_clip input_emb, emb = simplenet(input) self.assertTrue(emb.weight.gradient() is None) self.assertTrue(input_emb.gradient() is None) input_emb.backward(backward_strategy) - adam.minimize(input_emb) # grad_clip=grad_clip + adam.minimize(input_emb) self.assertTrue(emb.weight.gradient() is not None) emb.clear_gradients() @@ -92,14 +93,15 @@ class TestSimpleNet(unittest.TestCase): simplenet = SimpleNet(20, 32, "float32") adam = SGDOptimizer( learning_rate=0.001, - parameter_list=simplenet.parameters()) + parameter_list=simplenet.parameters(), + grad_clip=grad_clip) input_emb, emb = simplenet(input) self.assertTrue(emb.weight.gradient() is None) self.assertTrue(input_emb.gradient() is None) input_emb.backward(backward_strategy) - adam.minimize(input_emb, grad_clip=grad_clip) + adam.minimize(input_emb) self.assertTrue(emb.weight.gradient() is not None) emb.clear_gradients()