未验证 提交 ad49e0fb 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] Simplify bn vjp (#54012)

* recompute bn grad

* fix test case

---------
Co-authored-by: Nsunli <466530738@qq.com>
上级 36da353d
......@@ -1366,9 +1366,8 @@ void batch_norm_grad(const Tensor& x,
auto nhwc_out_grad = transpose<T>(out_grad_data, nchw_to_nhwc_dim);
auto nhwc_out_grad_sum = sum<T>(nhwc_out_grad, reduce_axis, dtype, false);
auto x_sub_mean = nhwc_x - mean_data;
auto sum_dout_mul_diff =
sum<T>(nhwc_out_grad * x_sub_mean, reduce_axis, dtype, false);
auto sum_dout_mul_diff = sum<T>(
nhwc_out_grad * (nhwc_x - mean_data), reduce_axis, dtype, false);
if (x_grad) {
if (use_global_stats) {
......@@ -1382,7 +1381,8 @@ void batch_norm_grad(const Tensor& x,
auto part1 = scale * rsqrt_var;
auto mean_temp1 = nhwc_out_grad_sum / nhw;
auto mean_temp2 = sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var;
auto part2 = nhwc_out_grad - mean_temp1 - x_sub_mean * mean_temp2;
auto part2 =
nhwc_out_grad - mean_temp1 - (nhwc_x - mean_data) * mean_temp2;
auto x_grad_data = part1 * part2;
auto nchw_x_grad = transpose<T>(x_grad_data, nhwc_to_nchw_dim);
......@@ -1403,11 +1403,10 @@ void batch_norm_grad(const Tensor& x,
}
case DataLayout::kNHWC: {
if (x_grad) {
auto x_sub_mean = x_data - mean_data;
auto out_grad_data_sum =
sum<T>(out_grad_data, reduce_axis, dtype, false);
auto nhwc_sum_dout_mul_diff =
sum<T>(out_grad_data * x_sub_mean, reduce_axis, dtype, false);
auto nhwc_sum_dout_mul_diff = sum<T>(
out_grad_data * (x_data - mean_data), reduce_axis, dtype, false);
if (use_global_stats) {
auto x_grad_data = scale * rsqrt_var * out_grad_data;
if (x.dtype() == phi::DataType::FLOAT16) {
......@@ -1420,7 +1419,8 @@ void batch_norm_grad(const Tensor& x,
auto mean_temp1 = out_grad_data_sum / nhw;
auto mean_temp2 =
nhwc_sum_dout_mul_diff / nhw * rsqrt_var * rsqrt_var;
auto part2 = out_grad_data - mean_temp1 - x_sub_mean * mean_temp2;
auto part2 =
out_grad_data - mean_temp1 - (x_data - mean_data) * mean_temp2;
auto x_grad_data = part1 * part2;
if (x.dtype() == phi::DataType::FLOAT16) {
......
......@@ -46,17 +46,16 @@ epoch_num = 1
# The results in ci as as follows:
DY2ST_PRIM_CINN_GT = [
5.828786849975586,
8.332858085632324,
5.026939868927002,
8.475804328918457,
8.017110824584961,
7.8353095054626465,
9.731267929077148,
8.193124771118164,
8.155317306518555,
10.185102462768555,
8.332863807678223,
5.0373005867004395,
8.464998245239258,
8.20099925994873,
7.576723098754883,
9.679173469543457,
8.381753921508789,
8.10612678527832,
10.124727249145508,
]
if core.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': True})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册