未验证 提交 7c62d2ab 编写于 作者: S shentanyue 提交者: GitHub

change ascend to npu (#47641)

上级 297f5efe
...@@ -160,14 +160,17 @@ void AnalysisConfig::SetXpuDeviceId(int device_id) { ...@@ -160,14 +160,17 @@ void AnalysisConfig::SetXpuDeviceId(int device_id) {
} }
void AnalysisConfig::EnableNpu(int device_id) { void AnalysisConfig::EnableNpu(int device_id) {
#ifdef PADDLE_WITH_ASCEND_CL #if defined(PADDLE_WITH_ASCEND_CL)
use_npu_ = true; use_npu_ = true;
npu_device_id_ = device_id; npu_device_id_ = device_id;
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
use_custom_device_ = true;
custom_device_id_ = device_id;
custom_device_type_ = "npu";
#else #else
LOG(ERROR) << "Please compile with npu to EnableNpu()"; LOG(ERROR) << "Please compile with npu to EnableNpu()";
use_npu_ = false; use_npu_ = false;
#endif #endif
Update(); Update();
} }
......
...@@ -195,7 +195,13 @@ def get_cudnn_version(): ...@@ -195,7 +195,13 @@ def get_cudnn_version():
def _convert_to_place(device): def _convert_to_place(device):
lower_device = device.lower() lower_device = device.lower()
if lower_device == 'cpu': if device in core.get_all_custom_device_type():
selected_devices = os.getenv(
"FLAGS_selected_{}s".format(device), "0"
).split(",")
device_id = int(selected_devices[0])
place = core.CustomPlace(device, device_id)
elif lower_device == 'cpu':
place = core.CPUPlace() place = core.CPUPlace()
elif lower_device == 'gpu': elif lower_device == 'gpu':
if not core.is_compiled_with_cuda(): if not core.is_compiled_with_cuda():
...@@ -238,12 +244,6 @@ def _convert_to_place(device): ...@@ -238,12 +244,6 @@ def _convert_to_place(device):
selected_mlus = os.getenv("FLAGS_selected_mlus", "0").split(",") selected_mlus = os.getenv("FLAGS_selected_mlus", "0").split(",")
device_id = int(selected_mlus[0]) device_id = int(selected_mlus[0])
place = core.MLUPlace(device_id) place = core.MLUPlace(device_id)
elif device in core.get_all_custom_device_type():
selected_devices = os.getenv(
"FLAGS_selected_{}s".format(device), "0"
).split(",")
device_id = int(selected_devices[0])
place = core.CustomPlace(device, device_id)
else: else:
avaliable_gpu_device = re.match(r'gpu:\d+', lower_device) avaliable_gpu_device = re.match(r'gpu:\d+', lower_device)
avaliable_xpu_device = re.match(r'xpu:\d+', lower_device) avaliable_xpu_device = re.match(r'xpu:\d+', lower_device)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册