未验证 提交 eb9d07e5 编写于 作者: C cyber-pioneer 提交者: GitHub

fix batch_norm cuda grad kernel test mode bug (#54681)

上级 5f92cc54
...@@ -1119,30 +1119,58 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -1119,30 +1119,58 @@ void BatchNormGradRawKernel(const Context &ctx,
} }
if (compute_format == DataLayout::kNCHW) { if (compute_format == DataLayout::kNCHW) {
if (d_x) { if (data_layout == DataLayout::kNHWC) {
KeBNBackwardData<T, phi::DataLayout::kNCHW> if (d_x) {
<<<grid1, block, 0, stream>>>(d_y->data<T>(), KeBNBackwardData<T, phi::DataLayout::kNHWC>
scale.data<BatchNormParamType<T>>(), <<<grid1, block, 0, stream>>>(d_y->data<T>(),
running_var_data, scale.data<BatchNormParamType<T>>(),
epsilon, running_var_data,
C, epsilon,
H * W, C,
num, H * W,
d_x->data<T>()); num,
} d_x->data<T>());
if (d_scale && d_bias) { }
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNCHW> if (d_scale && d_bias) {
<<<grid2, block, 0, stream>>>( KeBNBackwardScaleBias<T, block, phi::DataLayout::kNHWC>
d_y->data<T>(), <<<grid2, block, 0, stream>>>(
x.data<T>(), d_y->data<T>(),
running_mean_data, x.data<T>(),
running_var_data, running_mean_data,
epsilon, running_var_data,
N, epsilon,
C, N,
H * W * D, C,
d_scale->data<BatchNormParamType<T>>(), H * W * D,
d_bias->data<BatchNormParamType<T>>()); d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} else {
if (d_x) {
KeBNBackwardData<T, phi::DataLayout::kNCHW>
<<<grid1, block, 0, stream>>>(d_y->data<T>(),
scale.data<BatchNormParamType<T>>(),
running_var_data,
epsilon,
C,
H * W,
num,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNCHW>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
running_mean_data,
running_var_data,
epsilon,
N,
C,
H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} }
} else { } else {
if (d_x) { if (d_x) {
......
...@@ -253,6 +253,21 @@ class TestBatchNormOpNCHWTestMode(TestBatchNormOp): ...@@ -253,6 +253,21 @@ class TestBatchNormOpNCHWTestMode(TestBatchNormOp):
self.use_global_stats = True self.use_global_stats = True
class TestBatchNormOpNHWCTestMode(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = True
class TestBatchNormOpNCHWFp64(TestBatchNormOp): class TestBatchNormOpNCHWFp64(TestBatchNormOp):
def initConfig(self): def initConfig(self):
self.fw_comp_atol = 1e-11 self.fw_comp_atol = 1e-11
...@@ -283,6 +298,21 @@ class TestBatchNormOpNCHWTestModeFp64(TestBatchNormOp): ...@@ -283,6 +298,21 @@ class TestBatchNormOpNCHWTestModeFp64(TestBatchNormOp):
self.use_global_stats = None self.use_global_stats = None
class TestBatchNormOpNHWCTestModeFp64(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-15
self.fw_comp_rtol = 1e-15
self.rev_comp_atol = 1e-15
self.rev_comp_rtol = 1e-15
self.dtype = "float64"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNCHWFp16(TestBatchNormOp): class TestBatchNormOpNCHWFp16(TestBatchNormOp):
def initConfig(self): def initConfig(self):
self.fw_comp_atol = 1e-3 self.fw_comp_atol = 1e-3
...@@ -313,6 +343,21 @@ class TestBatchNormOpNCHWTestModeFp16(TestBatchNormOp): ...@@ -313,6 +343,21 @@ class TestBatchNormOpNCHWTestModeFp16(TestBatchNormOp):
self.use_global_stats = None self.use_global_stats = None
class TestBatchNormOpNHWCTestModeFp16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.dtype = "float16"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
@unittest.skipIf( @unittest.skipIf(
not core.is_compiled_with_cuda() not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)), or not core.is_bfloat16_supported(core.CUDAPlace(0)),
...@@ -357,6 +402,28 @@ class TestBatchNormOpNCHWTestModebf16(TestBatchNormOp): ...@@ -357,6 +402,28 @@ class TestBatchNormOpNCHWTestModebf16(TestBatchNormOp):
self.use_global_stats = None self.use_global_stats = None
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestBatchNormOpNHWCTestModebf16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.cinn_atol = 1e-3
self.cinn_rtol = 1e-3
self.dtype = "uint16"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWC(TestBatchNormOp): class TestBatchNormOpNHWC(TestBatchNormOp):
def initConfig(self): def initConfig(self):
self.fw_comp_atol = 1e-5 self.fw_comp_atol = 1e-5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册