From 2d6f3a56a451a1718b3e3828f7285c5071fc5bd1 Mon Sep 17 00:00:00 2001 From: duanyanhui <45005871+YanhuiDua@users.noreply.github.com> Date: Tue, 11 Apr 2023 09:56:08 +0800 Subject: [PATCH] update npu api (#9688) --- tools/program.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/tools/program.py b/tools/program.py index f36620a9..b11d1a09 100755 --- a/tools/program.py +++ b/tools/program.py @@ -134,9 +134,18 @@ def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False): if use_xpu and not paddle.device.is_compiled_with_xpu(): print(err.format("use_xpu", "xpu", "xpu", "use_xpu")) sys.exit(1) - if use_npu and not paddle.device.is_compiled_with_npu(): - print(err.format("use_npu", "npu", "npu", "use_npu")) - sys.exit(1) + if use_npu: + if int(paddle.version.major) != 0 and int( + paddle.version.major) <= 2 and int( + paddle.version.minor) <= 4: + if not paddle.device.is_compiled_with_npu(): + print(err.format("use_npu", "npu", "npu", "use_npu")) + sys.exit(1) + # is_compiled_with_npu() has been updated after paddle-2.4 + else: + if not paddle.device.is_compiled_with_custom_device("npu"): + print(err.format("use_npu", "npu", "npu", "use_npu")) + sys.exit(1) if use_mlu and not paddle.device.is_compiled_with_mlu(): print(err.format("use_mlu", "mlu", "mlu", "use_mlu")) sys.exit(1) -- GitLab