diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 8c9f02a4d37b3f985fac9b667ed915ca87fd2a7c..319a3ea018d7d3a4a9502962283af9a3ca01ed09 100755 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -160,14 +160,17 @@ void AnalysisConfig::SetXpuDeviceId(int device_id) { } void AnalysisConfig::EnableNpu(int device_id) { -#ifdef PADDLE_WITH_ASCEND_CL +#if defined(PADDLE_WITH_ASCEND_CL) use_npu_ = true; npu_device_id_ = device_id; +#elif defined(PADDLE_WITH_CUSTOM_DEVICE) + use_custom_device_ = true; + custom_device_id_ = device_id; + custom_device_type_ = "npu"; #else LOG(ERROR) << "Please compile with npu to EnableNpu()"; use_npu_ = false; #endif - Update(); } diff --git a/python/paddle/device/__init__.py b/python/paddle/device/__init__.py index 6e14fd504784e2244c5647f9352c068e6c58350a..8be0e692f76a36e991f432c5e2dfb53ec3c3c096 100644 --- a/python/paddle/device/__init__.py +++ b/python/paddle/device/__init__.py @@ -195,7 +195,13 @@ def get_cudnn_version(): def _convert_to_place(device): lower_device = device.lower() - if lower_device == 'cpu': + if device in core.get_all_custom_device_type(): + selected_devices = os.getenv( + "FLAGS_selected_{}s".format(device), "0" + ).split(",") + device_id = int(selected_devices[0]) + place = core.CustomPlace(device, device_id) + elif lower_device == 'cpu': place = core.CPUPlace() elif lower_device == 'gpu': if not core.is_compiled_with_cuda(): @@ -238,12 +244,6 @@ def _convert_to_place(device): selected_mlus = os.getenv("FLAGS_selected_mlus", "0").split(",") device_id = int(selected_mlus[0]) place = core.MLUPlace(device_id) - elif device in core.get_all_custom_device_type(): - selected_devices = os.getenv( - "FLAGS_selected_{}s".format(device), "0" - ).split(",") - device_id = int(selected_devices[0]) - place = core.CustomPlace(device, device_id) else: avaliable_gpu_device = re.match(r'gpu:\d+', lower_device) avaliable_xpu_device = re.match(r'xpu:\d+', lower_device)