未验证 提交 10d647f0 编写于 作者: D Double_V 提交者: GitHub

optimizer: fix dict object has no attribute trainable (#6311)

* add FGD distill code

* add configs

* add doc

* fix pretrain

* pre-commit

* fix ci

* fix readme

* fix readme

* fix ci

* fix param groups

* fix
上级 d2e7bd38
......@@ -341,7 +341,8 @@ class OptimizerBuilder():
_params = {
n: p
for n, p in model.named_parameters()
if any([k in n for k in group['params']])
if any([k in n
for k in group['params']] and p.trainable is True)
}
_group = group.copy()
_group.update({'params': list(_params.values())})
......@@ -350,7 +351,8 @@ class OptimizerBuilder():
visited.extend(list(_params.keys()))
ext_params = [
p for n, p in model.named_parameters() if n not in visited
p for n, p in model.named_parameters()
if n not in visited and p.trainable is True
]
if len(ext_params) < len(model.parameters()):
......@@ -360,10 +362,10 @@ class OptimizerBuilder():
raise RuntimeError
else:
params = model.parameters()
_params = model.parameters()
params = [param for param in _params if param.trainable is True]
train_params = [param for param in params if param.trainable is True]
return op(learning_rate=learning_rate,
parameters=train_params,
parameters=params,
grad_clip=grad_clip,
**optim_args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册