diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 753e3dc6762d314cfda454a84df3e0f6b2f3d6f3..2e2ebfda797e0b0e10eb61f28685e5416d7a974a 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2515,6 +2515,7 @@ func : BatchNormInferMeta kernel : func : sync_batch_norm + data_type : x backward : sync_batch_norm_grad # take_along_axis diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 8731946ba2d424503ae631f191259a667190363a..67698369278476b87b700df6950abd5a6659794b 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -1190,7 +1190,16 @@ PD_REGISTER_KERNEL(batch_norm, ALL_LAYOUT, phi::BatchNormKernel, float, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); +} #else PD_REGISTER_KERNEL(batch_norm, GPU, @@ -1200,6 +1209,10 @@ PD_REGISTER_KERNEL(batch_norm, double, phi::dtype::float16) { if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); diff --git a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu index a1d4b681ca0536cd4ad59c872aa3a5e4c96bc454..ed8a8c333442d0f098da8c1935202d5ba5b7956b 100644 --- a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu @@ -178,7 +178,18 @@ PD_REGISTER_KERNEL(sync_batch_norm, ALL_LAYOUT, phi::SyncBatchNormKernel, float, - phi::dtype::float16) {} + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + } +} #else PD_REGISTER_KERNEL(sync_batch_norm, GPU, @@ -186,5 +197,16 @@ PD_REGISTER_KERNEL(sync_batch_norm, phi::SyncBatchNormKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + } +} #endif