未验证 提交 1080d4fc 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Sync_batch_norm support bfloat16 (#52921)

* [AMP OP&Test] Sync_batch_norm support bfloat16

* fix

* fix
上级 0abdcff6
......@@ -321,6 +321,28 @@ PD_REGISTER_KERNEL(sync_batch_norm,
}
}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(sync_batch_norm,
GPU,
ALL_LAYOUT,
phi::SyncBatchNormKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16 ||
kernel_key.dtype() == phi::DataType::BFLOAT16) {
kernel->InputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->InputAt(4).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
}
}
#else
PD_REGISTER_KERNEL(sync_batch_norm,
GPU,
ALL_LAYOUT,
......@@ -340,6 +362,7 @@ PD_REGISTER_KERNEL(sync_batch_norm,
}
}
#endif
#endif
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(sync_batch_norm_grad,
......@@ -349,6 +372,16 @@ PD_REGISTER_KERNEL(sync_batch_norm_grad,
float,
phi::dtype::float16) {}
#else
#if CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(sync_batch_norm_grad,
GPU,
ALL_LAYOUT,
phi::SyncBatchNormGradKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
#else
PD_REGISTER_KERNEL(sync_batch_norm_grad,
GPU,
ALL_LAYOUT,
......@@ -357,6 +390,7 @@ PD_REGISTER_KERNEL(sync_batch_norm_grad,
double,
phi::dtype::float16) {}
#endif
#endif
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(sync_batch_norm_coo,
......
......@@ -24,7 +24,11 @@ import unittest
import numpy as np
from decorator_helper import prog_scope
from eager_op_test import OpTest, _set_use_system_allocator
from eager_op_test import (
OpTest,
_set_use_system_allocator,
convert_float_to_uint16,
)
import paddle
from paddle import fluid, nn
......@@ -85,7 +89,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
cmds += ["--fetch_list", str(fetch_list)]
if only_forward:
cmds += ["--only_forward"]
if self.dtype == np.float16:
if self.dtype == np.float16 or self.dtype == np.uint16:
cmds += ["--use_cudnn"]
p = subprocess.run(cmds)
assert p.returncode == 0, f"Fleet train: Failed: {p}"
......@@ -98,7 +102,7 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
startup = fluid.Program()
main.random_seed = seed
startup.random_seed = seed
use_cudnn = self.dtype == np.float16
use_cudnn = (self.dtype == np.float16) or (self.dtype == np.uint16)
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
data = paddle.static.data(
......@@ -144,7 +148,14 @@ class TestSyncBatchNormOpTraining(unittest.TestCase):
os.environ['FLAGS_cudnn_deterministic'] = "1"
paddle.enable_static()
scope = core.Scope()
data = np.random.random(size=self.dshape).astype(self.dtype) * 4.0 - 2
if self.dtype == np.uint16:
data = convert_float_to_uint16(
np.random.random(size=self.dshape).astype(np.float32) * 4.0 - 2
)
else:
data = (
np.random.random(size=self.dshape).astype(self.dtype) * 4.0 - 2
)
stride = self.N // core.get_cuda_device_count()
for id in range(core.get_cuda_device_count()):
filepath = os.path.join(
......@@ -269,6 +280,27 @@ class TestFP16SyncBatchNormOpTraining(TestSyncBatchNormOpTraining):
self.H = 32
self.W = 32
self.dshape = [self.N, self.C, self.H, self.W]
self.atol = 1e-3
self.data_dir = tempfile.TemporaryDirectory()
self.fleet_log_dir = tempfile.TemporaryDirectory()
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestBF16SyncBatchNormOpTraining(TestSyncBatchNormOpTraining):
"""sync_batch_norm op test for BF16 input."""
def setUp(self):
"""Setup."""
self.dtype = np.uint16
self.N = 8
self.C = 16
self.H = 32
self.W = 32
self.dshape = [self.N, self.C, self.H, self.W]
self.atol = 1e-2
self.data_dir = tempfile.TemporaryDirectory()
self.fleet_log_dir = tempfile.TemporaryDirectory()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册