From e97b89873a4ec2f57b54225b432eebbffad4fb2f Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Wed, 15 Nov 2017 12:25:23 -0800 Subject: [PATCH] "fix accuracy kernel bug" (#5673) * "fix accuracy kernel bug" * "relauch ci" --- paddle/operators/accuracy_op.cu | 23 +++++++++++++---------- paddle/platform/gpu_info.cc | 5 +++++ paddle/platform/gpu_info.h | 3 +++ 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index b575c682f0d..d2dcab4e548 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/operators/accuracy_op.h" #include "paddle/platform/cuda_helper.h" +#include "paddle/platform/gpu_info.h" namespace paddle { namespace operators { @@ -73,26 +74,28 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { int num_samples = static_cast(inference->dims()[0]); size_t infer_width = inference->dims()[1]; - PADDLE_ENFORCE(cudaMemset(accuracy_data, 0, sizeof(float))); - // cudaMemset((void**)&correct_data, 0, sizeof(float)); + auto stream = ctx.cuda_device_context().stream(); + platform::GpuMemsetAsync(accuracy_data, 0, sizeof(float), stream); if (num_samples == 0) { return; } - cudaMemcpy(total_data, &num_samples, sizeof(int), cudaMemcpyHostToDevice); + platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int), + cudaMemcpyHostToDevice, stream); - AccuracyCudaKernel<<< - 1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>( + AccuracyCudaKernel< + PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( num_samples, infer_width, indices_data, label_data, correct_data, accuracy_data); int d_num_samples, d_num_correct; float d_accuracy; - cudaMemcpy(&d_num_correct, correct_data, sizeof(int), - cudaMemcpyDeviceToHost); - cudaMemcpy(&d_num_samples, total_data, sizeof(int), cudaMemcpyDeviceToHost); - cudaMemcpy(&d_accuracy, accuracy_data, sizeof(float), - cudaMemcpyDeviceToHost); + platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int), + cudaMemcpyDeviceToHost, stream); + platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int), + cudaMemcpyDeviceToHost, stream); + platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float), + cudaMemcpyDeviceToHost, stream); } }; diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index f3455a87338..36b216d8721 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -109,5 +109,10 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream), "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer"); } + +void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) { + PADDLE_ENFORCE(cudaMemsetAsync(dst, value, count, stream), + "cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync"); +} } // namespace platform } // namespace paddle diff --git a/paddle/platform/gpu_info.h b/paddle/platform/gpu_info.h index 37665b97d76..db961f3838a 100644 --- a/paddle/platform/gpu_info.h +++ b/paddle/platform/gpu_info.h @@ -60,6 +60,9 @@ void GpuMemcpySync(void *dst, const void *src, size_t count, void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device, size_t count, cudaStream_t stream); +//! Set memory dst with value count size asynchronously +void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream); + } // namespace platform } // namespace paddle -- GitLab