diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index d1b01ae05b808b229309e9689165483a11530c84..485aba7060c60abe120a5707bdf80f3751aea444 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -20,7 +20,8 @@ namespace paddle { namespace framework { void TensorCopy(const Tensor& src, const platform::Place& dst_place, - const platform::DeviceContext& ctx, Tensor* dst) { + const platform::DeviceContext& ctx, Tensor* dst, + bool sync = false) { VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " << dst_place; src.check_memory_size(); @@ -47,9 +48,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); auto ctx_gpu_place = boost::get(ctx_place); PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); - memory::Copy( - dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, - reinterpret_cast(ctx).stream()); + auto stream = + sync ? nullptr + : reinterpret_cast(ctx) + .stream(); + memory::Copy(dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream); } else if (platform::is_cpu_place(src_place) && platform::is_gpu_place(dst_place)) { auto src_cpu_place = boost::get(src_place); @@ -58,18 +61,22 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); auto ctx_gpu_place = boost::get(ctx_place); PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place); - memory::Copy( - dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, - reinterpret_cast(ctx).stream()); + auto stream = + sync ? nullptr + : reinterpret_cast(ctx) + .stream(); + memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream); } else if (platform::is_gpu_place(src_place) && platform::is_gpu_place(dst_place)) { auto src_gpu_place = boost::get(src_place); auto dst_gpu_place = boost::get(dst_place); auto ctx_place = ctx.GetPlace(); PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); - memory::Copy( - dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, - reinterpret_cast(ctx).stream()); + auto stream = + sync ? nullptr + : reinterpret_cast(ctx) + .stream(); + memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream); } #endif } diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index eddcaab8befda84dd14ed46c31ac025dfbcc7ca9..347fbe7ecc737ff10489b2fc03de08d95e33e963 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -30,29 +30,46 @@ void Copy(platform::CPUPlace, void* dst, template <> void Copy( platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place, - const void* src, size_t num, cudaStream_t stream) { + const void* src, size_t num, cudaStream_t stream = nullptr) { platform::SetDeviceId(src_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); + if (stream) { + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); + } else { + platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost); + } } template <> void Copy( platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place, - const void* src, size_t num, cudaStream_t stream) { + const void* src, size_t num, cudaStream_t stream = nullptr) { platform::SetDeviceId(dst_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); + if (stream) { + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); + } else { + platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice); + } } template <> void Copy( platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place, - const void* src, size_t num, cudaStream_t stream) { + const void* src, size_t num, cudaStream_t stream = nullptr) { if (dst_place == src_place) { platform::SetDeviceId(src_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); + if (stream) { + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); + } else { + platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice); + } } else { - platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num, - stream); + if (stream) { + platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device, + num, stream); + } else { + platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device, + num, stream); + } } } @@ -81,18 +98,26 @@ template <> void Copy( platform::CUDAPinnedPlace dst_place, void* dst, platform::CUDAPlace src_place, const void* src, size_t num, - cudaStream_t stream) { + cudaStream_t stream = nullptr) { platform::SetDeviceId(src_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); + if (stream) { + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); + } else { + platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost); + } } template <> void Copy( platform::CUDAPlace dst_place, void* dst, platform::CUDAPinnedPlace src_place, const void* src, size_t num, - cudaStream_t stream) { + cudaStream_t stream = nullptr) { platform::SetDeviceId(dst_place.device); - platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); + if (stream) { + platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); + } else { + platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice); + } } #endif diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index 0b7c1d6af714558d35918dac62d92d9e0f86c970..4372f23fc1dbd85e43b04a9d644977392316c2e9 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -180,7 +180,8 @@ void DoubleBufferReader::PrefetchThreadFunc() { auto* gpu_ctx = ctxs_[cached_tensor_id].get(); gpu_batch.resize(cpu_batch.size()); for (size_t i = 0; i < cpu_batch.size(); ++i) { - framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i]); + framework::TensorCopy(cpu_batch[i], place_, *gpu_ctx, &gpu_batch[i], + true); gpu_batch[i].set_lod(cpu_batch[i].lod()); } } diff --git a/paddle/fluid/platform/gpu_info.cc b/paddle/fluid/platform/gpu_info.cc index aaebeb1353a13ab16fcf98f10da59d41fd2f5b48..4cee93f3a4224cb97327254cd1679021d197a1b1 100644 --- a/paddle/fluid/platform/gpu_info.cc +++ b/paddle/fluid/platform/gpu_info.cc @@ -127,11 +127,24 @@ void GpuMemcpyAsync(void *dst, const void *src, size_t count, "cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync"); } -void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, - size_t count, cudaStream_t stream) { +void GpuMemcpySync(void *dst, const void *src, size_t count, + enum cudaMemcpyKind kind) { + PADDLE_ENFORCE(cudaMemcpy(dst, src, count, kind), + "cudaMemcpy failed in paddle::platform::GpuMemcpySync"); +} + +void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src, + int src_device, size_t count, cudaStream_t stream) { PADDLE_ENFORCE( cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream), - "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer"); + "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeerAsync"); +} + +void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src, + int src_device, size_t count) { + PADDLE_ENFORCE( + cudaMemcpyPeer(dst, dst_device, src, src_device, count), + "cudaMemcpyPeer failed in paddle::platform::GpuMemcpyPeerSync"); } void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) { diff --git a/paddle/fluid/platform/gpu_info.h b/paddle/fluid/platform/gpu_info.h index 36345e17406e22970806fa274d5a73a703517c43..f4640d3eaa2165c35e8e14690d83e9e7e7168c0b 100644 --- a/paddle/fluid/platform/gpu_info.h +++ b/paddle/fluid/platform/gpu_info.h @@ -57,9 +57,17 @@ size_t GpuMaxChunkSize(); void GpuMemcpyAsync(void *dst, const void *src, size_t count, enum cudaMemcpyKind kind, cudaStream_t stream); -//! Copy memory from one device to another device. -void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, - size_t count, cudaStream_t stream); +//! Copy memory from address src to dst synchronously. +void GpuMemcpySync(void *dst, const void *src, size_t count, + enum cudaMemcpyKind kind); + +//! Copy memory from one device to another device asynchronously. +void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src, + int src_device, size_t count, cudaStream_t stream); + +//! Copy memory from one device to another device synchronously. +void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src, + int src_device, size_t count); //! Set memory dst with value count size asynchronously void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream);