未验证 提交 a5745864 编写于 作者: X XiaoguangHu 提交者: GitHub

[cherry-pick 2.3] fix bug of batch_norm_grad kernel with fp16 (#42461)

* fix bug of batch_norm_grad kernel with fp16

* format code
上级 87e6149c
...@@ -987,10 +987,9 @@ PD_REGISTER_KERNEL(batch_norm_grad, ...@@ -987,10 +987,9 @@ PD_REGISTER_KERNEL(batch_norm_grad,
double, double,
phi::dtype::float16) { phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) { if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
} }
} }
...@@ -1002,10 +1001,9 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw, ...@@ -1002,10 +1001,9 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
double, double,
phi::dtype::float16) { phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) { if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
} }
} }
...@@ -1018,7 +1016,6 @@ PD_REGISTER_KERNEL(batch_norm_grad_grad, ...@@ -1018,7 +1016,6 @@ PD_REGISTER_KERNEL(batch_norm_grad_grad,
phi::BatchNormDoubleGradKernel, phi::BatchNormDoubleGradKernel,
float, float,
double) {} double) {}
#else #else
PD_REGISTER_KERNEL(batch_norm_grad_grad, PD_REGISTER_KERNEL(batch_norm_grad_grad,
GPU, GPU,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册