未验证 提交 d9fb639c 编写于 作者: Z zhangkaihuo 提交者: GitHub

Fix batch_norm momentum (#51120)

上级 625e30b7
...@@ -872,6 +872,7 @@ void BatchNormKernel(const Context &ctx, ...@@ -872,6 +872,7 @@ void BatchNormKernel(const Context &ctx,
} else { } else {
double this_factor = 1. - momentum; double this_factor = 1. - momentum;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
this_factor = momentum;
const int num = transformed_x.numel(); const int num = transformed_x.numel();
const int block = 256; const int block = 256;
const int max_threads = ctx.GetMaxPhysicalThreadCount(); const int max_threads = ctx.GetMaxPhysicalThreadCount();
...@@ -945,6 +946,7 @@ void BatchNormKernel(const Context &ctx, ...@@ -945,6 +946,7 @@ void BatchNormKernel(const Context &ctx,
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN)); (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD_TRAIN));
if (use_native_kernel) { if (use_native_kernel) {
double this_factor = momentum;
dim3 block; dim3 block;
dim3 grid; dim3 grid;
const int block_size = 512; const int block_size = 512;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册