From e62c687486c0881759ffd49b736afb5ccaa3d717 Mon Sep 17 00:00:00 2001 From: Wenyu Date: Mon, 10 Jan 2022 21:41:25 +0800 Subject: [PATCH] ema update (#5089) --- ppdet/optimizer.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py index bed95fab8..a9a64adf3 100644 --- a/ppdet/optimizer.py +++ b/ppdet/optimizer.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function import math +import weakref import paddle import paddle.nn as nn @@ -319,19 +320,31 @@ class ModelEMA(object): self.use_thres_step = use_thres_step self.cycle_epoch = cycle_epoch + self._model_state = { + k: weakref.ref(p) + for k, p in model.state_dict().items() + } + def reset(self): self.step = 0 self.epoch = 0 for k, v in self.state_dict.items(): self.state_dict[k] = paddle.zeros_like(v) - def update(self, model): + def update(self, model=None): if self.use_thres_step: decay = min(self.decay, (1 + self.step) / (10 + self.step)) else: decay = self.decay self._decay = decay - model_dict = model.state_dict() + + if model is not None: + model_dict = model.state_dict() + else: + model_dict = {k: p() for k, p in self._model_state.items()} + assert all( + [v is not None for _, v in model_dict.items()]), 'python gc.' + for k, v in self.state_dict.items(): v = decay * v + (1 - decay) * model_dict[k] v.stop_gradient = True -- GitLab