未验证 提交 4c6f77d8 编写于 作者: C cyber-pioneer 提交者: GitHub

fix batch_norm grad kernel nhwc error (#54703)

上级 1df2ee6c
...@@ -1018,32 +1018,62 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -1018,32 +1018,62 @@ void BatchNormGradRawKernel(const Context &ctx,
} else { } else {
// This branch call CUDA kernels // This branch call CUDA kernels
if (compute_format == DataLayout::kNCHW) { if (compute_format == DataLayout::kNCHW) {
if (d_x) { if (data_layout == DataLayout::kNHWC) {
BNBackwardData<T, block, phi::DataLayout::kNCHW> if (d_x) {
<<<grid2, block, 0, ctx.stream()>>>( BNBackwardData<T, block, phi::DataLayout::kNHWC>
d_y->data<T>(), <<<grid2, block, 0, ctx.stream()>>>(
scale.data<BatchNormParamType<T>>(), d_y->data<T>(),
saved_mean_data, scale.data<BatchNormParamType<T>>(),
x.data<T>(), saved_mean_data,
saved_var_data, x.data<T>(),
C, saved_var_data,
N, C,
H * W * D, N,
d_x->data<T>()); H * W * D,
} d_x->data<T>());
if (d_scale && d_bias) { }
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNCHW> if (d_scale && d_bias) {
<<<grid2, block, 0, stream>>>( KeBNBackwardScaleBias<T, block, phi::DataLayout::kNHWC>
d_y->data<T>(), <<<grid2, block, 0, stream>>>(
x.data<T>(), d_y->data<T>(),
saved_mean_data, x.data<T>(),
saved_var_data, saved_mean_data,
epsilon, saved_var_data,
N, epsilon,
C, N,
H * W * D, C,
d_scale->data<BatchNormParamType<T>>(), H * W * D,
d_bias->data<BatchNormParamType<T>>()); d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} else {
if (d_x) {
BNBackwardData<T, block, phi::DataLayout::kNCHW>
<<<grid2, block, 0, ctx.stream()>>>(
d_y->data<T>(),
scale.data<BatchNormParamType<T>>(),
saved_mean_data,
x.data<T>(),
saved_var_data,
C,
N,
H * W * D,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNCHW>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
saved_mean_data,
saved_var_data,
epsilon,
N,
C,
H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} }
} else { } else {
if (d_x) { if (d_x) {
......
...@@ -97,8 +97,7 @@ class TestBatchNormOp(OpTest): ...@@ -97,8 +97,7 @@ class TestBatchNormOp(OpTest):
check_prim=True, check_prim=True,
only_check_prim=True, only_check_prim=True,
) )
elif self.data_format == "NCHW" and paddle.is_compiled_with_cuda(): if paddle.is_compiled_with_cuda():
# origin batch_norm cuda kernel differ in nhwc x_grad whether to calculate scale_grad and bias_grad
self.check_grad_with_place( self.check_grad_with_place(
core.CUDAPlace(0), core.CUDAPlace(0),
["X"], ["X"],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册