From 793c35ef4c00f078a76b7a9db7f3a4406fcb4601 Mon Sep 17 00:00:00 2001 From: shentanyue <34421038+shentanyue@users.noreply.github.com> Date: Tue, 8 Nov 2022 16:32:42 +0800 Subject: [PATCH] fix npu:0 stage (#47729) --- python/paddle/device/__init__.py | 62 ++++++++++++++++++-------------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/python/paddle/device/__init__.py b/python/paddle/device/__init__.py index 8be0e692f76..f8d5dbd8b9d 100644 --- a/python/paddle/device/__init__.py +++ b/python/paddle/device/__init__.py @@ -249,28 +249,6 @@ def _convert_to_place(device): avaliable_xpu_device = re.match(r'xpu:\d+', lower_device) avaliable_npu_device = re.match(r'npu:\d+', lower_device) avaliable_mlu_device = re.match(r'mlu:\d+', lower_device) - if ( - not avaliable_gpu_device - and not avaliable_xpu_device - and not avaliable_npu_device - and not avaliable_mlu_device - ): - device_info_list = device.split(':', 1) - device_type = device_info_list[0] - if device_type in core.get_all_custom_device_type(): - device_id = device_info_list[1] - device_id = int(device_id) - place = core.CustomPlace(device_type, device_id) - else: - raise ValueError( - "The device must be a string which is like 'cpu', {}".format( - ', '.join( - "'{}', '{}:x'".format(x, x) - for x in ['gpu', 'xpu', 'npu', 'mlu'] - + core.get_all_custom_device_type() - ) - ) - ) if avaliable_gpu_device: if not core.is_compiled_with_cuda(): raise ValueError( @@ -293,10 +271,20 @@ def _convert_to_place(device): place = core.XPUPlace(device_id) if avaliable_npu_device: if not core.is_compiled_with_npu(): - raise ValueError( - "The device should not be {}, since PaddlePaddle is " - "not compiled with NPU".format(avaliable_npu_device) - ) + device_info_list = device.split(':', 1) + device_type = device_info_list[0] + if device_type in core.get_all_custom_device_type(): + device_id = device_info_list[1] + device_id = int(device_id) + place = core.CustomPlace(device_type, device_id) + return place + else: + raise ValueError( + "The device should not be {}, since PaddlePaddle is " + "not compiled with NPU or compiled with custom device".format( + avaliable_npu_device + ) + ) device_info_list = device.split(':', 1) device_id = device_info_list[1] device_id = int(device_id) @@ -311,6 +299,28 @@ def _convert_to_place(device): device_id = device_info_list[1] device_id = int(device_id) place = core.MLUPlace(device_id) + if ( + not avaliable_gpu_device + and not avaliable_xpu_device + and not avaliable_npu_device + and not avaliable_mlu_device + ): + device_info_list = device.split(':', 1) + device_type = device_info_list[0] + if device_type in core.get_all_custom_device_type(): + device_id = device_info_list[1] + device_id = int(device_id) + place = core.CustomPlace(device_type, device_id) + else: + raise ValueError( + "The device must be a string which is like 'cpu', {}".format( + ', '.join( + "'{}', '{}:x'".format(x, x) + for x in ['gpu', 'xpu', 'npu', 'mlu'] + + core.get_all_custom_device_type() + ) + ) + ) return place -- GitLab