提交 079434dc 编写于 作者: G gaotingquan

feat: add AdamW

上级 17a06daf
......@@ -175,7 +175,7 @@ class Engine(object):
if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer(
self.config["Optimizer"], self.config["Global"]["epochs"],
len(self.train_dataloader), self.model.parameters())
len(self.train_dataloader), [self.model])
# for distributed
self.config["Global"][
......
......@@ -41,18 +41,21 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
return lr
def build_optimizer(config, epochs, step_each_epoch, parameters=None):
def build_optimizer(config, epochs, step_each_epoch, model_list):
config = copy.deepcopy(config)
# step1 build lr
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
logger.debug("build lr ({}) success..".format(lr))
# step2 build regularization
if 'regularizer' in config and config['regularizer'] is not None:
if 'weight_decay' in config:
logger.warning(
"ConfigError: Only one of regularizer and weight_decay can be set in Optimizer Config. \"weight_decay\" has been ignored."
)
reg_config = config.pop('regularizer')
reg_name = reg_config.pop('name') + 'Decay'
reg = getattr(paddle.regularizer, reg_name)(**reg_config)
else:
reg = None
config["weight_decay"] = reg
logger.debug("build regularizer ({}) success..".format(reg))
# step3 build optimizer
optim_name = config.pop('name')
......@@ -62,8 +65,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters=None):
else:
grad_clip = None
optim = getattr(optimizer, optim_name)(learning_rate=lr,
weight_decay=reg,
grad_clip=grad_clip,
**config)(parameters=parameters)
**config)(model_list=model_list)
logger.debug("build optimizer ({}) success..".format(optim))
return optim, lr
......@@ -35,14 +35,15 @@ class Momentum(object):
weight_decay=None,
grad_clip=None,
multi_precision=False):
super(Momentum, self).__init__()
super().__init__()
self.learning_rate = learning_rate
self.momentum = momentum
self.weight_decay = weight_decay
self.grad_clip = grad_clip
self.multi_precision = multi_precision
def __call__(self, parameters):
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
opt = optim.Momentum(
learning_rate=self.learning_rate,
momentum=self.momentum,
......@@ -77,7 +78,8 @@ class Adam(object):
self.lazy_mode = lazy_mode
self.multi_precision = multi_precision
def __call__(self, parameters):
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
opt = optim.Adam(
learning_rate=self.learning_rate,
beta1=self.beta1,
......@@ -112,7 +114,7 @@ class RMSProp(object):
weight_decay=None,
grad_clip=None,
multi_precision=False):
super(RMSProp, self).__init__()
super().__init__()
self.learning_rate = learning_rate
self.momentum = momentum
self.rho = rho
......@@ -120,7 +122,8 @@ class RMSProp(object):
self.weight_decay = weight_decay
self.grad_clip = grad_clip
def __call__(self, parameters):
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
opt = optim.RMSProp(
learning_rate=self.learning_rate,
momentum=self.momentum,
......@@ -130,3 +133,57 @@ class RMSProp(object):
grad_clip=self.grad_clip,
parameters=parameters)
return opt
class AdamW(object):
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
weight_decay=None,
multi_precision=False,
grad_clip=None,
no_weight_decay_name=None,
one_dim_param_no_weight_decay=False,
**args):
super().__init__()
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.grad_clip = grad_clip
self.weight_decay = weight_decay
self.multi_precision = multi_precision
self.no_weight_decay_name_list = no_weight_decay_name.split(
) if no_weight_decay_name else []
self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
def __call__(self, model_list):
parameters = sum([m.parameters() for m in model_list], [])
self.no_weight_decay_param_name_list = [
p.name for model in model_list for n, p in model.named_parameters()
if any(nd in n for nd in self.no_weight_decay_name_list)
]
if self.one_dim_param_no_weight_decay:
self.no_weight_decay_param_name_list += [
p.name for model in model_list
for n, p in model.named_parameters() if len(p.shape) == 1
]
opt = optim.AdamW(
learning_rate=self.learning_rate,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon,
parameters=parameters,
weight_decay=self.weight_decay,
multi_precision=self.multi_precision,
grad_clip=self.grad_clip,
apply_decay_param_fun=self._apply_decay_param_fun)
return opt
def _apply_decay_param_fun(self, name):
return name not in self.no_weight_decay_param_name_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册