未验证 提交 b9a2d36d 编写于 作者: F Feng Ni 提交者: GitHub

fix amp training (#6456)

上级 4e60a034
...@@ -449,7 +449,7 @@ class Trainer(object): ...@@ -449,7 +449,7 @@ class Trainer(object):
model, paddle. model, paddle.
DataParallel) and use_fused_allreduce_gradients: DataParallel) and use_fused_allreduce_gradients:
with model.no_sync(): with model.no_sync():
with amp.auto_cast( with paddle.amp.auto_cast(
enable=self.cfg.use_gpus, enable=self.cfg.use_gpus,
level=self.amp_level): level=self.amp_level):
# model forward # model forward
...@@ -461,7 +461,7 @@ class Trainer(object): ...@@ -461,7 +461,7 @@ class Trainer(object):
fused_allreduce_gradients( fused_allreduce_gradients(
list(model.parameters()), None) list(model.parameters()), None)
else: else:
with amp.auto_cast( with paddle.amp.auto_cast(
enable=self.cfg.use_gpu, level=self.amp_level): enable=self.cfg.use_gpu, level=self.amp_level):
# model forward # model forward
outputs = model(data) outputs = model(data)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册