diff --git a/ppcls/utils/ema.py b/ppcls/utils/ema.py index 9a3b65ccda5ec2cdb8e7384b78c60c1e0738481f..7f356b88061ce6dc33d94f4940b2a2dc64cb880d 100644 --- a/ppcls/utils/ema.py +++ b/ppcls/utils/ema.py @@ -32,11 +32,14 @@ class ExponentialMovingAverage(): @paddle.no_grad() def _update(self, model, update_fn): - for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): - ema_v.set_value(update_fn(ema_v, model_v)) + for ema_v, model_v in zip(self.module.state_dict().values(), + model.state_dict().values()): + ema_v.set_value(update_fn(ema_v.numpy(), model_v.numpy())) def update(self, model): - self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + self._update( + model, + update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) def set(self, model): - self._update(model, update_fn=lambda e, m: m) \ No newline at end of file + self._update(model, update_fn=lambda e, m: m)