未验证 提交 eb610740 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick] Fix bn performance degradation (#50382)

att, cherry-pick: #48563 , #50287
上级 59fec5d6
...@@ -855,7 +855,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -855,7 +855,8 @@ void BatchNormGradRawKernel(const Context &ctx,
} }
// CUDNN only support small batch size // CUDNN only support small batch size
bool use_native_nhwc = bool use_native_nhwc =
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC) d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC &&
H * W >= CUDNN_SPATIAL_THRESHOLD_EVAL)
: false; : false;
const bool use_native_kernel = const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
...@@ -933,6 +934,21 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -933,6 +934,21 @@ void BatchNormGradRawKernel(const Context &ctx,
flag_ptr); flag_ptr);
} }
// 2. reduce_sum(x, dy, mean) => dscale, dbias // 2. reduce_sum(x, dy, mean) => dscale, dbias
BatchNormParamType<T> *dscale = nullptr;
BatchNormParamType<T> *dbias = nullptr;
bool with_scale = false;
if (d_scale && d_bias) {
dscale = ctx.template Alloc<BatchNormParamType<T>>(d_scale);
dbias = ctx.template Alloc<BatchNormParamType<T>>(d_bias);
} else {
DenseTensor dscale_mem =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
DenseTensor dbias_mem =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
dscale = dscale_mem.data<BatchNormParamType<T>>();
dbias = dbias_mem.data<BatchNormParamType<T>>();
}
BNBackward2DChannelLastStage2<T, block_size> BNBackward2DChannelLastStage2<T, block_size>
<<<grid, block, 0, ctx.stream()>>>( <<<grid, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(), transformed_d_y.template data<T>(),
...@@ -944,8 +960,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -944,8 +960,8 @@ void BatchNormGradRawKernel(const Context &ctx,
H * W * D, H * W * D,
epsilon, epsilon,
block_data_ptr, block_data_ptr,
ctx.template Alloc<BatchNormParamType<T>>(d_scale), dscale,
ctx.template Alloc<BatchNormParamType<T>>(d_bias), dbias,
flag_ptr); flag_ptr);
// 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx // 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx
...@@ -954,8 +970,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -954,8 +970,8 @@ void BatchNormGradRawKernel(const Context &ctx,
transformed_d_y.template data<T>(), transformed_d_y.template data<T>(),
transformed_x.template data<T>(), transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(), scale.template data<BatchNormParamType<T>>(),
d_scale->data<BatchNormParamType<T>>(), dscale,
d_bias->data<BatchNormParamType<T>>(), dbias,
mean_ptr, mean_ptr,
variance_ptr, variance_ptr,
C, C,
...@@ -1165,6 +1181,7 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -1165,6 +1181,7 @@ void BatchNormGradRawKernel(const Context &ctx,
paddle::platform::dynload::cudnnDestroyTensorDescriptor( paddle::platform::dynload::cudnnDestroyTensorDescriptor(
bn_param_desc_)); bn_param_desc_));
#endif #endif
} else { } else {
const auto *running_mean = mean.get_ptr(); const auto *running_mean = mean.get_ptr();
const auto *running_var = variance.get_ptr(); const auto *running_var = variance.get_ptr();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册