未验证 提交 e62c6874 编写于 作者: W Wenyu 提交者: GitHub

ema update (#5089)

上级 d94481c2
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import math import math
import weakref
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
...@@ -319,19 +320,31 @@ class ModelEMA(object): ...@@ -319,19 +320,31 @@ class ModelEMA(object):
self.use_thres_step = use_thres_step self.use_thres_step = use_thres_step
self.cycle_epoch = cycle_epoch self.cycle_epoch = cycle_epoch
self._model_state = {
k: weakref.ref(p)
for k, p in model.state_dict().items()
}
def reset(self): def reset(self):
self.step = 0 self.step = 0
self.epoch = 0 self.epoch = 0
for k, v in self.state_dict.items(): for k, v in self.state_dict.items():
self.state_dict[k] = paddle.zeros_like(v) self.state_dict[k] = paddle.zeros_like(v)
def update(self, model): def update(self, model=None):
if self.use_thres_step: if self.use_thres_step:
decay = min(self.decay, (1 + self.step) / (10 + self.step)) decay = min(self.decay, (1 + self.step) / (10 + self.step))
else: else:
decay = self.decay decay = self.decay
self._decay = decay self._decay = decay
if model is not None:
model_dict = model.state_dict() model_dict = model.state_dict()
else:
model_dict = {k: p() for k, p in self._model_state.items()}
assert all(
[v is not None for _, v in model_dict.items()]), 'python gc.'
for k, v in self.state_dict.items(): for k, v in self.state_dict.items():
v = decay * v + (1 - decay) * model_dict[k] v = decay * v + (1 - decay) * model_dict[k]
v.stop_gradient = True v.stop_gradient = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册