diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 043b3eceb4889e03ea99225578799ef50cf2b441..54f6955c580a6a1b2df1bba7303fce32e5432cda 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -91,7 +91,7 @@ class Engine(object): self.vdl_writer = LogWriter(logdir=vdl_writer_path) # set device - assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"] + assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu", "npu"] self.device = paddle.set_device(self.config["Global"]["device"]) logger.info('train with paddle {} and device {}'.format( paddle.__version__, self.device)) diff --git a/ppcls/static/train.py b/ppcls/static/train.py index a3aa0b591ce2db7d1066f1fada521e3a91cfd239..5c56c17cb6d9a1e17bf9dc02bd8ab5a4b89b14fb 100644 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -91,14 +91,17 @@ def main(args): os.environ[k] = AMP_RELATED_FLAGS_SETTING[k] use_xpu = global_config.get("use_xpu", False) + use_npu = global_config.get("use_npu", False) assert ( - use_gpu and use_xpu - ) is not True, "gpu and xpu can not be true in the same time in static mode!" + use_gpu and use_xpu and use_npu + ) is not True, "gpu, xpu and npu can not be true in the same time in static mode!" if use_gpu: device = paddle.set_device('gpu') elif use_xpu: device = paddle.set_device('xpu') + elif use_npu: + device = paddle.set_device('npu') else: device = paddle.set_device('cpu')