未验证 提交 2f343f5a 编写于 作者: S shangliang Xu 提交者: GitHub

fix ema_filter_no_grad (#7974)

上级 fdbfbec6
......@@ -60,6 +60,12 @@ class ModelEMA(object):
self.cycle_epoch = cycle_epoch
self.ema_black_list = self._match_ema_black_list(
model.state_dict().keys(), ema_black_list)
bn_states_names = get_bn_running_state_names(model)
if ema_filter_no_grad:
for n, p in model.named_parameters():
if p.stop_gradient and n not in bn_states_names:
self.ema_black_list.add(n)
self.state_dict = dict()
for k, v in model.state_dict().items():
if k in self.ema_black_list:
......@@ -67,12 +73,6 @@ class ModelEMA(object):
else:
self.state_dict[k] = paddle.zeros_like(v)
bn_states_names = get_bn_running_state_names(model)
if ema_filter_no_grad:
for n, p in model.named_parameters():
if p.stop_gradient == True and n not in bn_states_names:
self.ema_black_list.append(n)
self._model_state = {
k: weakref.ref(p)
for k, p in model.state_dict().items()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册