提交 f67cfe2c 编写于 作者: G gaotingquan 提交者: cuicheng01

fix ema: set_value() -> paddle.assign()

上级 2823e48b
......@@ -32,11 +32,14 @@ class ExponentialMovingAverage():
@paddle.no_grad()
def _update(self, model, update_fn):
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
ema_v.set_value(update_fn(ema_v, model_v))
for ema_v, model_v in zip(self.module.state_dict().values(),
model.state_dict().values()):
paddle.assign(update_fn(ema_v, model_v), ema_v)
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
self._update(
model,
update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册