未验证 提交 e3826e0a 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] fix concat & split dense tensor (#51414)

上级 6f30b14f
......@@ -60,11 +60,17 @@ struct ConcatDenseTensor<platform::CustomDeviceContext, T> {
auto *out_data = out->data<T>();
auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
size_t offset = 0;
phi::stream::Stream stream_wrapper(context.GetPlace(), context.stream());
for (const auto &tensor : in) {
const auto *in_data = tensor.data<T>();
auto sz = tensor.numel() * sizeof(T);
device->MemoryCopyD2D(out_data + offset, in_data, sz, nullptr);
offset += sz;
if (out_data + offset != in_data) {
device->MemoryCopyD2D(out_data + offset,
in_data,
tensor.numel() * sizeof(T),
&stream_wrapper);
}
offset += tensor.numel();
}
}
};
......@@ -78,11 +84,17 @@ struct SplitDenseTensor<platform::CustomDeviceContext, T> {
auto *in_data = in.data<T>();
auto *device = phi::DeviceManager::GetDeviceWithPlace(context.GetPlace());
size_t offset = 0;
phi::stream::Stream stream_wrapper(context.GetPlace(), context.stream());
for (auto *p_tensor : *out) {
auto *out_data = p_tensor->data<T>();
auto sz = p_tensor->numel() * sizeof(T);
device->MemoryCopyD2D(out_data, in_data + offset, sz, nullptr);
offset += sz;
if (out_data != in_data + offset) {
device->MemoryCopyD2D(out_data,
in_data + offset,
p_tensor->numel() * sizeof(T),
&stream_wrapper);
}
offset += p_tensor->numel();
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册