From 23d2b079fd00ad8a06ce3dd2dcd712cb4cef6fa4 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Wed, 24 Aug 2022 11:34:52 +0800 Subject: [PATCH] [CustomDevice] fix Tensor._to (#45337) --- python/paddle/fluid/dygraph/varbase_patch_methods.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index f70bfbde1e8..2b7771554b8 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -429,12 +429,14 @@ def monkey_patch_varbase(): if device is not None: if isinstance(device, str): device = paddle.device._convert_to_place(device) - elif isinstance(device, (core.CPUPlace, core.CUDAPlace, - core.CUDAPinnedPlace, core.XPUPlace)): + elif isinstance( + device, + (core.CPUPlace, core.CUDAPlace, core.CUDAPinnedPlace, + core.XPUPlace, core.CustomPlace)): pass else: raise ValueError( - "device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace() or paddle.XPUPlace(), but the type of device is " + "device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace(), paddle.XPUPlace() or paddle.CustomPlace(), but the type of device is " + type(device).__name__) if blocking is None: -- GitLab