diff --git a/paddle/fluid/framework/data_device_transform.cc b/paddle/fluid/framework/data_device_transform.cc index a876725ac0f17838458065c4b4753a03e2812801..0cd2ebcd41d54a231a9c7545a21ca3e57f89387e 100644 --- a/paddle/fluid/framework/data_device_transform.cc +++ b/paddle/fluid/framework/data_device_transform.cc @@ -38,7 +38,8 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place, auto* dev_ctx = GetDeviceContext(in.place(), dst_place); TensorCopy(in, dst_place, *dev_ctx, out); - if (platform::is_gpu_place(in.place()) && platform::is_cpu_place(dst_place)) { + + if (in.place().which() != dst_place.which()) { dev_ctx->Wait(); } }