未验证 提交 ad49e0fb 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] Simplify bn vjp (#54012)

* recompute bn grad

* fix test case

---------
Co-authored-by: Nsunli <466530738@qq.com>
上级 36da353d
...@@ -1366,9 +1366,8 @@ void batch_norm_grad(const Tensor& x, ...@@ -1366,9 +1366,8 @@ void batch_norm_grad(const Tensor& x,
auto nhwc_out_grad = transpose<T>(out_grad_data, nchw_to_nhwc_dim); auto nhwc_out_grad = transpose<T>(out_grad_data, nchw_to_nhwc_dim);
auto nhwc_out_grad_sum = sum<T>(nhwc_out_grad, reduce_axis, dtype, false); auto nhwc_out_grad_sum = sum<T>(nhwc_out_grad, reduce_axis, dtype, false);
auto x_sub_mean = nhwc_x - mean_data; auto sum_dout_mul_diff = sum<T>(
auto sum_dout_mul_diff = nhwc_out_grad * (nhwc_x - mean_data), reduce_axis, dtype, false);
sum<T>(nhwc_out_grad * x_sub_mean, reduce_axis, dtype, false);
if (x_grad) { if (x_grad) {
if (use_global_stats) { if (use_global_stats) {
...@@ -1382,7 +1381,8 @@ void batch_norm_grad(const Tensor& x, ...@@ -1382,7 +1381,8 @@ void batch_norm_grad(const Tensor& x,
auto part1 = scale * rsqrt_var; auto part1 = scale * rsqrt_var;
auto mean_temp1 = nhwc_out_grad_sum / nhw; auto mean_temp1 = nhwc_out_grad_sum / nhw;
auto mean_temp2 = sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var; auto mean_temp2 = sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var;
auto part2 = nhwc_out_grad - mean_temp1 - x_sub_mean * mean_temp2; auto part2 =
nhwc_out_grad - mean_temp1 - (nhwc_x - mean_data) * mean_temp2;
auto x_grad_data = part1 * part2; auto x_grad_data = part1 * part2;
auto nchw_x_grad = transpose<T>(x_grad_data, nhwc_to_nchw_dim); auto nchw_x_grad = transpose<T>(x_grad_data, nhwc_to_nchw_dim);
...@@ -1403,11 +1403,10 @@ void batch_norm_grad(const Tensor& x, ...@@ -1403,11 +1403,10 @@ void batch_norm_grad(const Tensor& x,
} }
case DataLayout::kNHWC: { case DataLayout::kNHWC: {
if (x_grad) { if (x_grad) {
auto x_sub_mean = x_data - mean_data;
auto out_grad_data_sum = auto out_grad_data_sum =
sum<T>(out_grad_data, reduce_axis, dtype, false); sum<T>(out_grad_data, reduce_axis, dtype, false);
auto nhwc_sum_dout_mul_diff = auto nhwc_sum_dout_mul_diff = sum<T>(
sum<T>(out_grad_data * x_sub_mean, reduce_axis, dtype, false); out_grad_data * (x_data - mean_data), reduce_axis, dtype, false);
if (use_global_stats) { if (use_global_stats) {
auto x_grad_data = scale * rsqrt_var * out_grad_data; auto x_grad_data = scale * rsqrt_var * out_grad_data;
if (x.dtype() == phi::DataType::FLOAT16) { if (x.dtype() == phi::DataType::FLOAT16) {
...@@ -1420,7 +1419,8 @@ void batch_norm_grad(const Tensor& x, ...@@ -1420,7 +1419,8 @@ void batch_norm_grad(const Tensor& x,
auto mean_temp1 = out_grad_data_sum / nhw; auto mean_temp1 = out_grad_data_sum / nhw;
auto mean_temp2 = auto mean_temp2 =
nhwc_sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var; nhwc_sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var;
auto part2 = out_grad_data - mean_temp1 - x_sub_mean * mean_temp2; auto part2 =
out_grad_data - mean_temp1 - (x_data - mean_data) * mean_temp2;
auto x_grad_data = part1 * part2; auto x_grad_data = part1 * part2;
if (x.dtype() == phi::DataType::FLOAT16) { if (x.dtype() == phi::DataType::FLOAT16) {
......
...@@ -46,17 +46,16 @@ epoch_num = 1 ...@@ -46,17 +46,16 @@ epoch_num = 1
# The results in ci as as follows: # The results in ci as as follows:
DY2ST_PRIM_CINN_GT = [ DY2ST_PRIM_CINN_GT = [
5.828786849975586, 5.828786849975586,
8.332858085632324, 8.332863807678223,
5.026939868927002, 5.0373005867004395,
8.475804328918457, 8.464998245239258,
8.017110824584961, 8.20099925994873,
7.8353095054626465, 7.576723098754883,
9.731267929077148, 9.679173469543457,
8.193124771118164, 8.381753921508789,
8.155317306518555, 8.10612678527832,
10.185102462768555, 10.124727249145508,
] ]
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': True}) paddle.set_flags({'FLAGS_cudnn_deterministic': True})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册