未验证 提交 523f8a26 编写于 作者: G Guoxia Wang 提交者: GitHub

[AMP OP&Test] support bf16 for batch norm (#52407)

* [AMP OP&Test] support bf16 for batchnorm

* codestyle

* Update batch_norm_grad_kernel.cu

* Update batch_norm_kernel.cu

* fix codestyle

* fix

* fix

* fix

* fix

* fix

* Update batch_norm_kernel.cc
上级 a2060568
......@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h"
......@@ -66,6 +67,22 @@ PD_REGISTER_KERNEL(batch_norm_infer,
float,
double) {}
#ifdef PADDLE_WITH_CUDA
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(batch_norm_infer,
GPU,
ALL_LAYOUT,
phi::BatchNormInferKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
#else
PD_REGISTER_KERNEL(batch_norm_infer,
GPU,
ALL_LAYOUT,
......@@ -79,6 +96,7 @@ PD_REGISTER_KERNEL(batch_norm_infer,
}
}
#endif
#endif
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(batch_norm_infer,
GPU,
......
......@@ -1314,14 +1314,18 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
float,
phi::dtype::float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(batch_norm_grad,
GPU,
ALL_LAYOUT,
phi::BatchNormGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
......@@ -1334,6 +1338,22 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
phi::BatchNormGradRawKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
}
#else
PD_REGISTER_KERNEL(batch_norm_grad,
GPU,
ALL_LAYOUT,
phi::BatchNormGradKernel,
float,
double,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
......@@ -1342,6 +1362,20 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
}
}
PD_REGISTER_KERNEL(batch_norm_grad_raw,
GPU,
ALL_LAYOUT,
phi::BatchNormGradRawKernel,
float,
double,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
}
#endif
#endif
#ifdef PADDLE_WITH_HIP
......
......@@ -1221,6 +1221,7 @@ PD_REGISTER_KERNEL(batch_norm,
ALL_LAYOUT,
phi::BatchNormKernel,
float,
phi::dtype::bfloat16,
phi::dtype::float16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
......@@ -1232,6 +1233,28 @@ PD_REGISTER_KERNEL(batch_norm,
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(batch_norm,
GPU,
ALL_LAYOUT,
phi::BatchNormKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
}
#else
PD_REGISTER_KERNEL(batch_norm,
GPU,
ALL_LAYOUT,
......@@ -1250,5 +1273,6 @@ PD_REGISTER_KERNEL(batch_norm,
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
}
#endif
#endif
......@@ -16,7 +16,12 @@ import os
import unittest
import numpy as np
from eager_op_test import OpTest, _set_use_system_allocator
from eager_op_test import (
OpTest,
_set_use_system_allocator,
convert_float_to_uint16,
convert_uint16_to_float,
)
from op import Operator
import paddle
......@@ -239,6 +244,9 @@ class TestBatchNormOpInference(unittest.TestCase):
raise ValueError("Unknown data layout.")
scale_shape = [c]
if dtype == np.uint16:
x_val = np.random.random_sample(x_shape).astype(np.float32)
else:
x_val = np.random.random_sample(x_shape).astype(dtype)
# generate some negative values to test case with relu fused
x_val = x_val - 0.5
......@@ -248,12 +256,20 @@ class TestBatchNormOpInference(unittest.TestCase):
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
if dtype == np.uint16:
y_out = _reference_testing(
x_val, scale_val, bias_val, mean, variance, epsilon, data_layout
).astype(np.float32)
y_out = convert_float_to_uint16(y_out)
else:
y_out = _reference_testing(
x_val, scale_val, bias_val, mean, variance, epsilon, data_layout
).astype(dtype)
if self.fuse_with_relu:
y_out = np.maximum(y_out, 0)
if dtype == np.uint16:
x_val = convert_float_to_uint16(x_val)
scope = core.Scope()
# create input
......@@ -324,6 +340,11 @@ class TestBatchNormOpInference(unittest.TestCase):
y_tensor._set_dims(dims)
# check inference result
atol = 1e-3
if dtype == np.uint16:
y_tensor = convert_uint16_to_float(y_tensor)
y_out = convert_uint16_to_float(y_out)
atol = 1e-2
self.__assert_close(
y_tensor,
y_out,
......@@ -335,7 +356,7 @@ class TestBatchNormOpInference(unittest.TestCase):
+ str(np.dtype(dtype))
+ str(np.array(y_tensor))
+ str(y_out),
atol=1e-3,
atol=atol,
)
def test_check_output(self):
......@@ -376,6 +397,29 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
self.check_with_place(place, data_format, self.dtype, [2, 3])
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestBF16BatchNormOpInference(TestBatchNormOpInference):
def setUp(self):
self.dtype = np.uint16
self.use_mkldnn = False
self.fuse_with_relu = False
self.init_kernel_type()
def test_check_output(self):
places = [core.CUDAPlace(0)]
for place in places:
# for data_format in ["NCHW", "NHWC"]:
for data_format in ["NCHW"]:
self.check_with_place(
place, data_format, self.dtype, [2, 3, 4, 5]
)
self.check_with_place(place, data_format, self.dtype, [2, 3])
class TestBatchNormOpTraining(unittest.TestCase):
def setUp(self):
self.use_mkldnn = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册