未验证 提交 e7c249cb 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix reducer (#52115)

上级 9fa98349
...@@ -207,13 +207,14 @@ struct ConcatTensorsForAllReduce<platform::CustomDeviceContext, T> { ...@@ -207,13 +207,14 @@ struct ConcatTensorsForAllReduce<platform::CustomDeviceContext, T> {
.get(); .get();
uint8_t *out_data = reinterpret_cast<uint8_t *>(out->data<T>()); uint8_t *out_data = reinterpret_cast<uint8_t *>(out->data<T>());
auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace()); auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
phi::stream::Stream stream(context.GetPlace(), context.stream());
size_t offset = 0; size_t offset = 0;
for (const auto &tensor : dense_tensors_) { for (const auto &tensor : dense_tensors_) {
const uint8_t *in_data = const uint8_t *in_data =
reinterpret_cast<const uint8_t *>(tensor.data<T>()); reinterpret_cast<const uint8_t *>(tensor.data<T>());
auto sz = tensor.numel() * sizeof(T); auto sz = tensor.numel() * sizeof(T);
device->MemoryCopyD2D(out_data + offset, in_data, sz, nullptr); device->MemoryCopyD2D(out_data + offset, in_data, sz, &stream);
offset += sz; offset += sz;
} }
} }
...@@ -229,12 +230,13 @@ struct SplitTensorsForAllReduce<platform::CustomDeviceContext, T> { ...@@ -229,12 +230,13 @@ struct SplitTensorsForAllReduce<platform::CustomDeviceContext, T> {
.get(); .get();
uint8_t *in_data = reinterpret_cast<uint8_t *>(in->data<T>()); uint8_t *in_data = reinterpret_cast<uint8_t *>(in->data<T>());
auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace()); auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
phi::stream::Stream stream(context.GetPlace(), context.stream());
size_t offset = 0; size_t offset = 0;
for (auto &tensor : *p_dense_tensors) { for (auto &tensor : *p_dense_tensors) {
uint8_t *out_data = reinterpret_cast<uint8_t *>(tensor.data<T>()); uint8_t *out_data = reinterpret_cast<uint8_t *>(tensor.data<T>());
auto sz = tensor.numel() * sizeof(T); auto sz = tensor.numel() * sizeof(T);
device->MemoryCopyD2D(out_data, in_data + offset, sz, nullptr); device->MemoryCopyD2D(out_data, in_data + offset, sz, &stream);
offset += sz; offset += sz;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册