diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 823f790f41030e6156dbd106abfa7471f846e3e0..8a461760ef0c9be7868c874fa4a1d9dd326b6066 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -7623,8 +7623,15 @@ def _get_paddle_place(place): device_id = int(device_id) return core.IPUPlace(device_id) + place_info_list = place.split(':', 1) + device_type = place_info_list[0] + if device_type in core.get_all_custom_device_type(): + device_id = place_info_list[1] + device_id = int(device_id) + return core.CustomPlace(device_type, device_id) + raise ValueError( - f"Paddle supports CPUPlace, CUDAPlace, CUDAPinnedPlace, XPUPlace and IPUPlace, but received {place}." + f"Paddle supports CPUPlace, CUDAPlace, CUDAPinnedPlace, XPUPlace, IPUPlace and CustomPlace, but received {place}." )