From 17a076d8edb544f83d6ea775cea68a7a059b1c2f Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 29 May 2018 15:06:47 +0800 Subject: [PATCH] replace TensorCopy with TensorCopySync --- .../fluid/framework/data_device_transform.cc | 27 +++++-------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/data_device_transform.cc b/paddle/fluid/framework/data_device_transform.cc index 4089458a33f..6bcfc6cd55f 100644 --- a/paddle/fluid/framework/data_device_transform.cc +++ b/paddle/fluid/framework/data_device_transform.cc @@ -16,26 +16,14 @@ limitations under the License. */ namespace paddle { namespace framework { -static const platform::DeviceContext* GetDeviceContext( - const platform::Place& src_place, const platform::Place& dst_place) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - - if (platform::is_gpu_place(src_place) && platform::is_cpu_place(dst_place)) { - return pool.Get(src_place); - } else if (platform::is_cpu_place(src_place) && - platform::is_gpu_place(dst_place)) { - return pool.Get(dst_place); - } else { - PADDLE_THROW( - "Currently, model parallelism is only supported between CPU and CUDA"); - } -} - -void TransDataDevice(const Tensor& in, const platform::Place& dst_place, - Tensor* out) { +void TransDataDevice(const Tensor &in, const platform::Place &dst_place, + Tensor *out) { VLOG(3) << "DeviceTransform in, src_place " << in.place() << " dst_place: " << dst_place; - auto* dev_ctx = GetDeviceContext(in.place(), dst_place); + + PADDLE_ENFORCE_NE( + in.place().which(), dst_place.which(), + "Currently, model parallelism is only supported between CPU and CUDA"); // FIXME(zcd): TransDataDevice is used to transform data from GPU to CPU and // the enforced checkings have been done in GetDeviceContext, so the @@ -46,8 +34,7 @@ void TransDataDevice(const Tensor& in, const platform::Place& dst_place, // the transforming is from CPU to GPU and the number of elements is little. // But the embarrassment is that this solution this solution makes training // slower. - TensorCopy(in, dst_place, *dev_ctx, out); - dev_ctx->Wait(); + TensorCopySync(in, dst_place, out); } } // namespace framework -- GitLab