未验证 提交 04f8c24e 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] simplify bn vjp code (#51933)

* simplify bn vjp code

* simplify composite rule

* polish name
上级 648563dd
......@@ -1382,8 +1382,11 @@ void batch_norm_grad(const Tensor& x,
case DataLayout::kNCHW: {
auto nhwc_x = transpose<T>(x_data, nchw_to_nhwc_dim);
auto nhwc_out_grad = transpose<T>(out_grad_data, nchw_to_nhwc_dim);
auto nhwc_out_grad_sum = sum<T>(nhwc_out_grad, reduce_axis, dtype, false);
auto x_sub_mean = nhwc_x - mean_data;
auto sum_dout_mul_diff =
sum<T>(nhwc_out_grad * x_sub_mean, reduce_axis, dtype, false);
if (x_grad) {
if (use_global_stats) {
......@@ -1392,11 +1395,8 @@ void batch_norm_grad(const Tensor& x,
set_output<T>(nchw_x_grad, x_grad);
} else {
auto part1 = scale * rsqrt_var;
auto mean_temp1 =
sum<T>(nhwc_out_grad, reduce_axis, dtype, false) / nhw;
auto tmp = nhwc_out_grad * x_sub_mean * rsqrt_var * rsqrt_var / nhw;
auto mean_temp2 = sum<T>(tmp, reduce_axis, dtype, false);
auto mean_temp1 = nhwc_out_grad_sum / nhw;
auto mean_temp2 = sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var;
auto part2 = nhwc_out_grad - mean_temp1 - x_sub_mean * mean_temp2;
auto x_grad_data = part1 * part2;
......@@ -1408,29 +1408,30 @@ void batch_norm_grad(const Tensor& x,
}
}
if (scale_grad) {
auto scale_grad_data = sum<T>(
nhwc_out_grad * x_sub_mean * rsqrt_var, reduce_axis, dtype, false);
auto scale_grad_data = sum_dout_mul_diff * rsqrt_var;
set_output<T>(scale_grad_data, scale_grad);
}
if (bias_grad) {
auto bias_grad_data = sum<T>(nhwc_out_grad, reduce_axis, dtype, false);
set_output<T>(bias_grad_data, bias_grad);
set_output<T>(nhwc_out_grad_sum, bias_grad);
}
break;
}
case DataLayout::kNHWC: {
if (x_grad) {
auto x_sub_mean = x_data - mean_data;
auto out_grad_data_sum =
sum<T>(out_grad_data, reduce_axis, dtype, false);
auto nhwc_sum_dout_mul_diff =
sum<T>(out_grad_data * x_sub_mean, reduce_axis, dtype, false);
if (use_global_stats) {
auto x_grad_data = scale * rsqrt_var * out_grad_data;
set_output<T>(x_grad_data, x_grad);
} else {
auto part1 = scale * rsqrt_var;
auto mean_temp1 =
sum<T>(out_grad_data, reduce_axis, dtype, false) / nhw;
auto tmp = out_grad_data * x_sub_mean * rsqrt_var * rsqrt_var / nhw;
auto mean_temp2 = sum<T>(tmp, reduce_axis, dtype, false);
auto mean_temp1 = out_grad_data_sum / nhw;
auto mean_temp2 =
nhwc_sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var;
auto part2 = out_grad_data - mean_temp1 - x_sub_mean * mean_temp2;
auto x_grad_data = part1 * part2;
......@@ -1440,16 +1441,11 @@ void batch_norm_grad(const Tensor& x,
set_output<T>(x_grad_data, x_grad);
}
if (scale_grad) {
auto scale_grad_data = sum<T>(out_grad_data * x_sub_mean * rsqrt_var,
reduce_axis,
dtype,
false);
auto scale_grad_data = nhwc_sum_dout_mul_diff * rsqrt_var;
set_output<T>(scale_grad_data, scale_grad);
}
if (bias_grad) {
auto bias_grad_data =
sum<T>(out_grad_data, reduce_axis, dtype, false);
set_output<T>(bias_grad_data, bias_grad);
set_output<T>(out_grad_data_sum, bias_grad);
}
break;
}
......
......@@ -93,7 +93,7 @@ def composite_batchnorm(
1 if i in reduce_axes else s for i, s in enumerate(x.shape)
)
half = -0.5
half = full([1], -0.5, x.dtype)
if not use_run_stat:
batch_mean = mean(x, reduce_axes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册