diff --git a/ppcls/utils/ema.py b/ppcls/utils/ema.py index 7f356b88061ce6dc33d94f4940b2a2dc64cb880d..9a3b65ccda5ec2cdb8e7384b78c60c1e0738481f 100644 --- a/ppcls/utils/ema.py +++ b/ppcls/utils/ema.py @@ -32,14 +32,11 @@ class ExponentialMovingAverage(): @paddle.no_grad() def _update(self, model, update_fn): - for ema_v, model_v in zip(self.module.state_dict().values(), - model.state_dict().values()): - ema_v.set_value(update_fn(ema_v.numpy(), model_v.numpy())) + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + ema_v.set_value(update_fn(ema_v, model_v)) def update(self, model): - self._update( - model, - update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) def set(self, model): - self._update(model, update_fn=lambda e, m: m) + self._update(model, update_fn=lambda e, m: m) \ No newline at end of file