diff --git a/paddle/fluid/framework/data_device_transform.cc b/paddle/fluid/framework/data_device_transform.cc index 6bcfc6cd55f02f0d4f0f6e3170e7cc19ce666a28..fee6ba40047053ed5662fe044eceb0c687bd4db9 100644 --- a/paddle/fluid/framework/data_device_transform.cc +++ b/paddle/fluid/framework/data_device_transform.cc @@ -25,6 +25,10 @@ void TransDataDevice(const Tensor &in, const platform::Place &dst_place, in.place().which(), dst_place.which(), "Currently, model parallelism is only supported between CPU and CUDA"); + // NOTE(yy): TransDataDevice should wait for computation of input. + platform::DeviceContextPool::Instance().Get(in.place())->Wait(); + platform::DeviceContextPool::Instance().Get(dst_place)->Wait(); + // FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and // the enforced checkings have been done in GetDeviceContext, so the // `dev_ctx->Wait()` is necessary. But `dev_ctx->Wait()` will make the program