未验证 提交 04f8c24e 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] simplify bn vjp code (#51933)

* simplify bn vjp code

* simplify composite rule

* polish name
上级 648563dd
...@@ -1382,8 +1382,11 @@ void batch_norm_grad(const Tensor& x, ...@@ -1382,8 +1382,11 @@ void batch_norm_grad(const Tensor& x,
case DataLayout::kNCHW: { case DataLayout::kNCHW: {
auto nhwc_x = transpose<T>(x_data, nchw_to_nhwc_dim); auto nhwc_x = transpose<T>(x_data, nchw_to_nhwc_dim);
auto nhwc_out_grad = transpose<T>(out_grad_data, nchw_to_nhwc_dim); auto nhwc_out_grad = transpose<T>(out_grad_data, nchw_to_nhwc_dim);
auto nhwc_out_grad_sum = sum<T>(nhwc_out_grad, reduce_axis, dtype, false);
auto x_sub_mean = nhwc_x - mean_data; auto x_sub_mean = nhwc_x - mean_data;
auto sum_dout_mul_diff =
sum<T>(nhwc_out_grad * x_sub_mean, reduce_axis, dtype, false);
if (x_grad) { if (x_grad) {
if (use_global_stats) { if (use_global_stats) {
...@@ -1392,11 +1395,8 @@ void batch_norm_grad(const Tensor& x, ...@@ -1392,11 +1395,8 @@ void batch_norm_grad(const Tensor& x,
set_output<T>(nchw_x_grad, x_grad); set_output<T>(nchw_x_grad, x_grad);
} else { } else {
auto part1 = scale * rsqrt_var; auto part1 = scale * rsqrt_var;
auto mean_temp1 = auto mean_temp1 = nhwc_out_grad_sum / nhw;
sum<T>(nhwc_out_grad, reduce_axis, dtype, false) / nhw; auto mean_temp2 = sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var;
auto tmp = nhwc_out_grad * x_sub_mean * rsqrt_var * rsqrt_var / nhw;
auto mean_temp2 = sum<T>(tmp, reduce_axis, dtype, false);
auto part2 = nhwc_out_grad - mean_temp1 - x_sub_mean * mean_temp2; auto part2 = nhwc_out_grad - mean_temp1 - x_sub_mean * mean_temp2;
auto x_grad_data = part1 * part2; auto x_grad_data = part1 * part2;
...@@ -1408,29 +1408,30 @@ void batch_norm_grad(const Tensor& x, ...@@ -1408,29 +1408,30 @@ void batch_norm_grad(const Tensor& x,
} }
} }
if (scale_grad) { if (scale_grad) {
auto scale_grad_data = sum<T>( auto scale_grad_data = sum_dout_mul_diff * rsqrt_var;
nhwc_out_grad * x_sub_mean * rsqrt_var, reduce_axis, dtype, false);
set_output<T>(scale_grad_data, scale_grad); set_output<T>(scale_grad_data, scale_grad);
} }
if (bias_grad) { if (bias_grad) {
auto bias_grad_data = sum<T>(nhwc_out_grad, reduce_axis, dtype, false); set_output<T>(nhwc_out_grad_sum, bias_grad);
set_output<T>(bias_grad_data, bias_grad);
} }
break; break;
} }
case DataLayout::kNHWC: { case DataLayout::kNHWC: {
if (x_grad) { if (x_grad) {
auto x_sub_mean = x_data - mean_data; auto x_sub_mean = x_data - mean_data;
auto out_grad_data_sum =
sum<T>(out_grad_data, reduce_axis, dtype, false);
auto nhwc_sum_dout_mul_diff =
sum<T>(out_grad_data * x_sub_mean, reduce_axis, dtype, false);
if (use_global_stats) { if (use_global_stats) {
auto x_grad_data = scale * rsqrt_var * out_grad_data; auto x_grad_data = scale * rsqrt_var * out_grad_data;
set_output<T>(x_grad_data, x_grad); set_output<T>(x_grad_data, x_grad);
} else { } else {
auto part1 = scale * rsqrt_var; auto part1 = scale * rsqrt_var;
auto mean_temp1 =
sum<T>(out_grad_data, reduce_axis, dtype, false) / nhw;
auto tmp = out_grad_data * x_sub_mean * rsqrt_var * rsqrt_var / nhw; auto mean_temp1 = out_grad_data_sum / nhw;
auto mean_temp2 = sum<T>(tmp, reduce_axis, dtype, false); auto mean_temp2 =
nhwc_sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var;
auto part2 = out_grad_data - mean_temp1 - x_sub_mean * mean_temp2; auto part2 = out_grad_data - mean_temp1 - x_sub_mean * mean_temp2;
auto x_grad_data = part1 * part2; auto x_grad_data = part1 * part2;
...@@ -1440,16 +1441,11 @@ void batch_norm_grad(const Tensor& x, ...@@ -1440,16 +1441,11 @@ void batch_norm_grad(const Tensor& x,
set_output<T>(x_grad_data, x_grad); set_output<T>(x_grad_data, x_grad);
} }
if (scale_grad) { if (scale_grad) {
auto scale_grad_data = sum<T>(out_grad_data * x_sub_mean * rsqrt_var, auto scale_grad_data = nhwc_sum_dout_mul_diff * rsqrt_var;
reduce_axis,
dtype,
false);
set_output<T>(scale_grad_data, scale_grad); set_output<T>(scale_grad_data, scale_grad);
} }
if (bias_grad) { if (bias_grad) {
auto bias_grad_data = set_output<T>(out_grad_data_sum, bias_grad);
sum<T>(out_grad_data, reduce_axis, dtype, false);
set_output<T>(bias_grad_data, bias_grad);
} }
break; break;
} }
......
...@@ -93,7 +93,7 @@ def composite_batchnorm( ...@@ -93,7 +93,7 @@ def composite_batchnorm(
1 if i in reduce_axes else s for i, s in enumerate(x.shape) 1 if i in reduce_axes else s for i, s in enumerate(x.shape)
) )
half = -0.5 half = full([1], -0.5, x.dtype)
if not use_run_stat: if not use_run_stat:
batch_mean = mean(x, reduce_axes) batch_mean = mean(x, reduce_axes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册