From 2bf0d1c872a107bd05903f86ad9743047f237482 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 23 Mar 2023 16:50:41 +0800 Subject: [PATCH] [Fix Bug] Fix customOP + customDevice scenario selects wrong place (#51996) --- paddle/phi/api/lib/kernel_dispatch.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/paddle/phi/api/lib/kernel_dispatch.cc b/paddle/phi/api/lib/kernel_dispatch.cc index 7a0ed32c825..90a0afcf1e6 100644 --- a/paddle/phi/api/lib/kernel_dispatch.cc +++ b/paddle/phi/api/lib/kernel_dispatch.cc @@ -58,6 +58,7 @@ bool HasAllocation(const phi::TensorBase& t) { BackendSet GetTensorBackendSet(const phi::TensorBase& t) { if (HasAllocation(t) && t.place().GetType() != AllocationType::UNDEFINED) { #ifdef PADDLE_WITH_CUSTOM_DEVICE + // See Note [ Why `SetDevice` when parsing custom place? ] if (t.place().GetType() == AllocationType::CUSTOM) { phi::DeviceManager::SetDevice(t.place()); } @@ -128,6 +129,20 @@ DataType ParseDataTypeWithInputOrder(DataType dtype, const Tensor& tensor) { } Backend ParseBackend(const Place& place) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + /** + * [ Why `SetDevice` when parsing custom place? ] + * Users are able to call C++ APIs under customOP + customDevice scenario. To + * make sure `GetDevice` function outputs the accurate place when executing + * `GetDeviceContextByBackend` function in C++ API, we need to call + * `SetDevice` first. However, in dygraph mode, `SetDevice` is called at + * CPython level and calling C++ API directly in customOP cannot reach + * CPython. Hence, we need to manually set the device here. + */ + if (place.GetType() == AllocationType::CUSTOM) { + phi::DeviceManager::SetDevice(place); + } +#endif return phi::TransToPhiBackend(place); } Backend ParseBackend(const Tensor& tensor) { -- GitLab