未验证 提交 1755a2b2 编写于 作者: W Wenyu 提交者: GitHub

add w/o weight decay params groups (#4337)

上级 1bf6d854
......@@ -15,3 +15,4 @@ OptimizerBuilder:
optimizer:
type: AdamW
weight_decay: 0.05
without_weight_decay_params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm']
......@@ -115,8 +115,7 @@ class Trainer(object):
if self.mode == 'train':
steps_per_epoch = len(self.loader)
self.lr = create('LearningRate')(steps_per_epoch)
self.optimizer = create('OptimizerBuilder')(self.lr,
self.model.parameters())
self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
self._nranks = dist.get_world_size()
self._local_rank = dist.get_rank()
......
......@@ -225,7 +225,7 @@ class OptimizerBuilder():
self.regularizer = regularizer
self.optimizer = optimizer
def __call__(self, learning_rate, params=None):
def __call__(self, learning_rate, model=None):
if self.clip_grad_by_norm is not None:
grad_clip = nn.ClipGradByGlobalNorm(
clip_norm=self.clip_grad_by_norm)
......@@ -244,6 +244,25 @@ class OptimizerBuilder():
if optim_type != 'AdamW':
optim_args['weight_decay'] = regularization
op = getattr(optimizer, optim_type)
if 'without_weight_decay_params' in optim_args:
keys = optim_args['without_weight_decay_params']
params = [{
'params': [
p for n, p in model.named_parameters()
if any([k in n for k in keys])
],
'weight_decay': 0.
}, {
'params': [
p for n, p in model.named_parameters()
if all([k not in n for k in keys])
]
}]
del optim_args['without_weight_decay_params']
else:
params = model.parameters()
return op(learning_rate=learning_rate,
parameters=params,
grad_clip=grad_clip,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册