未验证 提交 eb610740 编写于 作者: Z zhangkaihuo 提交者: GitHub

[cherry-pick] Fix bn performance degradation (#50382)

att, cherry-pick: #48563 , #50287
上级 59fec5d6
......@@ -855,7 +855,8 @@ void BatchNormGradRawKernel(const Context &ctx,
}
// CUDNN only support small batch size
bool use_native_nhwc =
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC &&
H * W >= CUDNN_SPATIAL_THRESHOLD_EVAL)
: false;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
......@@ -933,6 +934,21 @@ void BatchNormGradRawKernel(const Context &ctx,
flag_ptr);
}
// 2. reduce_sum(x, dy, mean) => dscale, dbias
BatchNormParamType<T> *dscale = nullptr;
BatchNormParamType<T> *dbias = nullptr;
bool with_scale = false;
if (d_scale && d_bias) {
dscale = ctx.template Alloc<BatchNormParamType<T>>(d_scale);
dbias = ctx.template Alloc<BatchNormParamType<T>>(d_bias);
} else {
DenseTensor dscale_mem =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
DenseTensor dbias_mem =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
dscale = dscale_mem.data<BatchNormParamType<T>>();
dbias = dbias_mem.data<BatchNormParamType<T>>();
}
BNBackward2DChannelLastStage2<T, block_size>
<<<grid, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
......@@ -944,8 +960,8 @@ void BatchNormGradRawKernel(const Context &ctx,
H * W * D,
epsilon,
block_data_ptr,
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias),
dscale,
dbias,
flag_ptr);
// 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx
......@@ -954,8 +970,8 @@ void BatchNormGradRawKernel(const Context &ctx,
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>(),
dscale,
dbias,
mean_ptr,
variance_ptr,
C,
......@@ -1165,6 +1181,7 @@ void BatchNormGradRawKernel(const Context &ctx,
paddle::platform::dynload::cudnnDestroyTensorDescriptor(
bn_param_desc_));
#endif
} else {
const auto *running_mean = mean.get_ptr();
const auto *running_var = variance.get_ptr();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册