diff --git a/configs/faster_rcnn/_base_/optimizer_swin_1x.yml b/configs/faster_rcnn/_base_/optimizer_swin_1x.yml index 5b76183a2e30719306d93f71ff7589eeea69048c..5c1c6679940834f8ff3bb985bb44f6dc2f281428 100644 --- a/configs/faster_rcnn/_base_/optimizer_swin_1x.yml +++ b/configs/faster_rcnn/_base_/optimizer_swin_1x.yml @@ -15,4 +15,8 @@ OptimizerBuilder: optimizer: type: AdamW weight_decay: 0.05 - without_weight_decay_params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm'] + + param_groups: + - + params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm'] + weight_decay: 0. diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py index df9ade2e792ce1db1f4132d04e610777fbda762a..bed95fab87eab648b86498477057aefb17d07dd6 100644 --- a/ppdet/optimizer.py +++ b/ppdet/optimizer.py @@ -249,21 +249,37 @@ class OptimizerBuilder(): optim_args['weight_decay'] = regularization op = getattr(optimizer, optim_type) - if 'without_weight_decay_params' in optim_args: - keys = optim_args['without_weight_decay_params'] - params = [{ - 'params': [ - p for n, p in model.named_parameters() - if any([k in n for k in keys]) - ], - 'weight_decay': 0. - }, { - 'params': [ - p for n, p in model.named_parameters() - if all([k not in n for k in keys]) - ] - }] - del optim_args['without_weight_decay_params'] + if 'param_groups' in optim_args: + assert isinstance(optim_args['param_groups'], list), '' + + param_groups = optim_args.pop('param_groups') + + params, visited = [], [] + for group in param_groups: + assert isinstance(group, + dict) and 'params' in group and isinstance( + group['params'], list), '' + _params = { + n: p + for n, p in model.named_parameters() + if any([k in n for k in group['params']]) + } + _group = group.copy() + _group.update({'params': list(_params.values())}) + + params.append(_group) + visited.extend(list(_params.keys())) + + ext_params = [ + p for n, p in model.named_parameters() if n not in visited + ] + + if len(ext_params) < len(model.parameters()): + params.append({'params': ext_params}) + + elif len(ext_params) > len(model.parameters()): + raise RuntimeError + else: params = model.parameters()