未验证 提交 793c35ef 编写于 作者: S shentanyue 提交者: GitHub

fix npu:0 stage (#47729)

上级 caca5687
......@@ -249,28 +249,6 @@ def _convert_to_place(device):
avaliable_xpu_device = re.match(r'xpu:\d+', lower_device)
avaliable_npu_device = re.match(r'npu:\d+', lower_device)
avaliable_mlu_device = re.match(r'mlu:\d+', lower_device)
if (
not avaliable_gpu_device
and not avaliable_xpu_device
and not avaliable_npu_device
and not avaliable_mlu_device
):
device_info_list = device.split(':', 1)
device_type = device_info_list[0]
if device_type in core.get_all_custom_device_type():
device_id = device_info_list[1]
device_id = int(device_id)
place = core.CustomPlace(device_type, device_id)
else:
raise ValueError(
"The device must be a string which is like 'cpu', {}".format(
', '.join(
"'{}', '{}:x'".format(x, x)
for x in ['gpu', 'xpu', 'npu', 'mlu']
+ core.get_all_custom_device_type()
)
)
)
if avaliable_gpu_device:
if not core.is_compiled_with_cuda():
raise ValueError(
......@@ -293,10 +271,20 @@ def _convert_to_place(device):
place = core.XPUPlace(device_id)
if avaliable_npu_device:
if not core.is_compiled_with_npu():
raise ValueError(
"The device should not be {}, since PaddlePaddle is "
"not compiled with NPU".format(avaliable_npu_device)
)
device_info_list = device.split(':', 1)
device_type = device_info_list[0]
if device_type in core.get_all_custom_device_type():
device_id = device_info_list[1]
device_id = int(device_id)
place = core.CustomPlace(device_type, device_id)
return place
else:
raise ValueError(
"The device should not be {}, since PaddlePaddle is "
"not compiled with NPU or compiled with custom device".format(
avaliable_npu_device
)
)
device_info_list = device.split(':', 1)
device_id = device_info_list[1]
device_id = int(device_id)
......@@ -311,6 +299,28 @@ def _convert_to_place(device):
device_id = device_info_list[1]
device_id = int(device_id)
place = core.MLUPlace(device_id)
if (
not avaliable_gpu_device
and not avaliable_xpu_device
and not avaliable_npu_device
and not avaliable_mlu_device
):
device_info_list = device.split(':', 1)
device_type = device_info_list[0]
if device_type in core.get_all_custom_device_type():
device_id = device_info_list[1]
device_id = int(device_id)
place = core.CustomPlace(device_type, device_id)
else:
raise ValueError(
"The device must be a string which is like 'cpu', {}".format(
', '.join(
"'{}', '{}:x'".format(x, x)
for x in ['gpu', 'xpu', 'npu', 'mlu']
+ core.get_all_custom_device_type()
)
)
)
return place
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册