未验证 提交 10d647f0 编写于 作者: D Double_V 提交者: GitHub

optimizer: fix dict object has no attribute trainable (#6311)

* add FGD distill code

* add configs

* add doc

* fix pretrain

* pre-commit

* fix ci

* fix readme

* fix readme

* fix ci

* fix param groups

* fix
上级 d2e7bd38
...@@ -341,7 +341,8 @@ class OptimizerBuilder(): ...@@ -341,7 +341,8 @@ class OptimizerBuilder():
_params = { _params = {
n: p n: p
for n, p in model.named_parameters() for n, p in model.named_parameters()
if any([k in n for k in group['params']]) if any([k in n
for k in group['params']] and p.trainable is True)
} }
_group = group.copy() _group = group.copy()
_group.update({'params': list(_params.values())}) _group.update({'params': list(_params.values())})
...@@ -350,7 +351,8 @@ class OptimizerBuilder(): ...@@ -350,7 +351,8 @@ class OptimizerBuilder():
visited.extend(list(_params.keys())) visited.extend(list(_params.keys()))
ext_params = [ ext_params = [
p for n, p in model.named_parameters() if n not in visited p for n, p in model.named_parameters()
if n not in visited and p.trainable is True
] ]
if len(ext_params) < len(model.parameters()): if len(ext_params) < len(model.parameters()):
...@@ -360,10 +362,10 @@ class OptimizerBuilder(): ...@@ -360,10 +362,10 @@ class OptimizerBuilder():
raise RuntimeError raise RuntimeError
else: else:
params = model.parameters() _params = model.parameters()
params = [param for param in _params if param.trainable is True]
train_params = [param for param in params if param.trainable is True]
return op(learning_rate=learning_rate, return op(learning_rate=learning_rate,
parameters=train_params, parameters=params,
grad_clip=grad_clip, grad_clip=grad_clip,
**optim_args) **optim_args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册