未验证 提交 992250bf 编写于 作者: N niuliling123 提交者: GitHub

Modified the Kernel policy. When the compute is NHWC (#48563)

上级 5c64d84f
...@@ -858,15 +858,20 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -858,15 +858,20 @@ void BatchNormGradRawKernel(const Context &ctx,
// ctx.GetPlace()), // ctx.GetPlace()),
// epsilon, saved_mean_data, saved_var_data)); // epsilon, saved_mean_data, saved_var_data));
#else #else
// CUDNN only support small batch size }
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; // CUDNN only support small batch size
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240; // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801; const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const bool use_native_kernel = const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || bool use_native_nhwc =
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); d_x ? (x_dims.size() == 4 && compute_format == DataLayout::kNHWC)
if (use_native_kernel) { : false;
if (x_dims.size() == 2) { const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
if (use_native_nhwc || (d_x && d_scale && d_bias)) {
if (use_native_kernel || use_native_nhwc) {
if (x_dims.size() == 2 || use_native_nhwc) {
dim3 block; dim3 block;
dim3 grid; dim3 grid;
const int block_size = 512; const int block_size = 512;
...@@ -937,6 +942,21 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -937,6 +942,21 @@ void BatchNormGradRawKernel(const Context &ctx,
flag_ptr); flag_ptr);
} }
// 2. reduce_sum(x, dy, mean) => dscale, dbias // 2. reduce_sum(x, dy, mean) => dscale, dbias
BatchNormParamType<T> *dscale = nullptr;
BatchNormParamType<T> *dbias = nullptr;
bool with_scale = false;
if (d_scale && d_bias) {
dscale = ctx.template Alloc<BatchNormParamType<T>>(d_scale);
dbias = ctx.template Alloc<BatchNormParamType<T>>(d_bias);
} else {
DenseTensor dscale_mem =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
DenseTensor dbias_mem =
phi::Empty<BatchNormParamType<T>, Context>(ctx, {C});
dscale = dscale_mem.data<BatchNormParamType<T>>();
dbias = dbias_mem.data<BatchNormParamType<T>>();
}
BNBackward2DChannelLastStage2<T, block_size> BNBackward2DChannelLastStage2<T, block_size>
<<<grid, block, 0, ctx.stream()>>>( <<<grid, block, 0, ctx.stream()>>>(
transformed_d_y.template data<T>(), transformed_d_y.template data<T>(),
...@@ -948,8 +968,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -948,8 +968,8 @@ void BatchNormGradRawKernel(const Context &ctx,
H * W * D, H * W * D,
epsilon, epsilon,
block_data_ptr, block_data_ptr,
ctx.template Alloc<BatchNormParamType<T>>(d_scale), dscale,
ctx.template Alloc<BatchNormParamType<T>>(d_bias), dbias,
flag_ptr); flag_ptr);
// 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx // 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx
...@@ -958,8 +978,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -958,8 +978,8 @@ void BatchNormGradRawKernel(const Context &ctx,
transformed_d_y.template data<T>(), transformed_d_y.template data<T>(),
transformed_x.template data<T>(), transformed_x.template data<T>(),
scale.template data<BatchNormParamType<T>>(), scale.template data<BatchNormParamType<T>>(),
d_scale->data<BatchNormParamType<T>>(), dscale,
d_bias->data<BatchNormParamType<T>>(), dbias,
mean_ptr, mean_ptr,
variance_ptr, variance_ptr,
C, C,
...@@ -1169,6 +1189,7 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -1169,6 +1189,7 @@ void BatchNormGradRawKernel(const Context &ctx,
paddle::platform::dynload::cudnnDestroyTensorDescriptor( paddle::platform::dynload::cudnnDestroyTensorDescriptor(
bn_param_desc_)); bn_param_desc_));
#endif #endif
} else { } else {
const auto *running_mean = mean.get_ptr(); const auto *running_mean = mean.get_ptr();
const auto *running_var = variance.get_ptr(); const auto *running_var = variance.get_ptr();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册