未验证 提交 47d7ee5a 编写于 作者: W Wenyu 提交者: GitHub

fix v dtype in ema black list (#8466)

上级 9ed7ad25
...@@ -69,7 +69,7 @@ class ModelEMA(object): ...@@ -69,7 +69,7 @@ class ModelEMA(object):
self.state_dict = dict() self.state_dict = dict()
for k, v in model.state_dict().items(): for k, v in model.state_dict().items():
if k in self.ema_black_list: if k in self.ema_black_list:
self.state_dict[k] = v.astype('float32') self.state_dict[k] = v
else: else:
self.state_dict[k] = paddle.zeros_like(v, dtype='float32') self.state_dict[k] = paddle.zeros_like(v, dtype='float32')
...@@ -127,7 +127,7 @@ class ModelEMA(object): ...@@ -127,7 +127,7 @@ class ModelEMA(object):
for k, v in self.state_dict.items(): for k, v in self.state_dict.items():
if k in self.ema_black_list: if k in self.ema_black_list:
v.stop_gradient = True v.stop_gradient = True
state_dict[k] = v.astype(model_dict[k].dtype) state_dict[k] = v
else: else:
if self.ema_decay_type != 'exponential': if self.ema_decay_type != 'exponential':
v = v / (1 - self._decay**self.step) v = v / (1 - self._decay**self.step)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册