未验证 提交 47d7ee5a 编写于 作者: W Wenyu 提交者: GitHub

fix v dtype in ema black list (#8466)

上级 9ed7ad25
......@@ -69,7 +69,7 @@ class ModelEMA(object):
self.state_dict = dict()
for k, v in model.state_dict().items():
if k in self.ema_black_list:
self.state_dict[k] = v.astype('float32')
self.state_dict[k] = v
else:
self.state_dict[k] = paddle.zeros_like(v, dtype='float32')
......@@ -127,7 +127,7 @@ class ModelEMA(object):
for k, v in self.state_dict.items():
if k in self.ema_black_list:
v.stop_gradient = True
state_dict[k] = v.astype(model_dict[k].dtype)
state_dict[k] = v
else:
if self.ema_decay_type != 'exponential':
v = v / (1 - self._decay**self.step)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册