From 125e48c3f6aae095b2280e1eab0e01c1dcc187b3 Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Tue, 9 Aug 2022 20:04:34 +0800 Subject: [PATCH] Fix copy bug for same src and dst Tensor (#44992) * Fix copy bug for same src and dst Tensor * Improve code design * Fix errors --- paddle/phi/core/tensor_utils.cc | 13 +++++++++++++ .../phi/kernels/gpu/class_center_sample_kernel.cu | 12 +++++------- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/paddle/phi/core/tensor_utils.cc b/paddle/phi/core/tensor_utils.cc index 8b3d4a14273..dcd25180e29 100644 --- a/paddle/phi/core/tensor_utils.cc +++ b/paddle/phi/core/tensor_utils.cc @@ -35,6 +35,19 @@ void Copy(const Context& dev_ctx, auto* src_ptr = src.data(); const auto& src_place = src.place(); + if (&src == dst) { + if (paddle::platform::is_same_place(src_place, dst_place)) { + VLOG(6) << "Skip copy the same data(" << src_ptr << ") from " << src_place + << " to " << dst_place; + } else { + VLOG(6) << "Src and dst are the same Tensor, in-place copy data(" + << src_ptr << ") from " << src_place << " to " << dst_place; + const DenseTensor src_copy = src; + Copy(dev_ctx, src_copy, dst_place, blocking, dst); + } + return; + } + VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " << dst_place; diff --git a/paddle/phi/kernels/gpu/class_center_sample_kernel.cu b/paddle/phi/kernels/gpu/class_center_sample_kernel.cu index 64c3f1dfc9b..eb92a4488e5 100644 --- a/paddle/phi/kernels/gpu/class_center_sample_kernel.cu +++ b/paddle/phi/kernels/gpu/class_center_sample_kernel.cu @@ -571,13 +571,11 @@ void ClassCenterSampleKernel(const Context& dev_ctx, dev_ctx.template Alloc(remapped_label)); // step 14: Get sampled class center for output - paddle::framework::TensorCopySync( - num_classes_per_device, phi::CPUPlace(), &num_classes_per_device); - // phi::Copy(dev_ctx, - // num_classes_per_device, - // phi::CPUPlace(), - // true, - // &num_classes_per_device); + phi::Copy(dev_ctx, + num_classes_per_device, + phi::CPUPlace(), + true, + &num_classes_per_device); T actual_num_samples = num_classes_per_device.data()[rank + 1]; sampled_local_class_center->Resize(phi::make_ddim({actual_num_samples})); -- GitLab