未验证 提交 47ea4d87 编写于 作者: Z Zhang Ting 提交者: GitHub

Fix sync_bn kernel selection for fp16 (#44876)

* fix sync error for fp16

* fix kernel register

* add dtype of kernel in lecacy yamal
上级 73be70a3
...@@ -2515,6 +2515,7 @@ ...@@ -2515,6 +2515,7 @@
func : BatchNormInferMeta func : BatchNormInferMeta
kernel : kernel :
func : sync_batch_norm func : sync_batch_norm
data_type : x
backward : sync_batch_norm_grad backward : sync_batch_norm_grad
# take_along_axis # take_along_axis
......
...@@ -1190,7 +1190,16 @@ PD_REGISTER_KERNEL(batch_norm, ...@@ -1190,7 +1190,16 @@ PD_REGISTER_KERNEL(batch_norm,
ALL_LAYOUT, ALL_LAYOUT,
phi::BatchNormKernel, phi::BatchNormKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
#else #else
PD_REGISTER_KERNEL(batch_norm, PD_REGISTER_KERNEL(batch_norm,
GPU, GPU,
...@@ -1200,6 +1209,10 @@ PD_REGISTER_KERNEL(batch_norm, ...@@ -1200,6 +1209,10 @@ PD_REGISTER_KERNEL(batch_norm,
double, double,
phi::dtype::float16) { phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) { if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
......
...@@ -178,7 +178,18 @@ PD_REGISTER_KERNEL(sync_batch_norm, ...@@ -178,7 +178,18 @@ PD_REGISTER_KERNEL(sync_batch_norm,
ALL_LAYOUT, ALL_LAYOUT,
phi::SyncBatchNormKernel, phi::SyncBatchNormKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
}
#else #else
PD_REGISTER_KERNEL(sync_batch_norm, PD_REGISTER_KERNEL(sync_batch_norm,
GPU, GPU,
...@@ -186,5 +197,16 @@ PD_REGISTER_KERNEL(sync_batch_norm, ...@@ -186,5 +197,16 @@ PD_REGISTER_KERNEL(sync_batch_norm,
phi::SyncBatchNormKernel, phi::SyncBatchNormKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
}
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册