diff --git a/ppdet/optimizer/optimizer.py b/ppdet/optimizer/optimizer.py index bcba182d3229c74206b8623b003d0e3e72160506..e6d5fe44ae2ccbf9f852d00a812ba4f6e98465e0 100644 --- a/ppdet/optimizer/optimizer.py +++ b/ppdet/optimizer/optimizer.py @@ -341,7 +341,8 @@ class OptimizerBuilder(): _params = { n: p for n, p in model.named_parameters() - if any([k in n for k in group['params']]) + if any([k in n + for k in group['params']] and p.trainable is True) } _group = group.copy() _group.update({'params': list(_params.values())}) @@ -350,7 +351,8 @@ class OptimizerBuilder(): visited.extend(list(_params.keys())) ext_params = [ - p for n, p in model.named_parameters() if n not in visited + p for n, p in model.named_parameters() + if n not in visited and p.trainable is True ] if len(ext_params) < len(model.parameters()): @@ -360,10 +362,10 @@ class OptimizerBuilder(): raise RuntimeError else: - params = model.parameters() + _params = model.parameters() + params = [param for param in _params if param.trainable is True] - train_params = [param for param in params if param.trainable is True] return op(learning_rate=learning_rate, - parameters=train_params, + parameters=params, grad_clip=grad_clip, **optim_args)