diff --git a/paddle/fluid/framework/data_device_transform.cc b/paddle/fluid/framework/data_device_transform.cc index 85dbb39e6fba735471446b5e5e71a612282c498a..a876725ac0f17838458065c4b4753a03e2812801 100644 --- a/paddle/fluid/framework/data_device_transform.cc +++ b/paddle/fluid/framework/data_device_transform.cc @@ -36,9 +36,11 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place, VLOG(3) << "DeviceTransform in, src_place " << in.place() << " dst_place: " << dst_place; auto* dev_ctx = GetDeviceContext(in.place(), dst_place); - dev_ctx->Wait(); + TensorCopy(in, dst_place, *dev_ctx, out); - dev_ctx->Wait(); + if (platform::is_gpu_place(in.place()) && platform::is_cpu_place(dst_place)) { + dev_ctx->Wait(); + } } } // namespace framework diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index eb559114e95dbcc7ee35561f2b9781a891cc2772..ac1f3f44ae8703c3e0c792bd9a2e658f1341ec15 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -203,8 +203,8 @@ if(WITH_DISTRIBUTE) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor) - cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op listen_and_serv_op executor) if(WITH_GPU) + cc_test(test_send_nccl_id SRCS test_send_nccl_id.cc DEPS send_op listen_and_serv_op executor) op_library(gen_nccl_id_op DEPS nccl_common sendrecvop_grpc) set_source_files_properties(gen_nccl_id_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else()