提交 c32e2b09 编写于 作者: Y Yang Nie 提交者: Tingquan Gao

Revert "Speedup EMA"

This reverts commit 35fc732dadac4761852b18512b5c5df8785e36df.
上级 001cdb09
......@@ -32,14 +32,11 @@ class ExponentialMovingAverage():
@paddle.no_grad()
def _update(self, model, update_fn):
for ema_v, model_v in zip(self.module.state_dict().values(),
model.state_dict().values()):
ema_v.set_value(update_fn(ema_v.numpy(), model_v.numpy()))
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
ema_v.set_value(update_fn(ema_v, model_v))
def update(self, model):
self._update(
model,
update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
self._update(model, update_fn=lambda e, m: m)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册