From 10d647f0e9ae926a332f35a163c71d40839d7b13 Mon Sep 17 00:00:00 2001 From: Double_V Date: Thu, 30 Jun 2022 16:18:54 +0800 Subject: [PATCH] optimizer: fix dict object has no attribute trainable (#6311) * add FGD distill code * add configs * add doc * fix pretrain * pre-commit * fix ci * fix readme * fix readme * fix ci * fix param groups * fix --- ppdet/optimizer/optimizer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/ppdet/optimizer/optimizer.py b/ppdet/optimizer/optimizer.py index bcba182d3..e6d5fe44a 100644 --- a/ppdet/optimizer/optimizer.py +++ b/ppdet/optimizer/optimizer.py @@ -341,7 +341,8 @@ class OptimizerBuilder(): _params = { n: p for n, p in model.named_parameters() - if any([k in n for k in group['params']]) + if any([k in n + for k in group['params']] and p.trainable is True) } _group = group.copy() _group.update({'params': list(_params.values())}) @@ -350,7 +351,8 @@ class OptimizerBuilder(): visited.extend(list(_params.keys())) ext_params = [ - p for n, p in model.named_parameters() if n not in visited + p for n, p in model.named_parameters() + if n not in visited and p.trainable is True ] if len(ext_params) < len(model.parameters()): @@ -360,10 +362,10 @@ class OptimizerBuilder(): raise RuntimeError else: - params = model.parameters() + _params = model.parameters() + params = [param for param in _params if param.trainable is True] - train_params = [param for param in params if param.trainable is True] return op(learning_rate=learning_rate, - parameters=train_params, + parameters=params, grad_clip=grad_clip, **optim_args) -- GitLab