From f67cfe2c2a6b1386b7bff301b8aca8c13fc4a83b Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Thu, 25 May 2023 02:37:47 +0000 Subject: [PATCH] fix ema: set_value() -> paddle.assign() --- ppcls/utils/ema.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/ppcls/utils/ema.py b/ppcls/utils/ema.py index 9a3b65cc..8cdb3dfa 100644 --- a/ppcls/utils/ema.py +++ b/ppcls/utils/ema.py @@ -32,11 +32,14 @@ class ExponentialMovingAverage(): @paddle.no_grad() def _update(self, model, update_fn): - for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): - ema_v.set_value(update_fn(ema_v, model_v)) + for ema_v, model_v in zip(self.module.state_dict().values(), + model.state_dict().values()): + paddle.assign(update_fn(ema_v, model_v), ema_v) def update(self, model): - self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + self._update( + model, + update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) def set(self, model): - self._update(model, update_fn=lambda e, m: m) \ No newline at end of file + self._update(model, update_fn=lambda e, m: m) -- GitLab