From 93c470034a96b4963b3b0740105ef5f92ac70f96 Mon Sep 17 00:00:00 2001 From: chengduo Date: Fri, 18 May 2018 02:46:38 +0800 Subject: [PATCH] fix DataTransFunc (#10752) --- paddle/fluid/framework/data_device_transform.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/data_device_transform.cc b/paddle/fluid/framework/data_device_transform.cc index 85dbb39e6f..a876725ac0 100644 --- a/paddle/fluid/framework/data_device_transform.cc +++ b/paddle/fluid/framework/data_device_transform.cc @@ -36,9 +36,11 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place, VLOG(3) << "DeviceTransform in, src_place " << in.place() << " dst_place: " << dst_place; auto* dev_ctx = GetDeviceContext(in.place(), dst_place); - dev_ctx->Wait(); + TensorCopy(in, dst_place, *dev_ctx, out); - dev_ctx->Wait(); + if (platform::is_gpu_place(in.place()) && platform::is_cpu_place(dst_place)) { + dev_ctx->Wait(); + } } } // namespace framework -- GitLab