From 7c62d2ab45421c349aa92d25f2bc1108e9e876f6 Mon Sep 17 00:00:00 2001 From: shentanyue <34421038+shentanyue@users.noreply.github.com> Date: Fri, 4 Nov 2022 16:47:52 +0800 Subject: [PATCH] change ascend to npu (#47641) --- paddle/fluid/inference/api/analysis_config.cc | 7 +++++-- python/paddle/device/__init__.py | 14 +++++++------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 8c9f02a4d3..319a3ea018 100755 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -160,14 +160,17 @@ void AnalysisConfig::SetXpuDeviceId(int device_id) { } void AnalysisConfig::EnableNpu(int device_id) { -#ifdef PADDLE_WITH_ASCEND_CL +#if defined(PADDLE_WITH_ASCEND_CL) use_npu_ = true; npu_device_id_ = device_id; +#elif defined(PADDLE_WITH_CUSTOM_DEVICE) + use_custom_device_ = true; + custom_device_id_ = device_id; + custom_device_type_ = "npu"; #else LOG(ERROR) << "Please compile with npu to EnableNpu()"; use_npu_ = false; #endif - Update(); } diff --git a/python/paddle/device/__init__.py b/python/paddle/device/__init__.py index 6e14fd5047..8be0e692f7 100644 --- a/python/paddle/device/__init__.py +++ b/python/paddle/device/__init__.py @@ -195,7 +195,13 @@ def get_cudnn_version(): def _convert_to_place(device): lower_device = device.lower() - if lower_device == 'cpu': + if device in core.get_all_custom_device_type(): + selected_devices = os.getenv( + "FLAGS_selected_{}s".format(device), "0" + ).split(",") + device_id = int(selected_devices[0]) + place = core.CustomPlace(device, device_id) + elif lower_device == 'cpu': place = core.CPUPlace() elif lower_device == 'gpu': if not core.is_compiled_with_cuda(): @@ -238,12 +244,6 @@ def _convert_to_place(device): selected_mlus = os.getenv("FLAGS_selected_mlus", "0").split(",") device_id = int(selected_mlus[0]) place = core.MLUPlace(device_id) - elif device in core.get_all_custom_device_type(): - selected_devices = os.getenv( - "FLAGS_selected_{}s".format(device), "0" - ).split(",") - device_id = int(selected_devices[0]) - place = core.CustomPlace(device, device_id) else: avaliable_gpu_device = re.match(r'gpu:\d+', lower_device) avaliable_xpu_device = re.match(r'xpu:\d+', lower_device) -- GitLab