未验证 提交 487c7972 编写于 作者: L littletomatodonkey 提交者: GitHub

fix optimizer builder (#751)

上级 dd79f81f
...@@ -60,7 +60,6 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): ...@@ -60,7 +60,6 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
optim = getattr(optimizer, optim_name)(learning_rate=lr, optim = getattr(optimizer, optim_name)(learning_rate=lr,
weight_decay=reg, weight_decay=reg,
grad_clip=grad_clip, grad_clip=grad_clip,
parameter_list=parameters, **config)(parameters=parameters)
**config)()
logger.info("build optimizer ({}) success..".format(optim)) logger.info("build optimizer ({}) success..".format(optim))
return optim, lr return optim, lr
...@@ -16,52 +16,12 @@ from __future__ import absolute_import ...@@ -16,52 +16,12 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import sys from paddle import optimizer as optim
import paddle
import paddle.regularizer as regularizer
__all__ = ['OptimizerBuilder']
class L1Decay(object):
"""
L1 Weight Decay Regularization, which encourages the weights to be sparse.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def __init__(self, factor=0.0):
super(L1Decay, self).__init__()
self.factor = factor
def __call__(self):
reg = regularizer.L1Decay(self.factor)
return reg
class L2Decay(object):
"""
L2 Weight Decay Regularization, which encourages the weights to be sparse.
Args:
factor(float): regularization coeff. Default:0.0.
"""
def __init__(self, factor=0.0):
super(L2Decay, self).__init__()
self.factor = factor
def __call__(self):
reg = regularizer.L2Decay(self.factor)
return reg
class Momentum(object): class Momentum(object):
""" """
Simple Momentum optimizer with velocity state. Simple Momentum optimizer with velocity state.
Args: Args:
learning_rate (float|Variable) - The learning rate used to update parameters. learning_rate (float|Variable) - The learning rate used to update parameters.
Can be a float value or a Variable with one float value as data element. Can be a float value or a Variable with one float value as data element.
...@@ -72,31 +32,63 @@ class Momentum(object): ...@@ -72,31 +32,63 @@ class Momentum(object):
def __init__(self, def __init__(self,
learning_rate, learning_rate,
momentum, momentum,
parameter_list=None, weight_decay=None,
regularization=None, grad_clip=None):
multi_precision=False,
**args):
super(Momentum, self).__init__() super(Momentum, self).__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.momentum = momentum self.momentum = momentum
self.parameter_list = parameter_list self.weight_decay = weight_decay
self.regularization = regularization self.grad_clip = grad_clip
self.multi_precision = multi_precision
def __call__(self): def __call__(self, parameters):
opt = paddle.optimizer.Momentum( opt = optim.Momentum(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
parameters=self.parameter_list, weight_decay=self.weight_decay,
weight_decay=self.regularization, grad_clip=self.grad_clip,
multi_precision=self.multi_precision) parameters=parameters)
return opt
class Adam(object):
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
parameter_list=None,
weight_decay=None,
grad_clip=None,
name=None,
lazy_mode=False):
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.parameter_list = parameter_list
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.grad_clip = grad_clip
self.name = name
self.lazy_mode = lazy_mode
def __call__(self, parameters):
opt = optim.Adam(
learning_rate=self.learning_rate,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
name=self.name,
lazy_mode=self.lazy_mode,
parameters=parameters)
return opt return opt
class RMSProp(object): class RMSProp(object):
""" """
Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method. Root Mean Squared Propagation (RMSProp) is an unpublished, adaptive learning rate method.
Args: Args:
learning_rate (float|Variable) - The learning rate used to update parameters. learning_rate (float|Variable) - The learning rate used to update parameters.
Can be a float value or a Variable with one float value as data element. Can be a float value or a Variable with one float value as data element.
...@@ -108,58 +100,26 @@ class RMSProp(object): ...@@ -108,58 +100,26 @@ class RMSProp(object):
def __init__(self, def __init__(self,
learning_rate, learning_rate,
momentum, momentum=0.0,
rho=0.95, rho=0.95,
epsilon=1e-6, epsilon=1e-6,
parameter_list=None, weight_decay=None,
regularization=None, grad_clip=None):
**args):
super(RMSProp, self).__init__() super(RMSProp, self).__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.momentum = momentum self.momentum = momentum
self.rho = rho self.rho = rho
self.epsilon = epsilon self.epsilon = epsilon
self.parameter_list = parameter_list self.weight_decay = weight_decay
self.regularization = regularization self.grad_clip = grad_clip
def __call__(self): def __call__(self, parameters):
opt = paddle.optimizer.RMSProp( opt = optim.RMSProp(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
rho=self.rho, rho=self.rho,
epsilon=self.epsilon, epsilon=self.epsilon,
parameters=self.parameter_list, weight_decay=self.weight_decay,
weight_decay=self.regularization) grad_clip=self.grad_clip,
parameters=parameters)
return opt return opt
\ No newline at end of file
class OptimizerBuilder(object):
"""
Build optimizer
Args:
function(str): optimizer name of learning rate
params(dict): parameters used for init the class
regularizer (dict): parameters used for create regularization
"""
def __init__(self,
function='Momentum',
params={'momentum': 0.9},
regularizer=None):
self.function = function
self.params = params
# create regularizer
if regularizer is not None:
mod = sys.modules[__name__]
reg_func = regularizer['function'] + 'Decay'
del regularizer['function']
reg = getattr(mod, reg_func)(**regularizer)()
self.params['regularization'] = reg
def __call__(self, learning_rate, parameter_list=None):
mod = sys.modules[__name__]
opt = getattr(mod, self.function)
return opt(learning_rate=learning_rate,
parameter_list=parameter_list,
**self.params)()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册