未验证 提交 1a46f29f 编写于 作者: W Wenyu 提交者: GitHub

params group (#4955)

上级 54f2411f
...@@ -15,4 +15,8 @@ OptimizerBuilder: ...@@ -15,4 +15,8 @@ OptimizerBuilder:
optimizer: optimizer:
type: AdamW type: AdamW
weight_decay: 0.05 weight_decay: 0.05
without_weight_decay_params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm']
param_groups:
-
params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm']
weight_decay: 0.
...@@ -249,21 +249,37 @@ class OptimizerBuilder(): ...@@ -249,21 +249,37 @@ class OptimizerBuilder():
optim_args['weight_decay'] = regularization optim_args['weight_decay'] = regularization
op = getattr(optimizer, optim_type) op = getattr(optimizer, optim_type)
if 'without_weight_decay_params' in optim_args: if 'param_groups' in optim_args:
keys = optim_args['without_weight_decay_params'] assert isinstance(optim_args['param_groups'], list), ''
params = [{
'params': [ param_groups = optim_args.pop('param_groups')
p for n, p in model.named_parameters()
if any([k in n for k in keys]) params, visited = [], []
], for group in param_groups:
'weight_decay': 0. assert isinstance(group,
}, { dict) and 'params' in group and isinstance(
'params': [ group['params'], list), ''
p for n, p in model.named_parameters() _params = {
if all([k not in n for k in keys]) n: p
] for n, p in model.named_parameters()
}] if any([k in n for k in group['params']])
del optim_args['without_weight_decay_params'] }
_group = group.copy()
_group.update({'params': list(_params.values())})
params.append(_group)
visited.extend(list(_params.keys()))
ext_params = [
p for n, p in model.named_parameters() if n not in visited
]
if len(ext_params) < len(model.parameters()):
params.append({'params': ext_params})
elif len(ext_params) > len(model.parameters()):
raise RuntimeError
else: else:
params = model.parameters() params = model.parameters()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册