From 5b52481058088da18c90c920bc815181badbf534 Mon Sep 17 00:00:00 2001 From: chengduo Date: Wed, 20 Dec 2017 20:34:47 +0800 Subject: [PATCH] refine accuracy_op.cu (#6774) --- paddle/operators/accuracy_op.cu | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu index 539a93530..dd51aad10 100644 --- a/paddle/operators/accuracy_op.cu +++ b/paddle/operators/accuracy_op.cu @@ -26,7 +26,7 @@ template __global__ void AccuracyCudaKernel(const int N, const int D, const int64_t* Xdata, const int64_t* labeldata, int* correct_data, - float* accuracy) { + float* accuracy, int* total_data) { int count = 0; __shared__ int total[BlockSize]; @@ -47,6 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, if (threadIdx.x == 0) { *correct_data = result; *accuracy = static_cast(result) / static_cast(N); + *total_data = N; } } @@ -80,22 +81,11 @@ class AccuracyOpCUDAKernel : public framework::OpKernel { if (num_samples == 0) { return; } - platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int), - cudaMemcpyHostToDevice, stream); AccuracyCudaKernel< PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( num_samples, infer_width, indices_data, label_data, correct_data, - accuracy_data); - - int d_num_samples, d_num_correct; - float d_accuracy; - platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int), - cudaMemcpyDeviceToHost, stream); - platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int), - cudaMemcpyDeviceToHost, stream); - platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float), - cudaMemcpyDeviceToHost, stream); + accuracy_data, total_data); } }; -- GitLab