提交 93c47003 编写于 作者: C chengduo 提交者: Yang Yang(Tony)

fix DataTransFunc (#10752)

上级 d0a62bfc
...@@ -36,9 +36,11 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place, ...@@ -36,9 +36,11 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place,
VLOG(3) << "DeviceTransform in, src_place " << in.place() VLOG(3) << "DeviceTransform in, src_place " << in.place()
<< " dst_place: " << dst_place; << " dst_place: " << dst_place;
auto* dev_ctx = GetDeviceContext(in.place(), dst_place); auto* dev_ctx = GetDeviceContext(in.place(), dst_place);
dev_ctx->Wait();
TensorCopy(in, dst_place, *dev_ctx, out); TensorCopy(in, dst_place, *dev_ctx, out);
if (platform::is_gpu_place(in.place()) && platform::is_cpu_place(dst_place)) {
dev_ctx->Wait(); dev_ctx->Wait();
}
} }
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册