未验证 提交 a526b3e0 编写于 作者: Z Zhang Ting 提交者: GitHub

fuse L2Decay and momentum when param.regularizer is set (#32845)

* fuse L2Decay and momentum when param.regularizer is set

* add unittest

* refine

* refine _create_regularization_of_grad of momentum

* improve append_optimizer_op
上级 42c1297e
...@@ -33,7 +33,6 @@ from .framework import program_guard ...@@ -33,7 +33,6 @@ from .framework import program_guard
from .initializer import Constant from .initializer import Constant
from .layer_helper import LayerHelper from .layer_helper import LayerHelper
from .layers import ops from .layers import ops
from .regularizer import append_regularization_ops
from .dygraph import base as imperative_base from .dygraph import base as imperative_base
from .dygraph import no_grad from .dygraph import no_grad
from .dygraph.learning_rate_scheduler import LearningRateDecay, _LearningRateEpochDecay from .dygraph.learning_rate_scheduler import LearningRateDecay, _LearningRateEpochDecay
...@@ -884,6 +883,93 @@ class Optimizer(object): ...@@ -884,6 +883,93 @@ class Optimizer(object):
act_no_grad_set, callbacks) act_no_grad_set, callbacks)
return params_grads return params_grads
def _create_regularization_of_grad(self, param, grad, regularization=None):
""" Create and add backward regularization Operators
Function helper of append_regularization_ops.
"""
# If no gradient or no regularization is specified, then we don't need to do anything
if grad is None or ((not hasattr(param, 'regularizer') or
(hasattr(param, 'regularizer') and
param.regularizer is None)) and
regularization is None):
return grad
regularization_term = None
if hasattr(param, 'regularizer') and param.regularizer is not None:
# Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad, grad.block)
elif regularization is not None:
regularization_term = regularization(param, grad, grad.block)
assert regularization_term is not None
new_grad = grad
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
# FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
# the grad's type and name will be changed. But the gradient's name
# is used in ParallelExecutor Reduce mode, so I add a flag for
# the new_grad here.
new_grad = grad.block.create_var(
name=grad.name + core.kNewGradSuffix(),
dtype=param.dtype,
shape=param.shape,
lod_level=param.lod_level,
type=core.VarDesc.VarType.LOD_TENSOR)
inputs = {"X": [grad, regularization_term]}
outputs = {"Out": [new_grad]}
if framework.in_dygraph_mode():
new_grad = core.ops.sum([grad, regularization_term])
else:
grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)
return new_grad
def append_regularization_ops(self,
parameters_and_grads,
regularization=None):
r"""Create and add backward regularization Operators
Creates and adds backward regularization operators in the BlockDesc.
This will add gradients of the regularizer function to the gradients
of the parameters and return these modified gradients. This is the
same as implementing weight decay in optimizers for regularization.
Args:
parameters_and_grads: A list of (parameters, gradients) pairs
that need to be regularized.
regularization: A global regularizer. If the parameter is not
set. It will be applied with regularizer.
Returns:
list[(Variable, Variable)]: list of (parameters, gradients) \
pair with the regularized gradient
Raises:
Exception: Unknown regularization type
"""
params_and_grads = []
if framework.in_dygraph_mode():
for param, grad in parameters_and_grads:
new_grad = self._create_regularization_of_grad(param, grad,
regularization)
params_and_grads.append((param, new_grad))
else:
repeate_regularizer = False
with framework.name_scope('regularization'):
for param, grad in parameters_and_grads:
if not repeate_regularizer and param.regularizer is not None and regularization is not None:
repeate_regularizer = True
logging.info(
"If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. "
"The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
% regularization.__str__())
with param.block.program._optimized_guard([param, grad]):
new_grad = self._create_regularization_of_grad(
param, grad, regularization)
params_and_grads.append((param, new_grad))
return params_and_grads
def apply_gradients(self, params_grads): def apply_gradients(self, params_grads):
""" """
Second part of `minimize`, appending optimization operators for Second part of `minimize`, appending optimization operators for
...@@ -916,8 +1002,8 @@ class Optimizer(object): ...@@ -916,8 +1002,8 @@ class Optimizer(object):
params_grads = append_gradient_clip_ops(params_grads) params_grads = append_gradient_clip_ops(params_grads)
# Add regularization if any # Add regularization if any
params_grads = append_regularization_ops(params_grads, params_grads = self.append_regularization_ops(params_grads,
self.regularization) self.regularization)
optimize_ops = self._create_optimization_pass(params_grads) optimize_ops = self._create_optimization_pass(params_grads)
return optimize_ops return optimize_ops
...@@ -939,8 +1025,8 @@ class Optimizer(object): ...@@ -939,8 +1025,8 @@ class Optimizer(object):
framework.default_startup_program()): framework.default_startup_program()):
if self._grad_clip is not None: if self._grad_clip is not None:
params_grads = self._grad_clip(params_grads) params_grads = self._grad_clip(params_grads)
params_grads = append_regularization_ops(params_grads, params_grads = self.append_regularization_ops(
self.regularization) params_grads, self.regularization)
optimize_ops = self._create_optimization_pass(params_grads) optimize_ops = self._create_optimization_pass(params_grads)
else: else:
program = loss.block.program program = loss.block.program
...@@ -1674,8 +1760,8 @@ class DGCMomentumOptimizer(Optimizer): ...@@ -1674,8 +1760,8 @@ class DGCMomentumOptimizer(Optimizer):
not_dgc_params_grads = append_gradient_clip_ops( not_dgc_params_grads = append_gradient_clip_ops(
not_dgc_params_grads) not_dgc_params_grads)
not_dgc_params_grads = append_regularization_ops(not_dgc_params_grads, not_dgc_params_grads = self.append_regularization_ops(
self.regularization) not_dgc_params_grads, self.regularization)
params_grads = not_dgc_params_grads + dgc_params_grads params_grads = not_dgc_params_grads + dgc_params_grads
params_grads = sorted(params_grads, key=lambda x: x[0].name) params_grads = sorted(params_grads, key=lambda x: x[0].name)
......
...@@ -22,92 +22,6 @@ from . import core ...@@ -22,92 +22,6 @@ from . import core
__all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer'] __all__ = ['L1Decay', 'L2Decay', 'L1DecayRegularizer', 'L2DecayRegularizer']
def _create_regularization_of_grad(param, grad, regularization=None):
""" Create and add backward regularization Operators
Function helper of append_regularization_ops.
"""
# If no gradient or no regularization is specified, then we don't need to do anything
if grad is None or ((not hasattr(param, 'regularizer') or (
hasattr(param, 'regularizer') and param.regularizer is None)) and
regularization is None):
return grad
regularization_term = None
if hasattr(param, 'regularizer') and param.regularizer is not None:
# Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad, grad.block)
elif regularization is not None:
regularization_term = regularization(param, grad, grad.block)
assert regularization_term is not None
new_grad = grad
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
# FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
# the grad's type and name will be changed. But the gradient's name
# is used in ParallelExecutor Reduce mode, so I add a flag for
# the new_grad here.
new_grad = grad.block.create_var(
name=grad.name + core.kNewGradSuffix(),
dtype=param.dtype,
shape=param.shape,
lod_level=param.lod_level,
type=core.VarDesc.VarType.LOD_TENSOR)
inputs = {"X": [grad, regularization_term]}
outputs = {"Out": [new_grad]}
if in_dygraph_mode():
new_grad = core.ops.sum([grad, regularization_term])
else:
grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)
return new_grad
def append_regularization_ops(parameters_and_grads, regularization=None):
r"""Create and add backward regularization Operators
Creates and adds backward regularization operators in the BlockDesc.
This will add gradients of the regularizer function to the gradients
of the parameters and return these modified gradients. This is the
same as implementing weight decay in optimizers for regularization.
Args:
parameters_and_grads: A list of (parameters, gradients) pairs
that need to be regularized.
regularization: A global regularizer. If the parameter is not
set. It will be applied with regularizer.
Returns:
list[(Variable, Variable)]: list of (parameters, gradients) \
pair with the regularized gradient
Raises:
Exception: Unknown regularization type
"""
params_and_grads = []
if in_dygraph_mode():
for param, grad in parameters_and_grads:
new_grad = _create_regularization_of_grad(param, grad,
regularization)
params_and_grads.append((param, new_grad))
else:
repeate_regularizer = False
with framework.name_scope('regularization'):
for param, grad in parameters_and_grads:
if not repeate_regularizer and param.regularizer is not None and regularization is not None:
repeate_regularizer = True
logging.info(
"If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. "
"The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
% regularization.__str__())
with param.block.program._optimized_guard([param, grad]):
new_grad = _create_regularization_of_grad(param, grad,
regularization)
params_and_grads.append((param, new_grad))
return params_and_grads
class WeightDecayRegularizer(object): class WeightDecayRegularizer(object):
"""Base class for weight decay regularizers """Base class for weight decay regularizers
......
...@@ -613,6 +613,77 @@ class TestMomentumOpWithDecayAPI(unittest.TestCase): ...@@ -613,6 +613,77 @@ class TestMomentumOpWithDecayAPI(unittest.TestCase):
exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list) exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list)
class TestFusedMomentumWithDecayAPI(unittest.TestCase):
def get_program(self, weight_attr, bias_attr=False):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(
main_program=main_program, startup_program=startup_program):
x = paddle.static.data(name='x', shape=[10, 10])
linear = paddle.nn.Linear(
10, 10, weight_attr=weight_attr, bias_attr=bias_attr)
out = linear(x)
loss = paddle.mean(out)
optimizer = paddle.optimizer.Momentum(
learning_rate=0.01,
momentum=0.9,
weight_decay=paddle.regularizer.L2Decay(0.5))
optimizer.minimize(loss)
return main_program
def test_param_has_l2decay(self):
paddle.enable_static()
weight_attr = paddle.ParamAttr(
name="weight",
initializer=paddle.nn.initializer.Constant(value=0.5),
regularizer=paddle.regularizer.L2Decay(0.1))
program = self.get_program(weight_attr, bias_attr=False)
ops = program.global_block().ops
self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay')
self.assertEqual(ops[-1].attr('regularization_coeff'), np.float32(0.1))
for i in range(len(ops)):
self.assertTrue('sum' not in ops[i].type)
self.assertTrue('scale' not in ops[i].type)
def test_param_has_l1decay(self):
paddle.enable_static()
weight_attr = paddle.ParamAttr(
name="weight",
initializer=paddle.nn.initializer.Constant(value=0.5),
regularizer=paddle.regularizer.L1Decay(0.1))
bias_attr = paddle.ParamAttr(
name="bias",
initializer=paddle.nn.initializer.Constant(value=0.),
regularizer=None)
program = self.get_program(weight_attr, bias_attr)
ops = program.global_block().ops
self.assertEqual(ops[-1].type, 'momentum')
self.assertEqual(ops[-2].type, 'momentum')
self.assertEqual(ops[-3].type, 'sum')
self.assertEqual(ops[-4].type, 'scale')
self.assertEqual(ops[-5].type, 'sign')
self.assertEqual(ops[-6].type, 'matmul_grad')
if 'weight' in ops[-1].input('Param'):
self.assertEqual(ops[-1].attr('regularization_method'), '')
self.assertEqual(ops[-1].attr('regularization_coeff'), 0)
if 'bias' in ops[-2].input('Param'):
self.assertEqual(ops[-2].attr('regularization_method'), 'l2_decay')
self.assertEqual(ops[-2].attr('regularization_coeff'),
np.float32(0.5))
def test_param_has_no_regularizer(self):
paddle.enable_static()
program = self.get_program(weight_attr=None)
ops = program.global_block().ops
self.assertEqual(ops[-1].attr('regularization_method'), 'l2_decay')
self.assertEqual(ops[-1].attr('regularization_coeff'), np.float32(0.5))
for i in range(len(ops)):
self.assertTrue('sum' not in ops[i].type)
self.assertTrue('scale' not in ops[i].type)
class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase): class TestMomentumOpVsMomentumOpWithDecayAPI(unittest.TestCase):
def __update_params(self, momentum, linear): def __update_params(self, momentum, linear):
for i in range(10): for i in range(10):
......
...@@ -59,6 +59,7 @@ class TestL2DecayRegularizer(unittest.TestCase): ...@@ -59,6 +59,7 @@ class TestL2DecayRegularizer(unittest.TestCase):
params_grads = append_backward(mean_out) params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
count_ops = len(block.ops) count_ops = len(block.ops)
optimizer = paddle.optimizer.Adam()
params_grads = optimizer.append_regularization_ops(params_grads) params_grads = optimizer.append_regularization_ops(params_grads)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(block.ops), count_ops + 2) self.assertEqual(len(block.ops), count_ops + 2)
...@@ -97,6 +98,7 @@ class TestL1DecayRegularizer(unittest.TestCase): ...@@ -97,6 +98,7 @@ class TestL1DecayRegularizer(unittest.TestCase):
params_grads = append_backward(mean_out) params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
count_ops = len(block.ops) count_ops = len(block.ops)
optimizer = paddle.optimizer.Adam()
params_grads = optimizer.append_regularization_ops(params_grads) params_grads = optimizer.append_regularization_ops(params_grads)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(block.ops), count_ops + 3) self.assertEqual(len(block.ops), count_ops + 3)
......
...@@ -252,6 +252,19 @@ class Momentum(Optimizer): ...@@ -252,6 +252,19 @@ class Momentum(Optimizer):
) )
self._add_accumulator(self._velocity_acc_str, p) self._add_accumulator(self._velocity_acc_str, p)
def _create_regularization_of_grad(self, param, grad, regularization=None):
""" Create and add backward regularization Operators
Function helper of append_regularization_ops.
"""
# If ParamAttr is set to L2Decay, we skip doing regularization here. And then we fused
# L2Decay with momentum which can refer to _append_optimize_op below.
if hasattr(param, 'regularizer') and isinstance(param.regularizer,
L2DecayRegularizer):
return grad
return super(Momentum, self)._create_regularization_of_grad(
param, grad, regularization)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
if isinstance(param_and_grad, dict): if isinstance(param_and_grad, dict):
...@@ -261,6 +274,20 @@ class Momentum(Optimizer): ...@@ -261,6 +274,20 @@ class Momentum(Optimizer):
param_and_grad[0]) param_and_grad[0])
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
# For fusion of momentum and l2decay
param = param_and_grad[0]
regularization_method = self._regularization_method
regularization_coeff = self._regularization_coeff
if hasattr(param, 'regularizer'):
# we skip param's l2decay before, so fuse it with momentum here.
if isinstance(param.regularizer, L2DecayRegularizer):
regularization_method = "l2_decay"
regularization_coeff = param.regularizer._regularization_coeff
# the param's regularization has been done before, we avoid do l2decay in momentum.
elif param.regularizer is not None:
regularization_method = ""
regularization_coeff = 0
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
if isinstance(param_and_grad, dict): if isinstance(param_and_grad, dict):
self._update_regularization(param_and_grad['weight_decay']) self._update_regularization(param_and_grad['weight_decay'])
...@@ -268,8 +295,8 @@ class Momentum(Optimizer): ...@@ -268,8 +295,8 @@ class Momentum(Optimizer):
param_and_grad[0], param_and_grad[1], velocity_acc, lr, param_and_grad[0], param_and_grad[1], velocity_acc, lr,
param_and_grad[0], velocity_acc, 'mu', self._momentum, param_and_grad[0], velocity_acc, 'mu', self._momentum,
'use_nesterov', self._use_nesterov, 'regularization_method', 'use_nesterov', self._use_nesterov, 'regularization_method',
self._regularization_method, 'regularization_coeff', regularization_method, 'regularization_coeff',
self._regularization_coeff) regularization_coeff)
return None return None
find_master = self._multi_precision and param_and_grad[ find_master = self._multi_precision and param_and_grad[
...@@ -280,8 +307,8 @@ class Momentum(Optimizer): ...@@ -280,8 +307,8 @@ class Momentum(Optimizer):
attrs = { attrs = {
"mu": self._momentum, "mu": self._momentum,
"use_nesterov": self._use_nesterov, "use_nesterov": self._use_nesterov,
"regularization_method": self._regularization_method, "regularization_method": regularization_method,
"regularization_coeff": self._regularization_coeff, "regularization_coeff": regularization_coeff,
"multi_precision": find_master, "multi_precision": find_master,
"rescale_grad": self._rescale_grad "rescale_grad": self._rescale_grad
} }
......
...@@ -32,7 +32,6 @@ from ..fluid.framework import program_guard, Parameter ...@@ -32,7 +32,6 @@ from ..fluid.framework import program_guard, Parameter
from ..fluid.initializer import Constant from ..fluid.initializer import Constant
from ..fluid.layer_helper import LayerHelper from ..fluid.layer_helper import LayerHelper
from ..fluid.layers import ops from ..fluid.layers import ops
from ..fluid.regularizer import append_regularization_ops
from ..fluid.dygraph import base as imperative_base from ..fluid.dygraph import base as imperative_base
from ..fluid.dygraph import no_grad from ..fluid.dygraph import no_grad
from paddle.fluid import core from paddle.fluid import core
...@@ -850,8 +849,8 @@ class Optimizer(object): ...@@ -850,8 +849,8 @@ class Optimizer(object):
params_grads = append_gradient_clip_ops(params_grads) params_grads = append_gradient_clip_ops(params_grads)
# Add regularization if any # Add regularization if any
params_grads = append_regularization_ops(params_grads, params_grads = self.append_regularization_ops(params_grads,
self.regularization) self.regularization)
optimize_ops = self._create_optimization_pass(params_grads) optimize_ops = self._create_optimization_pass(params_grads)
return optimize_ops return optimize_ops
...@@ -874,7 +873,7 @@ class Optimizer(object): ...@@ -874,7 +873,7 @@ class Optimizer(object):
if isinstance(params_grads, list): if isinstance(params_grads, list):
if self._grad_clip is not None: if self._grad_clip is not None:
params_grads = self._grad_clip(params_grads) params_grads = self._grad_clip(params_grads)
params_grads = append_regularization_ops( params_grads = self.append_regularization_ops(
params_grads, self.regularization) params_grads, self.regularization)
else: else:
grad_clip = params_grads['grad_clip'] grad_clip = params_grads['grad_clip']
...@@ -882,7 +881,7 @@ class Optimizer(object): ...@@ -882,7 +881,7 @@ class Optimizer(object):
params_grads['params'] = grad_clip(params_grads[ params_grads['params'] = grad_clip(params_grads[
'params']) 'params'])
params_grads['params'] = append_regularization_ops( params_grads['params'] = self.append_regularization_ops(
params_grads['params'], self.regularization) params_grads['params'], self.regularization)
optimize_ops = self._create_optimization_pass(params_grads) optimize_ops = self._create_optimization_pass(params_grads)
else: else:
...@@ -891,6 +890,93 @@ class Optimizer(object): ...@@ -891,6 +890,93 @@ class Optimizer(object):
optimize_ops = self.apply_gradients(params_grads) optimize_ops = self.apply_gradients(params_grads)
return optimize_ops return optimize_ops
def _create_regularization_of_grad(self, param, grad, regularization=None):
""" Create and add backward regularization Operators
Function helper of append_regularization_ops.
"""
# If no gradient or no regularization is specified, then we don't need to do anything
if grad is None or ((not hasattr(param, 'regularizer') or
(hasattr(param, 'regularizer') and
param.regularizer is None)) and
regularization is None):
return grad
regularization_term = None
if hasattr(param, 'regularizer') and param.regularizer is not None:
# Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad, grad.block)
elif regularization is not None:
regularization_term = regularization(param, grad, grad.block)
assert regularization_term is not None
new_grad = grad
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
# FIXME(zcd): If the grad is SELECTED_ROWS, after regularization,
# the grad's type and name will be changed. But the gradient's name
# is used in ParallelExecutor Reduce mode, so I add a flag for
# the new_grad here.
new_grad = grad.block.create_var(
name=grad.name + core.kNewGradSuffix(),
dtype=param.dtype,
shape=param.shape,
lod_level=param.lod_level,
type=core.VarDesc.VarType.LOD_TENSOR)
inputs = {"X": [grad, regularization_term]}
outputs = {"Out": [new_grad]}
if framework.in_dygraph_mode():
new_grad = core.ops.sum([grad, regularization_term])
else:
grad.block.append_op(type='sum', inputs=inputs, outputs=outputs)
return new_grad
def append_regularization_ops(self,
parameters_and_grads,
regularization=None):
r"""Create and add backward regularization Operators
Creates and adds backward regularization operators in the BlockDesc.
This will add gradients of the regularizer function to the gradients
of the parameters and return these modified gradients. This is the
same as implementing weight decay in optimizers for regularization.
Args:
parameters_and_grads: A list of (parameters, gradients) pairs
that need to be regularized.
regularization: A global regularizer. If the parameter is not
set. It will be applied with regularizer.
Returns:
list[(Variable, Variable)]: list of (parameters, gradients) \
pair with the regularized gradient
Raises:
Exception: Unknown regularization type
"""
params_and_grads = []
if framework.in_dygraph_mode():
for param, grad in parameters_and_grads:
new_grad = self._create_regularization_of_grad(param, grad,
regularization)
params_and_grads.append((param, new_grad))
else:
repeate_regularizer = False
with framework.name_scope('regularization'):
for param, grad in parameters_and_grads:
if not repeate_regularizer and param.regularizer is not None and regularization is not None:
repeate_regularizer = True
logging.info(
"If regularizer of a Parameter has been set by 'fluid.ParamAttr' or 'fluid.WeightNormParamAttr' already. "
"The Regularization[%s] in Optimizer will not take effect, and it will only be applied to other Parameters!"
% regularization.__str__())
with param.block.program._optimized_guard([param, grad]):
new_grad = self._create_regularization_of_grad(
param, grad, regularization)
params_and_grads.append((param, new_grad))
return params_and_grads
def _get_no_grad_set(self, loss, no_grad_set=None): def _get_no_grad_set(self, loss, no_grad_set=None):
no_grad_set = _get_no_grad_set_name(no_grad_set) no_grad_set = _get_no_grad_set_name(no_grad_set)
parameters = loss.block.program.global_block().all_parameters() parameters = loss.block.program.global_block().all_parameters()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册