未验证 提交 47ea4d87 编写于 作者: Z Zhang Ting 提交者: GitHub

Fix sync_bn kernel selection for fp16 (#44876)

* fix sync error for fp16

* fix kernel register

* add dtype of kernel in lecacy yamal
上级 73be70a3
......@@ -2515,6 +2515,7 @@
func : BatchNormInferMeta
kernel :
func : sync_batch_norm
data_type : x
backward : sync_batch_norm_grad
# take_along_axis
......
......@@ -1190,7 +1190,16 @@ PD_REGISTER_KERNEL(batch_norm,
ALL_LAYOUT,
phi::BatchNormKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
#else
PD_REGISTER_KERNEL(batch_norm,
GPU,
......@@ -1200,6 +1209,10 @@ PD_REGISTER_KERNEL(batch_norm,
double,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
......
......@@ -178,7 +178,18 @@ PD_REGISTER_KERNEL(sync_batch_norm,
ALL_LAYOUT,
phi::SyncBatchNormKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
}
#else
PD_REGISTER_KERNEL(sync_batch_norm,
GPU,
......@@ -186,5 +197,16 @@ PD_REGISTER_KERNEL(sync_batch_norm,
phi::SyncBatchNormKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册