提交 5b524810 编写于 作者: C chengduo 提交者: dzhwinter

refine accuracy_op.cu (#6774)

上级 0f1c685c
...@@ -26,7 +26,7 @@ template <int BlockSize> ...@@ -26,7 +26,7 @@ template <int BlockSize>
__global__ void AccuracyCudaKernel(const int N, const int D, __global__ void AccuracyCudaKernel(const int N, const int D,
const int64_t* Xdata, const int64_t* Xdata,
const int64_t* labeldata, int* correct_data, const int64_t* labeldata, int* correct_data,
float* accuracy) { float* accuracy, int* total_data) {
int count = 0; int count = 0;
__shared__ int total[BlockSize]; __shared__ int total[BlockSize];
...@@ -47,6 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, ...@@ -47,6 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D,
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
*correct_data = result; *correct_data = result;
*accuracy = static_cast<float>(result) / static_cast<float>(N); *accuracy = static_cast<float>(result) / static_cast<float>(N);
*total_data = N;
} }
} }
...@@ -80,22 +81,11 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> { ...@@ -80,22 +81,11 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
if (num_samples == 0) { if (num_samples == 0) {
return; return;
} }
platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int),
cudaMemcpyHostToDevice, stream);
AccuracyCudaKernel< AccuracyCudaKernel<
PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>( PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_samples, infer_width, indices_data, label_data, correct_data, num_samples, infer_width, indices_data, label_data, correct_data,
accuracy_data); accuracy_data, total_data);
int d_num_samples, d_num_correct;
float d_accuracy;
platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int),
cudaMemcpyDeviceToHost, stream);
platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int),
cudaMemcpyDeviceToHost, stream);
platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float),
cudaMemcpyDeviceToHost, stream);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册