未验证 提交 e7c249cb 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix reducer (#52115)

上级 9fa98349
......@@ -207,13 +207,14 @@ struct ConcatTensorsForAllReduce<platform::CustomDeviceContext, T> {
.get();
uint8_t *out_data = reinterpret_cast<uint8_t *>(out->data<T>());
auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
phi::stream::Stream stream(context.GetPlace(), context.stream());
size_t offset = 0;
for (const auto &tensor : dense_tensors_) {
const uint8_t *in_data =
reinterpret_cast<const uint8_t *>(tensor.data<T>());
auto sz = tensor.numel() * sizeof(T);
device->MemoryCopyD2D(out_data + offset, in_data, sz, nullptr);
device->MemoryCopyD2D(out_data + offset, in_data, sz, &stream);
offset += sz;
}
}
......@@ -229,12 +230,13 @@ struct SplitTensorsForAllReduce<platform::CustomDeviceContext, T> {
.get();
uint8_t *in_data = reinterpret_cast<uint8_t *>(in->data<T>());
auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
phi::stream::Stream stream(context.GetPlace(), context.stream());
size_t offset = 0;
for (auto &tensor : *p_dense_tensors) {
uint8_t *out_data = reinterpret_cast<uint8_t *>(tensor.data<T>());
auto sz = tensor.numel() * sizeof(T);
device->MemoryCopyD2D(out_data, in_data + offset, sz, nullptr);
device->MemoryCopyD2D(out_data, in_data + offset, sz, &stream);
offset += sz;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册