From 4c576870d7314c35302a74f16185da7afe25a8cc Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Tue, 17 Jan 2023 11:29:09 +0800 Subject: [PATCH] SetDevice when parse TensorBase (#49860) --- paddle/phi/api/lib/kernel_dispatch.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/paddle/phi/api/lib/kernel_dispatch.cc b/paddle/phi/api/lib/kernel_dispatch.cc index 074da80bbfb..73569b38731 100644 --- a/paddle/phi/api/lib/kernel_dispatch.cc +++ b/paddle/phi/api/lib/kernel_dispatch.cc @@ -22,6 +22,9 @@ limitations under the License. */ #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/string_tensor_utils.h" #include "paddle/phi/core/tensor_utils.h" +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/backends/device_manager.h" +#endif namespace paddle { namespace experimental { @@ -54,6 +57,11 @@ bool HasAllocation(const phi::TensorBase& t) { BackendSet GetTensorBackendSet(const phi::TensorBase& t) { if (HasAllocation(t) && t.place().GetType() != AllocationType::UNDEFINED) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (t.place().GetType() == AllocationType::CUSTOM) { + phi::DeviceManager::SetDevice(t.place()); + } +#endif phi::Backend backend_key = phi::TransToPhiBackend(t.place()); BackendSet backend_set(backend_key); if (backend_key == Backend::GPU && phi::DenseTensor::classof(&t) && -- GitLab