From 1a46f29fe1b42c6350cadf2b5ace66a217cdf145 Mon Sep 17 00:00:00 2001 From: Wenyu Date: Fri, 31 Dec 2021 15:02:26 +0800 Subject: [PATCH] params group (#4955) --- .../faster_rcnn/_base_/optimizer_swin_1x.yml | 6 ++- ppdet/optimizer.py | 46 +++++++++++++------ 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/configs/faster_rcnn/_base_/optimizer_swin_1x.yml b/configs/faster_rcnn/_base_/optimizer_swin_1x.yml index 5b76183a2..5c1c66799 100644 --- a/configs/faster_rcnn/_base_/optimizer_swin_1x.yml +++ b/configs/faster_rcnn/_base_/optimizer_swin_1x.yml @@ -15,4 +15,8 @@ OptimizerBuilder: optimizer: type: AdamW weight_decay: 0.05 - without_weight_decay_params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm'] + + param_groups: + - + params: ['absolute_pos_embed', 'relative_position_bias_table', 'norm'] + weight_decay: 0. diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py index df9ade2e7..bed95fab8 100644 --- a/ppdet/optimizer.py +++ b/ppdet/optimizer.py @@ -249,21 +249,37 @@ class OptimizerBuilder(): optim_args['weight_decay'] = regularization op = getattr(optimizer, optim_type) - if 'without_weight_decay_params' in optim_args: - keys = optim_args['without_weight_decay_params'] - params = [{ - 'params': [ - p for n, p in model.named_parameters() - if any([k in n for k in keys]) - ], - 'weight_decay': 0. - }, { - 'params': [ - p for n, p in model.named_parameters() - if all([k not in n for k in keys]) - ] - }] - del optim_args['without_weight_decay_params'] + if 'param_groups' in optim_args: + assert isinstance(optim_args['param_groups'], list), '' + + param_groups = optim_args.pop('param_groups') + + params, visited = [], [] + for group in param_groups: + assert isinstance(group, + dict) and 'params' in group and isinstance( + group['params'], list), '' + _params = { + n: p + for n, p in model.named_parameters() + if any([k in n for k in group['params']]) + } + _group = group.copy() + _group.update({'params': list(_params.values())}) + + params.append(_group) + visited.extend(list(_params.keys())) + + ext_params = [ + p for n, p in model.named_parameters() if n not in visited + ] + + if len(ext_params) < len(model.parameters()): + params.append({'params': ext_params}) + + elif len(ext_params) > len(model.parameters()): + raise RuntimeError + else: params = model.parameters() -- GitLab