diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index f186afe8e458d2923bebb6ae4559c63a7c8cc6bb..fa9167f05b0f8554cd2650a337a51bd31c355b6c 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -373,7 +373,7 @@ class Trainer(object): # enabel auto mixed precision mode if self.cfg.get('amp', False): scaler = amp.GradScaler( - enable=self.cfg.use_gpu, init_loss_scaling=1024) + enable=self.cfg.use_gpu or self.cfg.use_npu, init_loss_scaling=1024) self.status.update({ 'epoch_id': self.start_epoch,