From b9a2d36d656285ce510fab368138257800cb66a2 Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Mon, 18 Jul 2022 10:36:26 +0800 Subject: [PATCH] fix amp training (#6456) --- ppdet/engine/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 05ed06284..dc63f83ec 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -449,7 +449,7 @@ class Trainer(object): model, paddle. DataParallel) and use_fused_allreduce_gradients: with model.no_sync(): - with amp.auto_cast( + with paddle.amp.auto_cast( enable=self.cfg.use_gpus, level=self.amp_level): # model forward @@ -461,7 +461,7 @@ class Trainer(object): fused_allreduce_gradients( list(model.parameters()), None) else: - with amp.auto_cast( + with paddle.amp.auto_cast( enable=self.cfg.use_gpu, level=self.amp_level): # model forward outputs = model(data) -- GitLab