From 47ea4d879df3c830ebb2cf591bdfbad526fff508 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Mon, 8 Aug 2022 16:42:23 +0800 Subject: [PATCH] Fix sync_bn kernel selection for fp16 (#44876) * fix sync error for fp16 * fix kernel register * add dtype of kernel in lecacy yamal --- paddle/phi/api/yaml/legacy_api.yaml | 1 + paddle/phi/kernels/gpu/batch_norm_kernel.cu | 15 ++++++++++- .../phi/kernels/gpu/sync_batch_norm_kernel.cu | 26 +++++++++++++++++-- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/paddle/phi/api/yaml/legacy_api.yaml b/paddle/phi/api/yaml/legacy_api.yaml index 753e3dc676..2e2ebfda79 100755 --- a/paddle/phi/api/yaml/legacy_api.yaml +++ b/paddle/phi/api/yaml/legacy_api.yaml @@ -2515,6 +2515,7 @@ func : BatchNormInferMeta kernel : func : sync_batch_norm + data_type : x backward : sync_batch_norm_grad # take_along_axis diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 8731946ba2..6769836927 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -1190,7 +1190,16 @@ PD_REGISTER_KERNEL(batch_norm, ALL_LAYOUT, phi::BatchNormKernel, float, - phi::dtype::float16) {} + phi::dtype::float16) { + kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); +} #else PD_REGISTER_KERNEL(batch_norm, GPU, @@ -1200,6 +1209,10 @@ PD_REGISTER_KERNEL(batch_norm, double, phi::dtype::float16) { if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); diff --git a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu index a1d4b681ca..ed8a8c3334 100644 --- a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu @@ -178,7 +178,18 @@ PD_REGISTER_KERNEL(sync_batch_norm, ALL_LAYOUT, phi::SyncBatchNormKernel, float, - phi::dtype::float16) {} + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + } +} #else PD_REGISTER_KERNEL(sync_batch_norm, GPU, @@ -186,5 +197,16 @@ PD_REGISTER_KERNEL(sync_batch_norm, phi::SyncBatchNormKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16) { + if (kernel_key.dtype() == phi::DataType::FLOAT16) { + kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); + kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); + } +} #endif -- GitLab