未验证 提交 693de9f0 编写于 作者: W WangZhen 提交者: GitHub

Fix accuracy fp16 kernel return fp32 tensor error (#48803)

上级 93b7ccf5
...@@ -26,13 +26,13 @@ ...@@ -26,13 +26,13 @@
namespace phi { namespace phi {
using phi::PADDLE_CUDA_NUM_THREADS; using phi::PADDLE_CUDA_NUM_THREADS;
template <int BlockSize> template <int BlockSize, typename T>
__global__ void AccuracyCudaKernel(const int N, __global__ void AccuracyCudaKernel(const int N,
const int D, const int D,
const int64_t* Xdata, const int64_t* Xdata,
const int64_t* labeldata, const int64_t* labeldata,
int* correct_data, int* correct_data,
float* accuracy, T* accuracy,
int* total_data) { int* total_data) {
int count = 0; int count = 0;
__shared__ int total[BlockSize]; __shared__ int total[BlockSize];
...@@ -64,7 +64,7 @@ __global__ void AccuracyCudaKernel(const int N, ...@@ -64,7 +64,7 @@ __global__ void AccuracyCudaKernel(const int N,
#endif #endif
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
*correct_data = result; *correct_data = result;
*accuracy = static_cast<float>(result) / static_cast<float>(N); *accuracy = static_cast<T>(result) / static_cast<T>(N);
*total_data = N; *total_data = N;
} }
} }
...@@ -84,18 +84,18 @@ void AccuracyRawKernel(const Context& dev_ctx, ...@@ -84,18 +84,18 @@ void AccuracyRawKernel(const Context& dev_ctx,
int* correct_data = dev_ctx.template Alloc<int>(correct); int* correct_data = dev_ctx.template Alloc<int>(correct);
int* total_data = dev_ctx.template Alloc<int>(total); int* total_data = dev_ctx.template Alloc<int>(total);
float* accuracy_data = dev_ctx.template Alloc<float>(accuracy); T* accuracy_data = dev_ctx.template Alloc<T>(accuracy);
int num_samples = static_cast<int>(inference.dims()[0]); int num_samples = static_cast<int>(inference.dims()[0]);
size_t infer_width = inference.dims()[1]; size_t infer_width = inference.dims()[1];
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
phi::backends::gpu::GpuMemsetAsync(accuracy_data, 0, sizeof(float), stream); phi::backends::gpu::GpuMemsetAsync(accuracy_data, 0, sizeof(T), stream);
if (num_samples == 0) { if (num_samples == 0) {
return; return;
} }
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS> AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS, T>
<<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(num_samples, <<<1, PADDLE_CUDA_NUM_THREADS, 0, stream>>>(num_samples,
infer_width, infer_width,
indices_data, indices_data,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册