未验证 提交 e62c6874 编写于 作者: W Wenyu 提交者: GitHub

ema update (#5089)

上级 d94481c2
......@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function
import math
import weakref
import paddle
import paddle.nn as nn
......@@ -319,19 +320,31 @@ class ModelEMA(object):
self.use_thres_step = use_thres_step
self.cycle_epoch = cycle_epoch
self._model_state = {
k: weakref.ref(p)
for k, p in model.state_dict().items()
}
def reset(self):
self.step = 0
self.epoch = 0
for k, v in self.state_dict.items():
self.state_dict[k] = paddle.zeros_like(v)
def update(self, model):
def update(self, model=None):
if self.use_thres_step:
decay = min(self.decay, (1 + self.step) / (10 + self.step))
else:
decay = self.decay
self._decay = decay
model_dict = model.state_dict()
if model is not None:
model_dict = model.state_dict()
else:
model_dict = {k: p() for k, p in self._model_state.items()}
assert all(
[v is not None for _, v in model_dict.items()]), 'python gc.'
for k, v in self.state_dict.items():
v = decay * v + (1 - decay) * model_dict[k]
v.stop_gradient = True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册