提交 079434dc 编写于 作者: G gaotingquan

feat: add AdamW

上级 17a06daf
...@@ -175,7 +175,7 @@ class Engine(object): ...@@ -175,7 +175,7 @@ class Engine(object):
if self.mode == 'train': if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer( self.optimizer, self.lr_sch = build_optimizer(
self.config["Optimizer"], self.config["Global"]["epochs"], self.config["Optimizer"], self.config["Global"]["epochs"],
len(self.train_dataloader), self.model.parameters()) len(self.train_dataloader), [self.model])
# for distributed # for distributed
self.config["Global"][ self.config["Global"][
......
...@@ -41,18 +41,21 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): ...@@ -41,18 +41,21 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
return lr return lr
def build_optimizer(config, epochs, step_each_epoch, parameters=None): def build_optimizer(config, epochs, step_each_epoch, model_list):
config = copy.deepcopy(config) config = copy.deepcopy(config)
# step1 build lr # step1 build lr
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch) lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
logger.debug("build lr ({}) success..".format(lr)) logger.debug("build lr ({}) success..".format(lr))
# step2 build regularization # step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None: if 'regularizer' in config and config['regularizer'] is not None:
if 'weight_decay' in config:
logger.warning(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
)
reg_config = config.pop('regularizer') reg_config = config.pop('regularizer')
reg_name = reg_config.pop('name') + 'Decay' reg_name = reg_config.pop('name') + 'Decay'
reg = getattr(paddle.regularizer, reg_name)(**reg_config) reg = getattr(paddle.regularizer, reg_name)(**reg_config)
else: config["weight_decay"] = reg
reg = None
logger.debug("build regularizer ({}) success..".format(reg)) logger.debug("build regularizer ({}) success..".format(reg))
# step3 build optimizer # step3 build optimizer
optim_name = config.pop('name') optim_name = config.pop('name')
...@@ -62,8 +65,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters=None): ...@@ -62,8 +65,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters=None):
else: else:
grad_clip = None grad_clip = None
optim = getattr(optimizer, optim_name)(learning_rate=lr, optim = getattr(optimizer, optim_name)(learning_rate=lr,
weight_decay=reg,
grad_clip=grad_clip, grad_clip=grad_clip,
**config)(parameters=parameters) **config)(model_list=model_list)
logger.debug("build optimizer ({}) success..".format(optim)) logger.debug("build optimizer ({}) success..".format(optim))
return optim, lr return optim, lr
...@@ -35,14 +35,15 @@ class Momentum(object): ...@@ -35,14 +35,15 @@ class Momentum(object):
weight_decay=None, weight_decay=None,
grad_clip=None, grad_clip=None,
multi_precision=False): multi_precision=False):
super(Momentum, self).__init__() super().__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.momentum = momentum self.momentum = momentum
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.grad_clip = grad_clip self.grad_clip = grad_clip
self.multi_precision = multi_precision self.multi_precision = multi_precision
def __call__(self, parameters): def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
opt = optim.Momentum( opt = optim.Momentum(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
...@@ -77,7 +78,8 @@ class Adam(object): ...@@ -77,7 +78,8 @@ class Adam(object):
self.lazy_mode = lazy_mode self.lazy_mode = lazy_mode
self.multi_precision = multi_precision self.multi_precision = multi_precision
def __call__(self, parameters): def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
opt = optim.Adam( opt = optim.Adam(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
beta1=self.beta1, beta1=self.beta1,
...@@ -112,7 +114,7 @@ class RMSProp(object): ...@@ -112,7 +114,7 @@ class RMSProp(object):
weight_decay=None, weight_decay=None,
grad_clip=None, grad_clip=None,
multi_precision=False): multi_precision=False):
super(RMSProp, self).__init__() super().__init__()
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.momentum = momentum self.momentum = momentum
self.rho = rho self.rho = rho
...@@ -120,7 +122,8 @@ class RMSProp(object): ...@@ -120,7 +122,8 @@ class RMSProp(object):
self.weight_decay = weight_decay self.weight_decay = weight_decay
self.grad_clip = grad_clip self.grad_clip = grad_clip
def __call__(self, parameters): def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
opt = optim.RMSProp( opt = optim.RMSProp(
learning_rate=self.learning_rate, learning_rate=self.learning_rate,
momentum=self.momentum, momentum=self.momentum,
...@@ -130,3 +133,57 @@ class RMSProp(object): ...@@ -130,3 +133,57 @@ class RMSProp(object):
grad_clip=self.grad_clip, grad_clip=self.grad_clip,
parameters=parameters) parameters=parameters)
return opt return opt
class AdamW(object):
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
weight_decay=None,
multi_precision=False,
grad_clip=None,
no_weight_decay_name=None,
one_dim_param_no_weight_decay=False,
**args):
super().__init__()
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.grad_clip = grad_clip
self.weight_decay = weight_decay
self.multi_precision = multi_precision
self.no_weight_decay_name_list = no_weight_decay_name.split(
) if no_weight_decay_name else []
self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
self.no_weight_decay_param_name_list = [
p.name for model in model_list for n, p in model.named_parameters()
if any(nd in n for nd in self.no_weight_decay_name_list)
]
if self.one_dim_param_no_weight_decay:
self.no_weight_decay_param_name_list += [
p.name for model in model_list
for n, p in model.named_parameters() if len(p.shape) == 1
]
opt = optim.AdamW(
learning_rate=self.learning_rate,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon,
parameters=parameters,
weight_decay=self.weight_decay,
multi_precision=self.multi_precision,
grad_clip=self.grad_clip,
apply_decay_param_fun=self._apply_decay_param_fun)
return opt
def _apply_decay_param_fun(self, name):
return name not in self.no_weight_decay_param_name_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册