diff --git a/ppdet/optimizer.py b/ppdet/optimizer.py index bed95fab87eab648b86498477057aefb17d07dd6..a9a64adf3ba20c1adf656c28ec231f0e34a715a7 100644 --- a/ppdet/optimizer.py +++ b/ppdet/optimizer.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function import math +import weakref import paddle import paddle.nn as nn @@ -319,19 +320,31 @@ class ModelEMA(object): self.use_thres_step = use_thres_step self.cycle_epoch = cycle_epoch + self._model_state = { + k: weakref.ref(p) + for k, p in model.state_dict().items() + } + def reset(self): self.step = 0 self.epoch = 0 for k, v in self.state_dict.items(): self.state_dict[k] = paddle.zeros_like(v) - def update(self, model): + def update(self, model=None): if self.use_thres_step: decay = min(self.decay, (1 + self.step) / (10 + self.step)) else: decay = self.decay self._decay = decay - model_dict = model.state_dict() + + if model is not None: + model_dict = model.state_dict() + else: + model_dict = {k: p() for k, p in self._model_state.items()} + assert all( + [v is not None for _, v in model_dict.items()]), 'python gc.' + for k, v in self.state_dict.items(): v = decay * v + (1 - decay) * model_dict[k] v.stop_gradient = True