From 14e722043da22c4a43595ee3b415b2fe0197893d Mon Sep 17 00:00:00 2001 From: furnace <34057289+windstamp@users.noreply.github.com> Date: Wed, 23 Mar 2022 11:14:44 +0800 Subject: [PATCH] [NPU] fix fp16 (#5417) --- ppdet/engine/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index f186afe8e..fa9167f05 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -373,7 +373,7 @@ class Trainer(object): # enabel auto mixed precision mode if self.cfg.get('amp', False): scaler = amp.GradScaler( - enable=self.cfg.use_gpu, init_loss_scaling=1024) + enable=self.cfg.use_gpu or self.cfg.use_npu, init_loss_scaling=1024) self.status.update({ 'epoch_id': self.start_epoch, -- GitLab