From 6570814194f4d3f92666c8b6f00f3f7849d80e3b Mon Sep 17 00:00:00 2001 From: XiaoguangHu <46782768+XiaoguangHu01@users.noreply.github.com> Date: Wed, 4 May 2022 21:20:30 +0800 Subject: [PATCH] fix bug of batch_norm_grad kernel with fp16 (#42460) * fix bug of batch_norm_grad kernel with fp16 * format code --- paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 35d36c3287..ad3b8579dd 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -988,10 +988,9 @@ PD_REGISTER_KERNEL(batch_norm_grad, double, phi::dtype::float16) { if (kernel_key.dtype() == phi::DataType::FLOAT16) { - kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad } } @@ -1003,10 +1002,9 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw, double, phi::dtype::float16) { if (kernel_key.dtype() == phi::DataType::FLOAT16) { - kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); - kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad } } @@ -1019,7 +1017,6 @@ PD_REGISTER_KERNEL(batch_norm_grad_grad, phi::BatchNormDoubleGradKernel, float, double) {} - #else PD_REGISTER_KERNEL(batch_norm_grad_grad, GPU, -- GitLab