未验证 提交 aa79bccf 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #13460 from reyoung/fix_data_transform

Wait input when data transform
...@@ -25,6 +25,10 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place, ...@@ -25,6 +25,10 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place,
in.place().which(), dst_place.which(), in.place().which(), dst_place.which(),
"Currently, model parallelism is only supported between CPU and CUDA"); "Currently, model parallelism is only supported between CPU and CUDA");
// NOTE(yy): TransDataDevice should wait for computation of input.
platform::DeviceContextPool::Instance().Get(in.place())->Wait();
platform::DeviceContextPool::Instance().Get(dst_place)->Wait();
// FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and // FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and
// the enforced checkings have been done in GetDeviceContext, so the // the enforced checkings have been done in GetDeviceContext, so the
// `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program // `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册