diff --git a/paddle/phi/api/lib/tensor_copy.cc b/paddle/phi/api/lib/tensor_copy.cc index 5f8c2ed71e9395508eaac787d86be70345471312..fb18a3b05c77e4b3f7b12ab59418f45ceb0bdf56 100644 --- a/paddle/phi/api/lib/tensor_copy.cc +++ b/paddle/phi/api/lib/tensor_copy.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/phi/api/lib/tensor_copy.h" +#include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/core/compat/convert_utils.h" @@ -24,18 +25,21 @@ limitations under the License. */ namespace paddle { namespace experimental { -void copy(const Tensor& src, Place place, bool blocking, Tensor* dst) { +void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst) { auto kernel_key_set = ParseKernelKeyByInputArgs(src); kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place)); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "copy", kernel_key); VLOG(6) << "copy API kernel key: " << kernel_key; + auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "copy", kernel_key); VLOG(6) << "copy API kernel: " << kernel; - auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto target_place = phi::TransToPhiPlace(kernel_key.backend()); + auto& pool = paddle::experimental::DeviceContextPool::Instance(); + auto* dev_ctx = pool.GetMutable( + target_place.GetType() == place.GetType() ? place : target_place); auto dense_x = TensorToDenseTensor(src); diff --git a/paddle/phi/api/lib/tensor_copy.h b/paddle/phi/api/lib/tensor_copy.h index 3ce45853319ecf24b21a1305288bdd441f1c1e1c..4a50b78be85d0862e0aa85786f9e09919d00518f 100644 --- a/paddle/phi/api/lib/tensor_copy.h +++ b/paddle/phi/api/lib/tensor_copy.h @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle { namespace experimental { -void copy(const Tensor& src, Place place, bool blocking, Tensor* dst); +void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst); } // namespace experimental } // namespace paddle