From 88216f63734ef12049be22e23e2e3ddd435dc044 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 8 Jun 2022 13:12:28 +0800 Subject: [PATCH] fix tensor copy bug (#43299) --- paddle/phi/api/lib/tensor_copy.cc | 12 ++++++++---- paddle/phi/api/lib/tensor_copy.h | 2 +- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/paddle/phi/api/lib/tensor_copy.cc b/paddle/phi/api/lib/tensor_copy.cc index 5f8c2ed71e9..fb18a3b05c7 100644 --- a/paddle/phi/api/lib/tensor_copy.cc +++ b/paddle/phi/api/lib/tensor_copy.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/phi/api/lib/tensor_copy.h" +#include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/core/compat/convert_utils.h" @@ -24,18 +25,21 @@ limitations under the License. */ namespace paddle { namespace experimental { -void copy(const Tensor& src, Place place, bool blocking, Tensor* dst) { +void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst) { auto kernel_key_set = ParseKernelKeyByInputArgs(src); kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place)); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "copy", kernel_key); VLOG(6) << "copy API kernel key: " << kernel_key; + auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "copy", kernel_key); VLOG(6) << "copy API kernel: " << kernel; - auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto target_place = phi::TransToPhiPlace(kernel_key.backend()); + auto& pool = paddle::experimental::DeviceContextPool::Instance(); + auto* dev_ctx = pool.GetMutable( + target_place.GetType() == place.GetType() ? place : target_place); auto dense_x = TensorToDenseTensor(src); diff --git a/paddle/phi/api/lib/tensor_copy.h b/paddle/phi/api/lib/tensor_copy.h index 3ce45853319..4a50b78be85 100644 --- a/paddle/phi/api/lib/tensor_copy.h +++ b/paddle/phi/api/lib/tensor_copy.h @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle { namespace experimental { -void copy(const Tensor& src, Place place, bool blocking, Tensor* dst); +void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst); } // namespace experimental } // namespace paddle -- GitLab