未验证 提交 c42cbb14 编写于 作者: Z zhangkaihuo 提交者: GitHub

BN1D inference support large batch_size (#44977)

上级 e8de9dfd
...@@ -691,6 +691,9 @@ void BatchNormKernel(const Context &ctx, ...@@ -691,6 +691,9 @@ void BatchNormKernel(const Context &ctx,
auto handle = ctx.cudnn_handle(); auto handle = ctx.cudnn_handle();
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
// Now, depending on whether we are running test or not, we have two paths. // Now, depending on whether we are running test or not, we have two paths.
// It is training mode when it's not reference AND not using pre-trained // It is training mode when it's not reference AND not using pre-trained
// model. // model.
...@@ -793,23 +796,58 @@ void BatchNormKernel(const Context &ctx, ...@@ -793,23 +796,58 @@ void BatchNormKernel(const Context &ctx,
// est_var->template data<BatchNormParamType<T>>())), // est_var->template data<BatchNormParamType<T>>())),
// epsilon)); // epsilon));
#else #else
PADDLE_ENFORCE_GPU_SUCCESS( const bool use_native_kernel =
paddle::platform::dynload::cudnnBatchNormalizationForwardInference( ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
handle, (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
// Note: PERSISTENT not implemented for inference if (use_native_kernel) {
CUDNN_BATCHNORM_SPATIAL, const int block_size = 256;
CudnnDataType<T>::kOne(), const int grid_size = (N * C * H * W * D + block_size - 1) / block_size;
CudnnDataType<T>::kZero(), if (compute_format == DataLayout::kNCHW) {
data_desc_, BNForwardInference<T, DataLayout::kNCHW>
transformed_x.template data<T>(), <<<grid_size, block_size, 0, ctx.stream()>>>(
data_desc_, transformed_x.template data<T>(),
ctx.template Alloc<T>(&transformed_y), est_mean->template data<BatchNormParamType<T>>(),
bn_param_desc_, est_var->template data<BatchNormParamType<T>>(),
scale.template data<BatchNormParamType<T>>(), scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(), bias.template data<BatchNormParamType<T>>(),
est_mean->template data<BatchNormParamType<T>>(), C,
est_var->template data<BatchNormParamType<T>>(), N,
epsilon)); H * W * D,
epsilon,
transformed_y.template data<T>());
} else {
BNForwardInference<T, DataLayout::kNHWC>
<<<grid_size, block_size, 0, ctx.stream()>>>(
transformed_x.template data<T>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
C,
N,
H * W * D,
epsilon,
transformed_y.template data<T>());
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationForwardInference(
handle,
// Note: PERSISTENT not implemented for inference
CUDNN_BATCHNORM_SPATIAL,
CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(),
data_desc_,
transformed_x.template data<T>(),
data_desc_,
ctx.template Alloc<T>(&transformed_y),
bn_param_desc_,
scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(),
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(),
epsilon));
}
#endif #endif
} else { } else {
// if MomentumTensor is set, use MomentumTensor value, momentum // if MomentumTensor is set, use MomentumTensor value, momentum
...@@ -909,8 +947,6 @@ void BatchNormKernel(const Context &ctx, ...@@ -909,8 +947,6 @@ void BatchNormKernel(const Context &ctx,
// BatchNormParamType<T>>(ctx.GetPlace())))); // BatchNormParamType<T>>(ctx.GetPlace()))));
#else #else
// const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070;
const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240;
const size_t CUDNN_SPATIAL_THRESHOLD = 880801;
const bool use_native_kernel = const bool use_native_kernel =
((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) ||
(x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册