未验证 提交 e077678c 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix reducer when input on cpu (#53662)

上级 aafaad90
...@@ -215,7 +215,11 @@ struct ConcatTensorsForAllReduce<platform::CustomDeviceContext, T> { ...@@ -215,7 +215,11 @@ struct ConcatTensorsForAllReduce<platform::CustomDeviceContext, T> {
const uint8_t *in_data = const uint8_t *in_data =
reinterpret_cast<const uint8_t *>(tensor.data<T>()); reinterpret_cast<const uint8_t *>(tensor.data<T>());
auto sz = tensor.numel() * sizeof(T); auto sz = tensor.numel() * sizeof(T);
device->MemoryCopyD2D(out_data + offset, in_data, sz, &stream); if (tensor.place().GetType() == phi::AllocationType::CPU) {
device->MemoryCopyH2D(out_data + offset, in_data, sz, &stream);
} else {
device->MemoryCopyD2D(out_data + offset, in_data, sz, &stream);
}
offset += sz; offset += sz;
} }
} }
...@@ -237,7 +241,11 @@ struct SplitTensorsForAllReduce<platform::CustomDeviceContext, T> { ...@@ -237,7 +241,11 @@ struct SplitTensorsForAllReduce<platform::CustomDeviceContext, T> {
for (auto &tensor : *p_dense_tensors) { for (auto &tensor : *p_dense_tensors) {
uint8_t *out_data = reinterpret_cast<uint8_t *>(tensor.data<T>()); uint8_t *out_data = reinterpret_cast<uint8_t *>(tensor.data<T>());
auto sz = tensor.numel() * sizeof(T); auto sz = tensor.numel() * sizeof(T);
device->MemoryCopyD2D(out_data, in_data + offset, sz, &stream); if (tensor.place().GetType() == phi::AllocationType::CPU) {
device->MemoryCopyD2H(out_data, in_data + offset, sz, &stream);
} else {
device->MemoryCopyD2D(out_data, in_data + offset, sz, &stream);
}
offset += sz; offset += sz;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册