提交 cabdc251 编写于 作者: Y Yang Nie 提交者: Tingquan Gao

Speedup EMA

上级 541326ea
......@@ -32,11 +32,14 @@ class ExponentialMovingAverage():
@paddle.no_grad()
def _update(self, model, update_fn):
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
ema_v.set_value(update_fn(ema_v, model_v))
for ema_v, model_v in zip(self.module.state_dict().values(),
model.state_dict().values()):
ema_v.set_value(update_fn(ema_v.numpy(), model_v.numpy()))
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
self._update(
model,
update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册