From f4c320ec6855dbe8d3e73ceab03532d23dacad2a Mon Sep 17 00:00:00 2001 From: Wenyu Date: Mon, 24 Jul 2023 18:53:27 +0800 Subject: [PATCH] fix v dtype in ema blacklist (#8465) --- ppdet/optimizer/ema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppdet/optimizer/ema.py b/ppdet/optimizer/ema.py index e81214f47..84cc9ac28 100644 --- a/ppdet/optimizer/ema.py +++ b/ppdet/optimizer/ema.py @@ -69,7 +69,7 @@ class ModelEMA(object): self.state_dict = dict() for k, v in model.state_dict().items(): if k in self.ema_black_list: - self.state_dict[k] = v.astype('float32') + self.state_dict[k] = v else: self.state_dict[k] = paddle.zeros_like(v, dtype='float32') @@ -127,7 +127,7 @@ class ModelEMA(object): for k, v in self.state_dict.items(): if k in self.ema_black_list: v.stop_gradient = True - state_dict[k] = v.astype(model_dict[k].dtype) + state_dict[k] = v else: if self.ema_decay_type != 'exponential': v = v / (1 - self._decay**self.step) -- GitLab