diff --git a/paddle/phi/api/lib/kernel_dispatch.cc b/paddle/phi/api/lib/kernel_dispatch.cc index 7a0ed32c825262d192fd2c300e0d91f81bc739dd..90a0afcf1e6d35f750e980980285aad66c750de5 100644 --- a/paddle/phi/api/lib/kernel_dispatch.cc +++ b/paddle/phi/api/lib/kernel_dispatch.cc @@ -58,6 +58,7 @@ bool HasAllocation(const phi::TensorBase& t) { BackendSet GetTensorBackendSet(const phi::TensorBase& t) { if (HasAllocation(t) && t.place().GetType() != AllocationType::UNDEFINED) { #ifdef PADDLE_WITH_CUSTOM_DEVICE + // See Note [ Why `SetDevice` when parsing custom place? ] if (t.place().GetType() == AllocationType::CUSTOM) { phi::DeviceManager::SetDevice(t.place()); } @@ -128,6 +129,20 @@ DataType ParseDataTypeWithInputOrder(DataType dtype, const Tensor& tensor) { } Backend ParseBackend(const Place& place) { +#ifdef PADDLE_WITH_CUSTOM_DEVICE + /** + * [ Why `SetDevice` when parsing custom place? ] + * Users are able to call C++ APIs under customOP + customDevice scenario. To + * make sure `GetDevice` function outputs the accurate place when executing + * `GetDeviceContextByBackend` function in C++ API, we need to call + * `SetDevice` first. However, in dygraph mode, `SetDevice` is called at + * CPython level and calling C++ API directly in customOP cannot reach + * CPython. Hence, we need to manually set the device here. + */ + if (place.GetType() == AllocationType::CUSTOM) { + phi::DeviceManager::SetDevice(place); + } +#endif return phi::TransToPhiBackend(place); } Backend ParseBackend(const Tensor& tensor) {