提交 5b524810 编写于 作者: C chengduo 提交者: dzhwinter

refine accuracy_op.cu (#6774)

上级 0f1c685c
......@@ -26,7 +26,7 @@ template <int BlockSize>
__global__ void AccuracyCudaKernel(const int N, const int D,
const int64_t* Xdata,
const int64_t* labeldata, int* correct_data,
float* accuracy) {
float* accuracy, int* total_data) {
int count = 0;
__shared__ int total[BlockSize];
......@@ -47,6 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D,
if (threadIdx.x == 0) {
*correct_data = result;
*accuracy = static_cast<float>(result) / static_cast<float>(N);
*total_data = N;
}
}
......@@ -80,22 +81,11 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
if (num_samples == 0) {
return;
}
platform::GpuMemcpyAsync(total_data, &num_samples, sizeof(int),
cudaMemcpyHostToDevice, stream);
AccuracyCudaKernel<
PADDLE_CUDA_NUM_THREADS><<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(
num_samples, infer_width, indices_data, label_data, correct_data,
accuracy_data);
int d_num_samples, d_num_correct;
float d_accuracy;
platform::GpuMemcpyAsync(&d_num_correct, correct_data, sizeof(int),
cudaMemcpyDeviceToHost, stream);
platform::GpuMemcpyAsync(&d_num_samples, total_data, sizeof(int),
cudaMemcpyDeviceToHost, stream);
platform::GpuMemcpyAsync(&d_accuracy, accuracy_data, sizeof(float),
cudaMemcpyDeviceToHost, stream);
accuracy_data, total_data);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册