未验证 提交 1a46f29f 编写于 作者: W Wenyu 提交者: GitHub

params group (#4955)

上级 54f2411f
......@@ -15,4 +15,8 @@ OptimizerBuilder:
optimizer:
type: AdamW
weight_decay: 0.05
without_weight_decay_params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm']
param_groups:
-
params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm']
weight_decay: 0.
......@@ -249,21 +249,37 @@ class OptimizerBuilder():
optim_args['weight_decay'] = regularization
op = getattr(optimizer, optim_type)
if 'without_weight_decay_params' in optim_args:
keys = optim_args['without_weight_decay_params']
params = [{
'params': [
p for n, p in model.named_parameters()
if any([k in n for k in keys])
],
'weight_decay': 0.
}, {
'params': [
p for n, p in model.named_parameters()
if all([k not in n for k in keys])
]
}]
del optim_args['without_weight_decay_params']
if 'param_groups' in optim_args:
assert isinstance(optim_args['param_groups'], list), ''
param_groups = optim_args.pop('param_groups')
params, visited = [], []
for group in param_groups:
assert isinstance(group,
dict) and 'params' in group and isinstance(
group['params'], list), ''
_params = {
n: p
for n, p in model.named_parameters()
if any([k in n for k in group['params']])
}
_group = group.copy()
_group.update({'params': list(_params.values())})
params.append(_group)
visited.extend(list(_params.keys()))
ext_params = [
p for n, p in model.named_parameters() if n not in visited
]
if len(ext_params) < len(model.parameters()):
params.append({'params': ext_params})
elif len(ext_params) > len(model.parameters()):
raise RuntimeError
else:
params = model.parameters()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册