未验证 提交 eb9d07e5 编写于 作者: C cyber-pioneer 提交者: GitHub

fix batch_norm cuda grad kernel test mode bug (#54681)

上级 5f92cc54
......@@ -1119,30 +1119,58 @@ void BatchNormGradRawKernel(const Context &ctx,
}
if (compute_format == DataLayout::kNCHW) {
if (d_x) {
KeBNBackwardData<T, phi::DataLayout::kNCHW>
<<<grid1, block, 0, stream>>>(d_y->data<T>(),
scale.data<BatchNormParamType<T>>(),
running_var_data,
epsilon,
C,
H * W,
num,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNCHW>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
running_mean_data,
running_var_data,
epsilon,
N,
C,
H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
if (data_layout == DataLayout::kNHWC) {
if (d_x) {
KeBNBackwardData<T, phi::DataLayout::kNHWC>
<<<grid1, block, 0, stream>>>(d_y->data<T>(),
scale.data<BatchNormParamType<T>>(),
running_var_data,
epsilon,
C,
H * W,
num,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNHWC>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
running_mean_data,
running_var_data,
epsilon,
N,
C,
H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} else {
if (d_x) {
KeBNBackwardData<T, phi::DataLayout::kNCHW>
<<<grid1, block, 0, stream>>>(d_y->data<T>(),
scale.data<BatchNormParamType<T>>(),
running_var_data,
epsilon,
C,
H * W,
num,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNCHW>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
running_mean_data,
running_var_data,
epsilon,
N,
C,
H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
}
} else {
if (d_x) {
......
......@@ -253,6 +253,21 @@ class TestBatchNormOpNCHWTestMode(TestBatchNormOp):
self.use_global_stats = True
class TestBatchNormOpNHWCTestMode(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
self.fw_comp_rtol = 1e-5
self.rev_comp_atol = 1e-5
self.rev_comp_rtol = 1e-5
self.dtype = "float32"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = True
class TestBatchNormOpNCHWFp64(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-11
......@@ -283,6 +298,21 @@ class TestBatchNormOpNCHWTestModeFp64(TestBatchNormOp):
self.use_global_stats = None
class TestBatchNormOpNHWCTestModeFp64(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-15
self.fw_comp_rtol = 1e-15
self.rev_comp_atol = 1e-15
self.rev_comp_rtol = 1e-15
self.dtype = "float64"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNCHWFp16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
......@@ -313,6 +343,21 @@ class TestBatchNormOpNCHWTestModeFp16(TestBatchNormOp):
self.use_global_stats = None
class TestBatchNormOpNHWCTestModeFp16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.dtype = "float16"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
......@@ -357,6 +402,28 @@ class TestBatchNormOpNCHWTestModebf16(TestBatchNormOp):
self.use_global_stats = None
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestBatchNormOpNHWCTestModebf16(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-3
self.fw_comp_rtol = 1e-3
self.rev_comp_atol = 1e-3
self.rev_comp_rtol = 1e-3
self.cinn_atol = 1e-3
self.cinn_rtol = 1e-3
self.dtype = "uint16"
self.shape = [16, 16, 16, 8]
self.training = False
self.momentum = 0.1
self.epsilon = 1e-05
self.data_format = "NHWC"
self.use_global_stats = None
class TestBatchNormOpNHWC(TestBatchNormOp):
def initConfig(self):
self.fw_comp_atol = 1e-5
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册