未验证 提交 4c6f77d8 编写于 作者: C cyber-pioneer 提交者: GitHub

fix batch_norm grad kernel nhwc error (#54703)

上级 1df2ee6c
......@@ -1018,32 +1018,62 @@ void BatchNormGradRawKernel(const Context &ctx,
} else {
// This branch call CUDA kernels
if (compute_format == DataLayout::kNCHW) {
if (d_x) {
BNBackwardData<T, block, phi::DataLayout::kNCHW>
<<<grid2, block, 0, ctx.stream()>>>(
d_y->data<T>(),
scale.data<BatchNormParamType<T>>(),
saved_mean_data,
x.data<T>(),
saved_var_data,
C,
N,
H * W * D,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNCHW>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
saved_mean_data,
saved_var_data,
epsilon,
N,
C,
H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
if (data_layout == DataLayout::kNHWC) {
if (d_x) {
BNBackwardData<T, block, phi::DataLayout::kNHWC>
<<<grid2, block, 0, ctx.stream()>>>(
d_y->data<T>(),
scale.data<BatchNormParamType<T>>(),
saved_mean_data,
x.data<T>(),
saved_var_data,
C,
N,
H * W * D,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNHWC>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
saved_mean_data,
saved_var_data,
epsilon,
N,
C,
H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
} else {
if (d_x) {
BNBackwardData<T, block, phi::DataLayout::kNCHW>
<<<grid2, block, 0, ctx.stream()>>>(
d_y->data<T>(),
scale.data<BatchNormParamType<T>>(),
saved_mean_data,
x.data<T>(),
saved_var_data,
C,
N,
H * W * D,
d_x->data<T>());
}
if (d_scale && d_bias) {
KeBNBackwardScaleBias<T, block, phi::DataLayout::kNCHW>
<<<grid2, block, 0, stream>>>(
d_y->data<T>(),
x.data<T>(),
saved_mean_data,
saved_var_data,
epsilon,
N,
C,
H * W * D,
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>());
}
}
} else {
if (d_x) {
......
......@@ -97,8 +97,7 @@ class TestBatchNormOp(OpTest):
check_prim=True,
only_check_prim=True,
)
elif self.data_format == "NCHW" and paddle.is_compiled_with_cuda():
# origin batch_norm cuda kernel differ in nhwc x_grad whether to calculate scale_grad and bias_grad
if paddle.is_compiled_with_cuda():
self.check_grad_with_place(
core.CUDAPlace(0),
["X"],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册