From 8760817ab4db7190da7b0f539d275e5bde3d811d Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 22 Jun 2022 11:13:02 +0800 Subject: [PATCH] fix tensor copy bug (#43299) (#43728) --- paddle/phi/api/lib/tensor_copy.cc | 13 +++++++++---- paddle/phi/api/lib/tensor_copy.h | 2 +- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/paddle/phi/api/lib/tensor_copy.cc b/paddle/phi/api/lib/tensor_copy.cc index 57e3c28d8cb..0fbdee25761 100644 --- a/paddle/phi/api/lib/tensor_copy.cc +++ b/paddle/phi/api/lib/tensor_copy.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/api/lib/tensor_copy.h" + +#include "paddle/phi/api/include/context_pool.h" #include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/utils/storage.h" @@ -24,18 +26,21 @@ limitations under the License. */ namespace paddle { namespace experimental { -void copy(const Tensor& src, Place place, bool blocking, Tensor* dst) { +void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst) { auto kernel_key_set = ParseKernelKeyByInputArgs(src); kernel_key_set.backend_set = kernel_key_set.backend_set | BackendSet(phi::TransToPhiBackend(place)); auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); - auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( - "copy", kernel_key); VLOG(6) << "copy API kernel key: " << kernel_key; + auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( + "copy", kernel_key); VLOG(6) << "copy API kernel: " << kernel; - auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + auto target_place = phi::TransToPhiPlace(kernel_key.backend()); + auto& pool = paddle::experimental::DeviceContextPool::Instance(); + auto* dev_ctx = pool.GetMutable( + target_place.GetType() == place.GetType() ? place : target_place); auto dense_x = TensorToDenseTensor(src); diff --git a/paddle/phi/api/lib/tensor_copy.h b/paddle/phi/api/lib/tensor_copy.h index 3ce45853319..4a50b78be85 100644 --- a/paddle/phi/api/lib/tensor_copy.h +++ b/paddle/phi/api/lib/tensor_copy.h @@ -19,7 +19,7 @@ limitations under the License. */ namespace paddle { namespace experimental { -void copy(const Tensor& src, Place place, bool blocking, Tensor* dst); +void copy(const Tensor& src, const Place& place, bool blocking, Tensor* dst); } // namespace experimental } // namespace paddle -- GitLab