From cabdc251fef6b98d57542e0434d2dcc34fedad64 Mon Sep 17 00:00:00 2001 From: Yang Nie Date: Thu, 2 Mar 2023 03:40:25 +0800 Subject: [PATCH] Speedup EMA --- ppcls/utils/ema.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ppcls/utils/ema.py b/ppcls/utils/ema.py index 9a3b65cc..7f356b88 100644 --- a/ppcls/utils/ema.py +++ b/ppcls/utils/ema.py @@ -32,11 +32,14 @@ class ExponentialMovingAverage(): @paddle.no_grad() def _update(self, model, update_fn): - for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): - ema_v.set_value(update_fn(ema_v, model_v)) + for ema_v, model_v in zip(self.module.state_dict().values(), + model.state_dict().values()): + ema_v.set_value(update_fn(ema_v.numpy(), model_v.numpy())) def update(self, model): - self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + self._update( + model, + update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) def set(self, model): - self._update(model, update_fn=lambda e, m: m) \ No newline at end of file + self._update(model, update_fn=lambda e, m: m) -- GitLab