From a526b3e03335ea6ec47329e96e3a769d46341f91 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Thu, 10 Jun 2021 10:11:21 +0800 Subject: [PATCH] fuse L2Decay and momentum when param.regularizer is set (#32845) * fuse L2Decay and momentum when param.regularizer is set * add unittest * refine * refine _create_regularization_of_grad of momentum * improve append_optimizer_op --- python/paddle/fluid/optimizer.py | 100 ++++++++++++++++-- python/paddle/fluid/regularizer.py | 86 --------------- .../fluid/tests/unittests/test_momentum_op.py | 71 +++++++++++++ .../fluid/tests/unittests/test_regularizer.py | 2 + python/paddle/optimizer/momentum.py | 35 +++++- python/paddle/optimizer/optimizer.py | 96 ++++++++++++++++- 6 files changed, 288 insertions(+), 102 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 254ffc796b3..e2ddc20b8f9 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -33,7 +33,6 @@ from .framework import program_guard from .initializer import Constant from .layer_helper import LayerHelper from .layers import ops -from .regularizer import append_regularization_ops from .dygraph import base as imperative_base from .dygraph import no_grad from .dygraph.learning_rate_scheduler import LearningRateDecay, _LearningRateEpochDecay @@ -884,6 +883,93 @@ class Optimizer(object): act_no_grad_set, callbacks) return params_grads + def _create_regularization_of_grad(self, param, grad, regularization=None): + """ Create and add backward regularization Operators + + Function helper of append_regularization_ops. + """ + # If no gradient or no regularization is specified, then we don't need to do anything + if grad is None or ((not hasattr(param, 'regularizer') or + (hasattr(param, 'regularizer') and + param.regularizer is None)) and + regularization is None): + return grad + regularization_term = None + if hasattr(param, 'regularizer') and param.regularizer is not None: + # Add variable for regularization term in grad block + regularization_term = param.regularizer(param, grad, grad.block) + elif regularization is not None: + regularization_term = regularization(param, grad, grad.block) + + assert regularization_term is not None + + new_grad = grad + if grad.type == core.VarDesc.VarType.SELECTED_ROWS: + # FIXME(zcd): If the grad is SELECTED_ROWS, after regularization, + # the grad's type and name will be changed. But the gradient's name + # is used in ParallelExecutor Reduce mode, so I add a flag for + # the new_grad here. + new_grad = grad.block.create_var( + name=grad.name + core.kNewGradSuffix(), + dtype=param.dtype, + shape=param.shape, + lod_level=param.lod_level, + type=core.VarDesc.VarType.LOD_TENSOR) + + inputs = {"X": [grad, regularization_term]} + outputs = {"Out": [new_grad]} + if framework.in_dygraph_mode(): + new_grad = core.ops.sum([grad, regularization_term]) + else: + grad.block.append_op(type='sum', inputs=inputs, outputs=outputs) + + return new_grad + + def append_regularization_ops(self, + parameters_and_grads, + regularization=None): + r"""Create and add backward regularization Operators + + Creates and adds backward regularization operators in the BlockDesc. + This will add gradients of the regularizer function to the gradients + of the parameters and return these modified gradients. This is the + same as implementing weight decay in optimizers for regularization. + + Args: + parameters_and_grads: A list of (parameters, gradients) pairs + that need to be regularized. + regularization: A global regularizer. If the parameter is not + set. It will be applied with regularizer. + + Returns: + list[(Variable, Variable)]: list of (parameters, gradients) \ + pair with the regularized gradient + + Raises: + Exception: Unknown regularization type + """ + params_and_grads = [] + if framework.in_dygraph_mode(): + for param, grad in parameters_and_grads: + new_grad = self._create_regularization_of_grad(param, grad, + regularization) + params_and_grads.append((param, new_grad)) + else: + repeate_regularizer = False + with framework.name_scope('regularization'): + for param, grad in parameters_and_grads: + if not repeate_regularizer and param.regularizer is not None and regularization is not None: + repeate_regularizer = True + logging.info( + "If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. " + "The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!" + % regularization.__str__()) + with param.block.program._optimized_guard([param, grad]): + new_grad = self._create_regularization_of_grad( + param, grad, regularization) + params_and_grads.append((param, new_grad)) + return params_and_grads + def apply_gradients(self, params_grads): """ Second part of `minimize`, appending optimization operators for @@ -916,8 +1002,8 @@ class Optimizer(object): params_grads = append_gradient_clip_ops(params_grads) # Add regularization if any - params_grads = append_regularization_ops(params_grads, - self.regularization) + params_grads = self.append_regularization_ops(params_grads, + self.regularization) optimize_ops = self._create_optimization_pass(params_grads) return optimize_ops @@ -939,8 +1025,8 @@ class Optimizer(object): framework.default_startup_program()): if self._grad_clip is not None: params_grads = self._grad_clip(params_grads) - params_grads = append_regularization_ops(params_grads, - self.regularization) + params_grads = self.append_regularization_ops( + params_grads, self.regularization) optimize_ops = self._create_optimization_pass(params_grads) else: program = loss.block.program @@ -1674,8 +1760,8 @@ class DGCMomentumOptimizer(Optimizer): not_dgc_params_grads = append_gradient_clip_ops( not_dgc_params_grads) - not_dgc_params_grads = append_regularization_ops(not_dgc_params_grads, - self.regularization) + not_dgc_params_grads = self.append_regularization_ops( + not_dgc_params_grads, self.regularization) params_grads = not_dgc_params_grads + dgc_params_grads params_grads = sorted(params_grads, key=lambda x: x[0].name) diff --git a/python/paddle/fluid/regularizer.py b/python/paddle/fluid/regularizer.py index 64ce283a63c..64bbca6c57c 100644 --- a/python/paddle/fluid/regularizer.py +++ b/python/paddle/fluid/regularizer.py @@ -22,92 +22,6 @@ from . import core __all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer'] -def _create_regularization_of_grad(param, grad, regularization=None): - """ Create and add backward regularization Operators - - Function helper of append_regularization_ops. - """ - # If no gradient or no regularization is specified, then we don't need to do anything - if grad is None or ((not hasattr(param, 'regularizer') or ( - hasattr(param, 'regularizer') and param.regularizer is None)) and - regularization is None): - return grad - regularization_term = None - if hasattr(param, 'regularizer') and param.regularizer is not None: - # Add variable for regularization term in grad block - regularization_term = param.regularizer(param, grad, grad.block) - elif regularization is not None: - regularization_term = regularization(param, grad, grad.block) - - assert regularization_term is not None - - new_grad = grad - if grad.type == core.VarDesc.VarType.SELECTED_ROWS: - # FIXME(zcd): If the grad is SELECTED_ROWS, after regularization, - # the grad's type and name will be changed. But the gradient's name - # is used in ParallelExecutor Reduce mode, so I add a flag for - # the new_grad here. - new_grad = grad.block.create_var( - name=grad.name + core.kNewGradSuffix(), - dtype=param.dtype, - shape=param.shape, - lod_level=param.lod_level, - type=core.VarDesc.VarType.LOD_TENSOR) - - inputs = {"X": [grad, regularization_term]} - outputs = {"Out": [new_grad]} - if in_dygraph_mode(): - new_grad = core.ops.sum([grad, regularization_term]) - else: - grad.block.append_op(type='sum', inputs=inputs, outputs=outputs) - - return new_grad - - -def append_regularization_ops(parameters_and_grads, regularization=None): - r"""Create and add backward regularization Operators - - Creates and adds backward regularization operators in the BlockDesc. - This will add gradients of the regularizer function to the gradients - of the parameters and return these modified gradients. This is the - same as implementing weight decay in optimizers for regularization. - - Args: - parameters_and_grads: A list of (parameters, gradients) pairs - that need to be regularized. - regularization: A global regularizer. If the parameter is not - set. It will be applied with regularizer. - - Returns: - list[(Variable, Variable)]: list of (parameters, gradients) \ - pair with the regularized gradient - - Raises: - Exception: Unknown regularization type - """ - params_and_grads = [] - if in_dygraph_mode(): - for param, grad in parameters_and_grads: - new_grad = _create_regularization_of_grad(param, grad, - regularization) - params_and_grads.append((param, new_grad)) - else: - repeate_regularizer = False - with framework.name_scope('regularization'): - for param, grad in parameters_and_grads: - if not repeate_regularizer and param.regularizer is not None and regularization is not None: - repeate_regularizer = True - logging.info( - "If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. " - "The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!" - % regularization.__str__()) - with param.block.program._optimized_guard([param, grad]): - new_grad = _create_regularization_of_grad(param, grad, - regularization) - params_and_grads.append((param, new_grad)) - return params_and_grads - - class WeightDecayRegularizer(object): """Base class for weight decay regularizers diff --git a/python/paddle/fluid/tests/unittests/test_momentum_op.py b/python/paddle/fluid/tests/unittests/test_momentum_op.py index e31587b225e..e79f6e5eb4a 100644 --- a/python/paddle/fluid/tests/unittests/test_momentum_op.py +++ b/python/paddle/fluid/tests/unittests/test_momentum_op.py @@ -613,6 +613,77 @@ class TestMomentumOpWithDecayAPI(unittest.TestCase): exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list) +class TestFusedMomentumWithDecayAPI(unittest.TestCase): + def get_program(self, weight_attr, bias_attr=False): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard( + main_program=main_program, startup_program=startup_program): + x = paddle.static.data(name='x', shape=[10, 10]) + linear = paddle.nn.Linear( + 10, 10, weight_attr=weight_attr, bias_attr=bias_attr) + out = linear(x) + loss = paddle.mean(out) + optimizer = paddle.optimizer.Momentum( + learning_rate=0.01, + momentum=0.9, + weight_decay=paddle.regularizer.L2Decay(0.5)) + optimizer.minimize(loss) + return main_program + + def test_param_has_l2decay(self): + paddle.enable_static() + weight_attr = paddle.ParamAttr( + name="weight", + initializer=paddle.nn.initializer.Constant(value=0.5), + regularizer=paddle.regularizer.L2Decay(0.1)) + program = self.get_program(weight_attr, bias_attr=False) + ops = program.global_block().ops + + self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay') + self.assertEqual(ops[-1].attr('regularization_coeff'), np.float32(0.1)) + for i in range(len(ops)): + self.assertTrue('sum' not in ops[i].type) + self.assertTrue('scale' not in ops[i].type) + + def test_param_has_l1decay(self): + paddle.enable_static() + weight_attr = paddle.ParamAttr( + name="weight", + initializer=paddle.nn.initializer.Constant(value=0.5), + regularizer=paddle.regularizer.L1Decay(0.1)) + bias_attr = paddle.ParamAttr( + name="bias", + initializer=paddle.nn.initializer.Constant(value=0.), + regularizer=None) + program = self.get_program(weight_attr, bias_attr) + ops = program.global_block().ops + + self.assertEqual(ops[-1].type, 'momentum') + self.assertEqual(ops[-2].type, 'momentum') + self.assertEqual(ops[-3].type, 'sum') + self.assertEqual(ops[-4].type, 'scale') + self.assertEqual(ops[-5].type, 'sign') + self.assertEqual(ops[-6].type, 'matmul_grad') + if 'weight' in ops[-1].input('Param'): + self.assertEqual(ops[-1].attr('regularization_method'), '') + self.assertEqual(ops[-1].attr('regularization_coeff'), 0) + if 'bias' in ops[-2].input('Param'): + self.assertEqual(ops[-2].attr('regularization_method'), 'l2_decay') + self.assertEqual(ops[-2].attr('regularization_coeff'), + np.float32(0.5)) + + def test_param_has_no_regularizer(self): + paddle.enable_static() + program = self.get_program(weight_attr=None) + ops = program.global_block().ops + self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay') + self.assertEqual(ops[-1].attr('regularization_coeff'), np.float32(0.5)) + for i in range(len(ops)): + self.assertTrue('sum' not in ops[i].type) + self.assertTrue('scale' not in ops[i].type) + + class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase): def __update_params(self, momentum, linear): for i in range(10): diff --git a/python/paddle/fluid/tests/unittests/test_regularizer.py b/python/paddle/fluid/tests/unittests/test_regularizer.py index edd69d67aaf..08a70fe1852 100644 --- a/python/paddle/fluid/tests/unittests/test_regularizer.py +++ b/python/paddle/fluid/tests/unittests/test_regularizer.py @@ -59,6 +59,7 @@ class TestL2DecayRegularizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) count_ops = len(block.ops) + optimizer = paddle.optimizer.Adam() params_grads = optimizer.append_regularization_ops(params_grads) self.assertEqual(len(params_grads), 1) self.assertEqual(len(block.ops), count_ops + 2) @@ -97,6 +98,7 @@ class TestL1DecayRegularizer(unittest.TestCase): params_grads = append_backward(mean_out) self.assertEqual(len(params_grads), 1) count_ops = len(block.ops) + optimizer = paddle.optimizer.Adam() params_grads = optimizer.append_regularization_ops(params_grads) self.assertEqual(len(params_grads), 1) self.assertEqual(len(block.ops), count_ops + 3) diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index faff090bcb1..85c5c60a34c 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -252,6 +252,19 @@ class Momentum(Optimizer): ) self._add_accumulator(self._velocity_acc_str, p) + def _create_regularization_of_grad(self, param, grad, regularization=None): + """ Create and add backward regularization Operators + + Function helper of append_regularization_ops. + """ + # If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused + # L2Decay with momentum which can refer to _append_optimize_op below. + if hasattr(param, 'regularizer') and isinstance(param.regularizer, + L2DecayRegularizer): + return grad + return super(Momentum, self)._create_regularization_of_grad( + param, grad, regularization) + def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) if isinstance(param_and_grad, dict): @@ -261,6 +274,20 @@ class Momentum(Optimizer): param_and_grad[0]) lr = self._create_param_lr(param_and_grad) + # For fusion of momentum and l2decay + param = param_and_grad[0] + regularization_method = self._regularization_method + regularization_coeff = self._regularization_coeff + if hasattr(param, 'regularizer'): + # we skip param's l2decay before, so fuse it with momentum here. + if isinstance(param.regularizer, L2DecayRegularizer): + regularization_method = "l2_decay" + regularization_coeff = param.regularizer._regularization_coeff + # the param's regularization has been done before, we avoid do l2decay in momentum. + elif param.regularizer is not None: + regularization_method = "" + regularization_coeff = 0 + if framework.in_dygraph_mode(): if isinstance(param_and_grad, dict): self._update_regularization(param_and_grad['weight_decay']) @@ -268,8 +295,8 @@ class Momentum(Optimizer): param_and_grad[0], param_and_grad[1], velocity_acc, lr, param_and_grad[0], velocity_acc, 'mu', self._momentum, 'use_nesterov', self._use_nesterov, 'regularization_method', - self._regularization_method, 'regularization_coeff', - self._regularization_coeff) + regularization_method, 'regularization_coeff', + regularization_coeff) return None find_master = self._multi_precision and param_and_grad[ @@ -280,8 +307,8 @@ class Momentum(Optimizer): attrs = { "mu": self._momentum, "use_nesterov": self._use_nesterov, - "regularization_method": self._regularization_method, - "regularization_coeff": self._regularization_coeff, + "regularization_method": regularization_method, + "regularization_coeff": regularization_coeff, "multi_precision": find_master, "rescale_grad": self._rescale_grad } diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 0f22b920b17..2cdf1d0d28e 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -32,7 +32,6 @@ from ..fluid.framework import program_guard, Parameter from ..fluid.initializer import Constant from ..fluid.layer_helper import LayerHelper from ..fluid.layers import ops -from ..fluid.regularizer import append_regularization_ops from ..fluid.dygraph import base as imperative_base from ..fluid.dygraph import no_grad from paddle.fluid import core @@ -850,8 +849,8 @@ class Optimizer(object): params_grads = append_gradient_clip_ops(params_grads) # Add regularization if any - params_grads = append_regularization_ops(params_grads, - self.regularization) + params_grads = self.append_regularization_ops(params_grads, + self.regularization) optimize_ops = self._create_optimization_pass(params_grads) return optimize_ops @@ -874,7 +873,7 @@ class Optimizer(object): if isinstance(params_grads, list): if self._grad_clip is not None: params_grads = self._grad_clip(params_grads) - params_grads = append_regularization_ops( + params_grads = self.append_regularization_ops( params_grads, self.regularization) else: grad_clip = params_grads['grad_clip'] @@ -882,7 +881,7 @@ class Optimizer(object): params_grads['params'] = grad_clip(params_grads[ 'params']) - params_grads['params'] = append_regularization_ops( + params_grads['params'] = self.append_regularization_ops( params_grads['params'], self.regularization) optimize_ops = self._create_optimization_pass(params_grads) else: @@ -891,6 +890,93 @@ class Optimizer(object): optimize_ops = self.apply_gradients(params_grads) return optimize_ops + def _create_regularization_of_grad(self, param, grad, regularization=None): + """ Create and add backward regularization Operators + + Function helper of append_regularization_ops. + """ + # If no gradient or no regularization is specified, then we don't need to do anything + if grad is None or ((not hasattr(param, 'regularizer') or + (hasattr(param, 'regularizer') and + param.regularizer is None)) and + regularization is None): + return grad + regularization_term = None + if hasattr(param, 'regularizer') and param.regularizer is not None: + # Add variable for regularization term in grad block + regularization_term = param.regularizer(param, grad, grad.block) + elif regularization is not None: + regularization_term = regularization(param, grad, grad.block) + + assert regularization_term is not None + + new_grad = grad + if grad.type == core.VarDesc.VarType.SELECTED_ROWS: + # FIXME(zcd): If the grad is SELECTED_ROWS, after regularization, + # the grad's type and name will be changed. But the gradient's name + # is used in ParallelExecutor Reduce mode, so I add a flag for + # the new_grad here. + new_grad = grad.block.create_var( + name=grad.name + core.kNewGradSuffix(), + dtype=param.dtype, + shape=param.shape, + lod_level=param.lod_level, + type=core.VarDesc.VarType.LOD_TENSOR) + + inputs = {"X": [grad, regularization_term]} + outputs = {"Out": [new_grad]} + if framework.in_dygraph_mode(): + new_grad = core.ops.sum([grad, regularization_term]) + else: + grad.block.append_op(type='sum', inputs=inputs, outputs=outputs) + + return new_grad + + def append_regularization_ops(self, + parameters_and_grads, + regularization=None): + r"""Create and add backward regularization Operators + + Creates and adds backward regularization operators in the BlockDesc. + This will add gradients of the regularizer function to the gradients + of the parameters and return these modified gradients. This is the + same as implementing weight decay in optimizers for regularization. + + Args: + parameters_and_grads: A list of (parameters, gradients) pairs + that need to be regularized. + regularization: A global regularizer. If the parameter is not + set. It will be applied with regularizer. + + Returns: + list[(Variable, Variable)]: list of (parameters, gradients) \ + pair with the regularized gradient + + Raises: + Exception: Unknown regularization type + """ + params_and_grads = [] + if framework.in_dygraph_mode(): + for param, grad in parameters_and_grads: + new_grad = self._create_regularization_of_grad(param, grad, + regularization) + params_and_grads.append((param, new_grad)) + else: + repeate_regularizer = False + with framework.name_scope('regularization'): + for param, grad in parameters_and_grads: + if not repeate_regularizer and param.regularizer is not None and regularization is not None: + repeate_regularizer = True + logging.info( + "If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. " + "The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!" + % regularization.__str__()) + with param.block.program._optimized_guard([param, grad]): + new_grad = self._create_regularization_of_grad( + param, grad, regularization) + params_and_grads.append((param, new_grad)) + return params_and_grads + def _get_no_grad_set(self, loss, no_grad_set=None): no_grad_set = _get_no_grad_set_name(no_grad_set) parameters = loss.block.program.global_block().all_parameters() -- GitLab