未验证 提交 0ca62744 编写于 作者: D dzhwinter 提交者: GitHub

"add global regularization" (#6443)

* "add global regularization"

* Polish `append_regularization_ops`
上级 5926e9a2
...@@ -18,8 +18,9 @@ class Optimizer(object): ...@@ -18,8 +18,9 @@ class Optimizer(object):
but need to use one of it's implementation. but need to use one of it's implementation.
""" """
def __init__(self, global_step=None): def __init__(self, global_step=None, regularization=None):
self._global_step = global_step self._global_step = global_step
self.regularization = regularization
# Dictionary of accumulators. Some optimizer subclasses need to # Dictionary of accumulators. Some optimizer subclasses need to
# allocate and manage extra variables associated with the parameters # allocate and manage extra variables associated with the parameters
# to train. These variables are called accumulators. # to train. These variables are called accumulators.
...@@ -199,7 +200,8 @@ class Optimizer(object): ...@@ -199,7 +200,8 @@ class Optimizer(object):
""" """
params_grads = append_backward_ops(loss, parameter_list, no_grad_set) params_grads = append_backward_ops(loss, parameter_list, no_grad_set)
# Add regularization if any # Add regularization if any
params_grads = append_regularization_ops(params_grads) params_grads = append_regularization_ops(params_grads,
self.regularization)
optimize_ops = self.create_optimization_pass(params_grads, loss, optimize_ops = self.create_optimization_pass(params_grads, loss,
startup_program) startup_program)
return optimize_ops return optimize_ops
...@@ -209,9 +211,9 @@ class SGDOptimizer(Optimizer): ...@@ -209,9 +211,9 @@ class SGDOptimizer(Optimizer):
""" Simple SGD optimizer without any state. """ Simple SGD optimizer without any state.
""" """
def __init__(self, learning_rate, global_step=None): def __init__(self, learning_rate, **kwargs):
assert learning_rate is not None assert learning_rate is not None
super(SGDOptimizer, self).__init__(global_step) super(SGDOptimizer, self).__init__(**kwargs)
self.type = "sgd" self.type = "sgd"
self._learning_rate = learning_rate self._learning_rate = learning_rate
...@@ -236,14 +238,10 @@ class MomentumOptimizer(Optimizer): ...@@ -236,14 +238,10 @@ class MomentumOptimizer(Optimizer):
""" """
_velocity_acc_str = "velocity" _velocity_acc_str = "velocity"
def __init__(self, def __init__(self, learning_rate, momentum, use_nesterov=False, **kwargs):
learning_rate,
momentum,
use_nesterov=False,
global_step=None):
assert learning_rate is not None assert learning_rate is not None
assert momentum is not None assert momentum is not None
super(MomentumOptimizer, self).__init__(global_step) super(MomentumOptimizer, self).__init__(**kwargs)
self.type = "momentum" self.type = "momentum"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._momentum = momentum self._momentum = momentum
...@@ -284,10 +282,10 @@ class AdagradOptimizer(Optimizer): ...@@ -284,10 +282,10 @@ class AdagradOptimizer(Optimizer):
""" """
_moment_acc_str = "moment" _moment_acc_str = "moment"
def __init__(self, learning_rate, epsilon=1.0e-6, global_step=None): def __init__(self, learning_rate, epsilon=1.0e-6, **kwargs):
assert learning_rate is not None assert learning_rate is not None
assert epsilon is not None assert epsilon is not None
super(AdagradOptimizer, self).__init__(global_step) super(AdagradOptimizer, self).__init__(**kwargs)
self.type = "adagrad" self.type = "adagrad"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._epsilon = epsilon self._epsilon = epsilon
...@@ -331,12 +329,12 @@ class AdamOptimizer(Optimizer): ...@@ -331,12 +329,12 @@ class AdamOptimizer(Optimizer):
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-8, epsilon=1e-8,
global_step=None): **kwargs):
assert learning_rate is not None assert learning_rate is not None
assert beta1 is not None assert beta1 is not None
assert beta2 is not None assert beta2 is not None
assert epsilon is not None assert epsilon is not None
super(AdamOptimizer, self).__init__(global_step) super(AdamOptimizer, self).__init__(**kwargs)
self.type = "adam" self.type = "adam"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._beta1 = beta1 self._beta1 = beta1
...@@ -436,12 +434,12 @@ class AdamaxOptimizer(Optimizer): ...@@ -436,12 +434,12 @@ class AdamaxOptimizer(Optimizer):
beta1=0.9, beta1=0.9,
beta2=0.999, beta2=0.999,
epsilon=1e-8, epsilon=1e-8,
global_step=None): **kwargs):
assert learning_rate is not None assert learning_rate is not None
assert beta1 is not None assert beta1 is not None
assert beta2 is not None assert beta2 is not None
assert epsilon is not None assert epsilon is not None
super(AdamaxOptimizer, self).__init__() super(AdamaxOptimizer, self).__init__(**kwargs)
self.type = "adamax" self.type = "adamax"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._beta1 = beta1 self._beta1 = beta1
...@@ -514,16 +512,12 @@ class DecayedAdagradOptimizer(Optimizer): ...@@ -514,16 +512,12 @@ class DecayedAdagradOptimizer(Optimizer):
""" """
_moment_acc_str = "moment" _moment_acc_str = "moment"
def __init__(self, def __init__(self, learning_rate, decay=0.95, epsilon=1.0e-6, **kwargs):
learning_rate,
decay=0.95,
epsilon=1.0e-6,
global_step=None):
assert learning_rate is not None assert learning_rate is not None
assert decay is not None assert decay is not None
assert epsilon is not None assert epsilon is not None
super(DecayedAdagradOptimizer, self).__init__(global_step) super(DecayedAdagradOptimizer, self).__init__(**kwargs)
self.type = "decayed_adagrad" self.type = "decayed_adagrad"
self._learning_rate = learning_rate self._learning_rate = learning_rate
self._decay = decay self._decay = decay
......
...@@ -3,7 +3,7 @@ import framework ...@@ -3,7 +3,7 @@ import framework
__all__ = ['append_regularization_ops', 'L1Decay', 'L2Decay'] __all__ = ['append_regularization_ops', 'L1Decay', 'L2Decay']
def append_regularization_ops(parameters_and_grads): def append_regularization_ops(parameters_and_grads, regularization=None):
"""Create and add backward regularization Operators """Create and add backward regularization Operators
Creates and adds backward regularization operators in the BlockDesc. Creates and adds backward regularization operators in the BlockDesc.
...@@ -14,6 +14,8 @@ def append_regularization_ops(parameters_and_grads): ...@@ -14,6 +14,8 @@ def append_regularization_ops(parameters_and_grads):
Args: Args:
parameters_and_grads: A list of (parameters, gradients) pairs parameters_and_grads: A list of (parameters, gradients) pairs
that need to be regularized. that need to be regularized.
regularization: A global regularizer. If the parameter is not
set. It will be applied with regularizer.
Returns: Returns:
list of (parameters, gradients) pair with the regularized gradient list of (parameters, gradients) pair with the regularized gradient
...@@ -23,14 +25,19 @@ def append_regularization_ops(parameters_and_grads): ...@@ -23,14 +25,19 @@ def append_regularization_ops(parameters_and_grads):
""" """
params_and_grads = [] params_and_grads = []
for param, grad in parameters_and_grads: for param, grad in parameters_and_grads:
regularization_term = None
if param.regularizer is not None:
# Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad.block)
elif regularization is not None:
regularization_term = regularization(param, grad.block)
# If no gradient or no regularization specified, # If no gradient or no regularization specified,
# then we don't need to do anything # then we don't need to do anything
if grad is None or param.regularizer is None: if grad is None or regularization_term is None:
params_and_grads.append((param, grad)) params_and_grads.append((param, grad))
continue continue
# Add variable for regularization term in grad block
regularization_term = param.regularizer(param, grad.block)
assert grad.shape == regularization_term.shape assert grad.shape == regularization_term.shape
grad.block.append_op( grad.block.append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册