未验证 提交 33730ae7 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix get_paddle_place (#55225)

上级 df21f815
...@@ -7623,8 +7623,15 @@ def _get_paddle_place(place): ...@@ -7623,8 +7623,15 @@ def _get_paddle_place(place):
device_id = int(device_id) device_id = int(device_id)
return core.IPUPlace(device_id) return core.IPUPlace(device_id)
place_info_list = place.split(':', 1)
device_type = place_info_list[0]
if device_type in core.get_all_custom_device_type():
device_id = place_info_list[1]
device_id = int(device_id)
return core.CustomPlace(device_type, device_id)
raise ValueError( raise ValueError(
f"Paddle supports CPUPlace, CUDAPlace, CUDAPinnedPlace, XPUPlace and IPUPlace, but received {place}." f"Paddle supports CPUPlace, CUDAPlace, CUDAPinnedPlace, XPUPlace, IPUPlace and CustomPlace, but received {place}."
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册