未验证 提交 b9a2d36d 编写于 作者: F Feng Ni 提交者: GitHub

fix amp training (#6456)

上级 4e60a034
......@@ -449,7 +449,7 @@ class Trainer(object):
model, paddle.
DataParallel) and use_fused_allreduce_gradients:
with model.no_sync():
with amp.auto_cast(
with paddle.amp.auto_cast(
enable=self.cfg.use_gpus,
level=self.amp_level):
# model forward
......@@ -461,7 +461,7 @@ class Trainer(object):
fused_allreduce_gradients(
list(model.parameters()), None)
else:
with amp.auto_cast(
with paddle.amp.auto_cast(
enable=self.cfg.use_gpu, level=self.amp_level):
# model forward
outputs = model(data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册