diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 05ed06284f3516187d56768e095cb98346626d0e..dc63f83eccfcd89ce2280bd2bc25f487800950bd 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -449,7 +449,7 @@ class Trainer(object): model, paddle. DataParallel) and use_fused_allreduce_gradients: with model.no_sync(): - with amp.auto_cast( + with paddle.amp.auto_cast( enable=self.cfg.use_gpus, level=self.amp_level): # model forward @@ -461,7 +461,7 @@ class Trainer(object): fused_allreduce_gradients( list(model.parameters()), None) else: - with amp.auto_cast( + with paddle.amp.auto_cast( enable=self.cfg.use_gpu, level=self.amp_level): # model forward outputs = model(data)