From e3826e0a2942f135f4234176e8bb62dd4f229958 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Fri, 10 Mar 2023 10:56:01 +0800 Subject: [PATCH] [CustomDevice] fix concat & split dense tensor (#51414) --- paddle/fluid/pybind/process_group_utils.h | 24 +++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/pybind/process_group_utils.h b/paddle/fluid/pybind/process_group_utils.h index 3e4cdd743f4..a35962ce841 100644 --- a/paddle/fluid/pybind/process_group_utils.h +++ b/paddle/fluid/pybind/process_group_utils.h @@ -60,11 +60,17 @@ struct ConcatDenseTensor { auto *out_data = out->data(); auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace()); size_t offset = 0; + phi::stream::Stream stream_wrapper(context.GetPlace(), context.stream()); + for (const auto &tensor : in) { const auto *in_data = tensor.data(); - auto sz = tensor.numel() * sizeof(T); - device->MemoryCopyD2D(out_data + offset, in_data, sz, nullptr); - offset += sz; + if (out_data + offset != in_data) { + device->MemoryCopyD2D(out_data + offset, + in_data, + tensor.numel() * sizeof(T), + &stream_wrapper); + } + offset += tensor.numel(); } } }; @@ -78,11 +84,17 @@ struct SplitDenseTensor { auto *in_data = in.data(); auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace()); size_t offset = 0; + phi::stream::Stream stream_wrapper(context.GetPlace(), context.stream()); + for (auto *p_tensor : *out) { auto *out_data = p_tensor->data(); - auto sz = p_tensor->numel() * sizeof(T); - device->MemoryCopyD2D(out_data, in_data + offset, sz, nullptr); - offset += sz; + if (out_data != in_data + offset) { + device->MemoryCopyD2D(out_data, + in_data + offset, + p_tensor->numel() * sizeof(T), + &stream_wrapper); + } + offset += p_tensor->numel(); } } }; -- GitLab