未验证 提交 992250bf 编写于 作者: N niuliling123 提交者: GitHub

Modified the Kernel policy. When the compute is NHWC (#48563)

上级 5c64d84f
......@@ -858,15 +858,20 @@ void BatchNormGradRawKernel(const Context &ctx,
// ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data));
#else
// CUDNN only support small batch size
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
if (use_native_kernel) {
if (x_dims.size() == 2) {
}
// CUDNN only support small batch size
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
bool use_native_nhwc =
d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
: false;
const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
if (use_native_nhwc || (d_x && d_scale && d_bias)) {
if (use_native_kernel || use_native_nhwc) {
if (x_dims.size() == 2 || use_native_nhwc) {
dim3 block;
dim3 grid;
const int block_size = 512;
......@@ -937,6 +942,21 @@ void BatchNormGradRawKernel(const Context &ctx,
flag_ptr);
}
// 2. reduce_sum(x, dy, mean) => dscale, dbias
BatchNormParamType<T> *dscale = nullptr;
BatchNormParamType<T> *dbias = nullptr;
bool with_scale = false;
if (d_scale && d_bias) {
dscale = ctx.template Alloc<BatchNormParamType<T>>(d_scale);
dbias = ctx.template Alloc<BatchNormParamType<T>>(d_bias);
} else {
DenseTensor dscale_mem =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
DenseTensor dbias_mem =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
dscale = dscale_mem.data<BatchNormParamType<T>>();
dbias = dbias_mem.data<BatchNormParamType<T>>();
}
BNBackward2DChannelLastStage2<T, block_size>
<<<grid, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(),
......@@ -948,8 +968,8 @@ void BatchNormGradRawKernel(const Context &ctx,
H * W * D,
epsilon,
block_data_ptr,
ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.template Alloc<BatchNormParamType<T>>(d_bias),
dscale,
dbias,
flag_ptr);
// 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx
......@@ -958,8 +978,8 @@ void BatchNormGradRawKernel(const Context &ctx,
transformed_d_y.template data<T>(),
transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(),
d_scale->data<BatchNormParamType<T>>(),
d_bias->data<BatchNormParamType<T>>(),
dscale,
dbias,
mean_ptr,
variance_ptr,
C,
......@@ -1169,6 +1189,7 @@ void BatchNormGradRawKernel(const Context &ctx,
paddle::platform::dynload::cudnnDestroyTensorDescriptor(
bn_param_desc_));
#endif
} else {
const auto *running_mean = mean.get_ptr();
const auto *running_var = variance.get_ptr();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册