未验证 提交 387f2276 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] refine npu data_device_transform (#33224)

上级 4540456b
......@@ -26,6 +26,13 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
platform::errors::Unavailable("Currently, model parallelism is only "
"supported between CPU and CUDA."));
// NOTE(zhiqiu): Special case for CPU->NPU, avoid stream sync.
if (platform::is_cpu_place(in.place()) && platform::is_npu_place(dst_place)) {
TensorCopy(in, dst_place,
*platform::DeviceContextPool::Instance().Get(dst_place), out);
return;
}
// NOTE(yy): TransDataDevice should wait for computation of input.
if (!platform::is_cuda_pinned_place(in.place())) {
platform::DeviceContextPool::Instance().Get(in.place())->Wait();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册