diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index abab5372b2dcc0458531e373e14254627bc90a70..73012668658181dd54af3808271ff46a20fb444f 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -421,8 +421,10 @@ class Trainer(object): model = self.model if self.cfg.get('to_static', False): model = apply_to_static(self.cfg, model) - sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and - (self.cfg.use_gpu or self.cfg.use_mlu) and self._nranks > 1) + sync_bn = ( + getattr(self.cfg, 'norm_type', None) == 'sync_bn' and + (self.cfg.use_gpu or self.cfg.use_npu or self.cfg.use_mlu) and + self._nranks > 1) if sync_bn: model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -484,7 +486,8 @@ class Trainer(object): DataParallel) and use_fused_allreduce_gradients: with model.no_sync(): with paddle.amp.auto_cast( - enable=self.cfg.use_gpu or self.cfg.use_mlu, + enable=self.cfg.use_gpu or + self.cfg.use_npu or self.cfg.use_mlu, custom_white_list=self.custom_white_list, custom_black_list=self.custom_black_list, level=self.amp_level): @@ -498,7 +501,8 @@ class Trainer(object): list(model.parameters()), None) else: with paddle.amp.auto_cast( - enable=self.cfg.use_gpu or self.cfg.use_mlu, + enable=self.cfg.use_gpu or self.cfg.use_npu or + self.cfg.use_mlu, custom_white_list=self.custom_white_list, custom_black_list=self.custom_black_list, level=self.amp_level): @@ -612,7 +616,8 @@ class Trainer(object): # forward if self.use_amp: with paddle.amp.auto_cast( - enable=self.cfg.use_gpu or self.cfg.use_mlu, + enable=self.cfg.use_gpu or self.cfg.use_npu or + self.cfg.use_mlu, custom_white_list=self.custom_white_list, custom_black_list=self.custom_black_list, level=self.amp_level): @@ -679,7 +684,8 @@ class Trainer(object): # forward if self.use_amp: with paddle.amp.auto_cast( - enable=self.cfg.use_gpu or self.cfg.use_mlu, + enable=self.cfg.use_gpu or self.cfg.use_npu or + self.cfg.use_mlu, custom_white_list=self.custom_white_list, custom_black_list=self.custom_black_list, level=self.amp_level): diff --git a/ppdet/utils/check.py b/ppdet/utils/check.py index 5235e0ebe79a2097b8f059b071e28e680955e823..7690ade9eab0a7d859459a0be74d344446be6938 100644 --- a/ppdet/utils/check.py +++ b/ppdet/utils/check.py @@ -53,16 +53,17 @@ def check_mlu(use_mlu): def check_npu(use_npu): """ Log error and exit when set use_npu=true in paddlepaddle - cpu/gpu/xpu version. + version without paddle-custom-npu installed. """ err = "Config use_npu cannot be set as true while you are " \ - "using paddlepaddle cpu/gpu/xpu version ! \nPlease try: \n" \ - "\t1. Install paddlepaddle-npu to run model on NPU \n" \ + "using paddlepaddle version without paddle-custom-npu " \ + "installed! \nPlease try: \n" \ + "\t1. Install paddle-custom-npu to run model on NPU \n" \ "\t2. Set use_npu as false in config file to run " \ - "model on CPU/GPU/XPU" + "model on other devices supported." try: - if use_npu and not paddle.is_compiled_with_npu(): + if use_npu and not 'npu' in paddle.device.get_all_custom_device_type(): logger.error(err) sys.exit(1) except Exception as e: