未验证 提交 523f8a26 编写于 作者: G Guoxia Wang 提交者: GitHub

[AMP OP&Test] support bf16 for batch norm (#52407)

* [AMP OP&Test] support bf16 for batchnorm

* codestyle

* Update batch_norm_grad_kernel.cu

* Update batch_norm_kernel.cu

* fix codestyle

* fix

* fix

* fix

* fix

* fix

* Update batch_norm_kernel.cc
上级 a2060568
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/batch_norm_kernel.h" #include "paddle/phi/kernels/batch_norm_kernel.h"
#include "paddle/phi/backends/gpu/gpu_dnn.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/empty_kernel.h"
...@@ -66,6 +67,22 @@ PD_REGISTER_KERNEL(batch_norm_infer, ...@@ -66,6 +67,22 @@ PD_REGISTER_KERNEL(batch_norm_infer,
float, float,
double) {} double) {}
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(batch_norm_infer,
GPU,
ALL_LAYOUT,
phi::BatchNormInferKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
#else
PD_REGISTER_KERNEL(batch_norm_infer, PD_REGISTER_KERNEL(batch_norm_infer,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -79,6 +96,7 @@ PD_REGISTER_KERNEL(batch_norm_infer, ...@@ -79,6 +96,7 @@ PD_REGISTER_KERNEL(batch_norm_infer,
} }
} }
#endif #endif
#endif
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(batch_norm_infer, PD_REGISTER_KERNEL(batch_norm_infer,
GPU, GPU,
......
...@@ -1314,14 +1314,18 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw, ...@@ -1314,14 +1314,18 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {}
#else #else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(batch_norm_grad, PD_REGISTER_KERNEL(batch_norm_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::BatchNormGradKernel, phi::BatchNormGradKernel,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::float16) { phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) { if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
...@@ -1334,6 +1338,22 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw, ...@@ -1334,6 +1338,22 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
phi::BatchNormGradRawKernel, phi::BatchNormGradRawKernel,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
}
#else
PD_REGISTER_KERNEL(batch_norm_grad,
GPU,
ALL_LAYOUT,
phi::BatchNormGradKernel,
float,
double,
phi::dtype::float16) { phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) { if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
...@@ -1342,6 +1362,20 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw, ...@@ -1342,6 +1362,20 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
} }
} }
PD_REGISTER_KERNEL(batch_norm_grad_raw,
GPU,
ALL_LAYOUT,
phi::BatchNormGradRawKernel,
float,
double,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
}
}
#endif
#endif #endif
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
......
...@@ -1221,6 +1221,7 @@ PD_REGISTER_KERNEL(batch_norm, ...@@ -1221,6 +1221,7 @@ PD_REGISTER_KERNEL(batch_norm,
ALL_LAYOUT, ALL_LAYOUT,
phi::BatchNormKernel, phi::BatchNormKernel,
float, float,
phi::dtype::bfloat16,
phi::dtype::float16) { phi::dtype::float16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32); kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32); kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
...@@ -1232,6 +1233,28 @@ PD_REGISTER_KERNEL(batch_norm, ...@@ -1232,6 +1233,28 @@ PD_REGISTER_KERNEL(batch_norm,
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
} }
#else #else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(batch_norm,
GPU,
ALL_LAYOUT,
phi::BatchNormKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
}
#else
PD_REGISTER_KERNEL(batch_norm, PD_REGISTER_KERNEL(batch_norm,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -1250,5 +1273,6 @@ PD_REGISTER_KERNEL(batch_norm, ...@@ -1250,5 +1273,6 @@ PD_REGISTER_KERNEL(batch_norm,
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
} }
} }
#endif
#endif #endif
...@@ -16,7 +16,12 @@ import os ...@@ -16,7 +16,12 @@ import os
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest, _set_use_system_allocator from eager_op_test import (
OpTest,
_set_use_system_allocator,
convert_float_to_uint16,
convert_uint16_to_float,
)
from op import Operator from op import Operator
import paddle import paddle
...@@ -239,6 +244,9 @@ class TestBatchNormOpInference(unittest.TestCase): ...@@ -239,6 +244,9 @@ class TestBatchNormOpInference(unittest.TestCase):
raise ValueError("Unknown data layout.") raise ValueError("Unknown data layout.")
scale_shape = [c] scale_shape = [c]
if dtype == np.uint16:
x_val = np.random.random_sample(x_shape).astype(np.float32)
else:
x_val = np.random.random_sample(x_shape).astype(dtype) x_val = np.random.random_sample(x_shape).astype(dtype)
# generate some negative values to test case with relu fused # generate some negative values to test case with relu fused
x_val = x_val - 0.5 x_val = x_val - 0.5
...@@ -248,12 +256,20 @@ class TestBatchNormOpInference(unittest.TestCase): ...@@ -248,12 +256,20 @@ class TestBatchNormOpInference(unittest.TestCase):
mean = np.zeros(scale_shape).astype(np.float32) mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32) variance = np.ones(scale_shape).astype(np.float32)
if dtype == np.uint16:
y_out = _reference_testing(
x_val, scale_val, bias_val, mean, variance, epsilon, data_layout
).astype(np.float32)
y_out = convert_float_to_uint16(y_out)
else:
y_out = _reference_testing( y_out = _reference_testing(
x_val, scale_val, bias_val, mean, variance, epsilon, data_layout x_val, scale_val, bias_val, mean, variance, epsilon, data_layout
).astype(dtype) ).astype(dtype)
if self.fuse_with_relu: if self.fuse_with_relu:
y_out = np.maximum(y_out, 0) y_out = np.maximum(y_out, 0)
if dtype == np.uint16:
x_val = convert_float_to_uint16(x_val)
scope = core.Scope() scope = core.Scope()
# create input # create input
...@@ -324,6 +340,11 @@ class TestBatchNormOpInference(unittest.TestCase): ...@@ -324,6 +340,11 @@ class TestBatchNormOpInference(unittest.TestCase):
y_tensor._set_dims(dims) y_tensor._set_dims(dims)
# check inference result # check inference result
atol = 1e-3
if dtype == np.uint16:
y_tensor = convert_uint16_to_float(y_tensor)
y_out = convert_uint16_to_float(y_out)
atol = 1e-2
self.__assert_close( self.__assert_close(
y_tensor, y_tensor,
y_out, y_out,
...@@ -335,7 +356,7 @@ class TestBatchNormOpInference(unittest.TestCase): ...@@ -335,7 +356,7 @@ class TestBatchNormOpInference(unittest.TestCase):
+ str(np.dtype(dtype)) + str(np.dtype(dtype))
+ str(np.array(y_tensor)) + str(np.array(y_tensor))
+ str(y_out), + str(y_out),
atol=1e-3, atol=atol,
) )
def test_check_output(self): def test_check_output(self):
...@@ -376,6 +397,29 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference): ...@@ -376,6 +397,29 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
self.check_with_place(place, data_format, self.dtype, [2, 3]) self.check_with_place(place, data_format, self.dtype, [2, 3])
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestBF16BatchNormOpInference(TestBatchNormOpInference):
def setUp(self):
self.dtype = np.uint16
self.use_mkldnn = False
self.fuse_with_relu = False
self.init_kernel_type()
def test_check_output(self):
places = [core.CUDAPlace(0)]
for place in places:
# for data_format in ["NCHW", "NHWC"]:
for data_format in ["NCHW"]:
self.check_with_place(
place, data_format, self.dtype, [2, 3, 4, 5]
)
self.check_with_place(place, data_format, self.dtype, [2, 3])
class TestBatchNormOpTraining(unittest.TestCase): class TestBatchNormOpTraining(unittest.TestCase):
def setUp(self): def setUp(self):
self.use_mkldnn = False self.use_mkldnn = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册