From 04f8c24eab869f362021df2ba93cdd03bf24fa2d Mon Sep 17 00:00:00 2001 From: cyber-pioneer <116002591+cyber-pioneer@users.noreply.github.com> Date: Mon, 3 Apr 2023 16:54:37 +0800 Subject: [PATCH] [Prim] simplify bn vjp code (#51933) * simplify bn vjp code * simplify composite rule * polish name --- .../composite_backward_api.h | 36 +++++++++---------- .../incubate/autograd/composite_rules.py | 2 +- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 98d5ca4845b..35b8fd65bdc 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1382,8 +1382,11 @@ void batch_norm_grad(const Tensor& x, case DataLayout::kNCHW: { auto nhwc_x = transpose(x_data, nchw_to_nhwc_dim); auto nhwc_out_grad = transpose(out_grad_data, nchw_to_nhwc_dim); + auto nhwc_out_grad_sum = sum(nhwc_out_grad, reduce_axis, dtype, false); auto x_sub_mean = nhwc_x - mean_data; + auto sum_dout_mul_diff = + sum(nhwc_out_grad * x_sub_mean, reduce_axis, dtype, false); if (x_grad) { if (use_global_stats) { @@ -1392,11 +1395,8 @@ void batch_norm_grad(const Tensor& x, set_output(nchw_x_grad, x_grad); } else { auto part1 = scale * rsqrt_var; - auto mean_temp1 = - sum(nhwc_out_grad, reduce_axis, dtype, false) / nhw; - - auto tmp = nhwc_out_grad * x_sub_mean * rsqrt_var * rsqrt_var / nhw; - auto mean_temp2 = sum(tmp, reduce_axis, dtype, false); + auto mean_temp1 = nhwc_out_grad_sum / nhw; + auto mean_temp2 = sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var; auto part2 = nhwc_out_grad - mean_temp1 - x_sub_mean * mean_temp2; auto x_grad_data = part1 * part2; @@ -1408,29 +1408,30 @@ void batch_norm_grad(const Tensor& x, } } if (scale_grad) { - auto scale_grad_data = sum( - nhwc_out_grad * x_sub_mean * rsqrt_var, reduce_axis, dtype, false); + auto scale_grad_data = sum_dout_mul_diff * rsqrt_var; set_output(scale_grad_data, scale_grad); } if (bias_grad) { - auto bias_grad_data = sum(nhwc_out_grad, reduce_axis, dtype, false); - set_output(bias_grad_data, bias_grad); + set_output(nhwc_out_grad_sum, bias_grad); } break; } case DataLayout::kNHWC: { if (x_grad) { auto x_sub_mean = x_data - mean_data; + auto out_grad_data_sum = + sum(out_grad_data, reduce_axis, dtype, false); + auto nhwc_sum_dout_mul_diff = + sum(out_grad_data * x_sub_mean, reduce_axis, dtype, false); if (use_global_stats) { auto x_grad_data = scale * rsqrt_var * out_grad_data; set_output(x_grad_data, x_grad); } else { auto part1 = scale * rsqrt_var; - auto mean_temp1 = - sum(out_grad_data, reduce_axis, dtype, false) / nhw; - auto tmp = out_grad_data * x_sub_mean * rsqrt_var * rsqrt_var / nhw; - auto mean_temp2 = sum(tmp, reduce_axis, dtype, false); + auto mean_temp1 = out_grad_data_sum / nhw; + auto mean_temp2 = + nhwc_sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var; auto part2 = out_grad_data - mean_temp1 - x_sub_mean * mean_temp2; auto x_grad_data = part1 * part2; @@ -1440,16 +1441,11 @@ void batch_norm_grad(const Tensor& x, set_output(x_grad_data, x_grad); } if (scale_grad) { - auto scale_grad_data = sum(out_grad_data * x_sub_mean * rsqrt_var, - reduce_axis, - dtype, - false); + auto scale_grad_data = nhwc_sum_dout_mul_diff * rsqrt_var; set_output(scale_grad_data, scale_grad); } if (bias_grad) { - auto bias_grad_data = - sum(out_grad_data, reduce_axis, dtype, false); - set_output(bias_grad_data, bias_grad); + set_output(out_grad_data_sum, bias_grad); } break; } diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 4e04232f119..ba5d4af6b4b 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -93,7 +93,7 @@ def composite_batchnorm( 1 if i in reduce_axes else s for i, s in enumerate(x.shape) ) - half = -0.5 + half = full([1], -0.5, x.dtype) if not use_run_stat: batch_mean = mean(x, reduce_axes) -- GitLab