提交 f67cfe2c 编写于 作者: G gaotingquan 提交者: cuicheng01

fix ema: set_value() -> paddle.assign()

上级 2823e48b
...@@ -32,11 +32,14 @@ class ExponentialMovingAverage(): ...@@ -32,11 +32,14 @@ class ExponentialMovingAverage():
@paddle.no_grad() @paddle.no_grad()
def _update(self, model, update_fn): def _update(self, model, update_fn):
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): for ema_v, model_v in zip(self.module.state_dict().values(),
ema_v.set_value(update_fn(ema_v, model_v)) model.state_dict().values()):
paddle.assign(update_fn(ema_v, model_v), ema_v)
def update(self, model): def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) self._update(
model,
update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model): def set(self, model):
self._update(model, update_fn=lambda e, m: m) self._update(model, update_fn=lambda e, m: m)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册