diff --git a/ppdet/optimizer/ema.py b/ppdet/optimizer/ema.py index e81214f47057a75ed905f61a26bb83a2f1decec6..84cc9ac285d72865e4ec65b86528310bc9d7a8bc 100644 --- a/ppdet/optimizer/ema.py +++ b/ppdet/optimizer/ema.py @@ -69,7 +69,7 @@ class ModelEMA(object): self.state_dict = dict() for k, v in model.state_dict().items(): if k in self.ema_black_list: - self.state_dict[k] = v.astype('float32') + self.state_dict[k] = v else: self.state_dict[k] = paddle.zeros_like(v, dtype='float32') @@ -127,7 +127,7 @@ class ModelEMA(object): for k, v in self.state_dict.items(): if k in self.ema_black_list: v.stop_gradient = True - state_dict[k] = v.astype(model_dict[k].dtype) + state_dict[k] = v else: if self.ema_decay_type != 'exponential': v = v / (1 - self._decay**self.step)