未验证 提交 125e48c3 编写于 作者: R Ruibiao Chen 提交者: GitHub

Fix copy bug for same src and dst Tensor (#44992)

* Fix copy bug for same src and dst Tensor

* Improve code design

* Fix errors
上级 be931dfe
......@@ -35,6 +35,19 @@ void Copy(const Context& dev_ctx,
auto* src_ptr = src.data();
const auto& src_place = src.place();
if (&src == dst) {
if (paddle::platform::is_same_place(src_place, dst_place)) {
VLOG(6) << "Skip copy the same data(" << src_ptr << ") from " << src_place
<< " to " << dst_place;
} else {
VLOG(6) << "Src and dst are the same Tensor, in-place copy data("
<< src_ptr << ") from " << src_place << " to " << dst_place;
const DenseTensor src_copy = src;
Copy(dev_ctx, src_copy, dst_place, blocking, dst);
}
return;
}
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
......
......@@ -571,13 +571,11 @@ void ClassCenterSampleKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(remapped_label));
// step 14: Get sampled class center for output
paddle::framework::TensorCopySync(
num_classes_per_device, phi::CPUPlace(), &num_classes_per_device);
// phi::Copy<Context>(dev_ctx,
// num_classes_per_device,
// phi::CPUPlace(),
// true,
// &num_classes_per_device);
phi::Copy<Context>(dev_ctx,
num_classes_per_device,
phi::CPUPlace(),
true,
&num_classes_per_device);
T actual_num_samples = num_classes_per_device.data<T>()[rank + 1];
sampled_local_class_center->Resize(phi::make_ddim({actual_num_samples}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册