diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index ec712d5869fbd968b6735a44e71666716de8b800..7b553db274d1f35daaf4c13daf4c52105db692d4 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -72,6 +72,40 @@ static __global__ void BNForwardInference(const T *x, } } +template +static __global__ void InverseVariance(const BatchNormParamType *variance, + const double epsilon, + const int C, + BatchNormParamType *inv_variance) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid < C) { + inv_variance[tid] = 1 / sqrt(variance[tid] + epsilon); + } +} + +template +static __global__ void BN1DForwardInference( + const T *x, + const BatchNormParamType *mean, + const BatchNormParamType *inv_variance, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const int C, + const int N, + const int HxW, + const double epsilon, + T *y) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + int num = N * C * HxW; + for (int i = gid; i < num; i += stride) { + const int c = layout == phi::DataLayout::kNCHW ? i / HxW % C : i % C; + BatchNormParamType x_sub_mean = + static_cast>(x[i]) - mean[c]; + y[i] = static_cast(scale[c] * x_sub_mean * inv_variance[c] + bias[c]); + } +} + template static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining( const T *x, @@ -795,7 +829,7 @@ void BatchNormKernel(const Context &ctx, // epsilon)); #else const bool use_native_kernel = - ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || + (x_dims.size() == 2 || (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); if (use_native_kernel) { const int block_size = 256; @@ -814,18 +848,43 @@ void BatchNormKernel(const Context &ctx, epsilon, transformed_y.template data()); } else { - BNForwardInference - <<>>( - transformed_x.template data(), - est_mean->template data>(), - est_var->template data>(), - scale.template data>(), - bias.template data>(), - C, - N, - H * W * D, - epsilon, - transformed_y.template data()); + if (x_dims.size() == 2) { + DenseTensor inv_var = phi::Empty>(ctx, {C}); + auto *inv_var_ptr = inv_var.data>(); + const int threads = 512 > C ? C : 512; + const int blocks = (C + 511) / 512; + InverseVariance<<>>( + est_var->template data>(), + epsilon, + C, + inv_var_ptr); + BN1DForwardInference + <<>>( + transformed_x.template data(), + est_mean->template data>(), + // est_var->template data>(), + inv_var_ptr, + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + epsilon, + transformed_y.template data()); + } else { + BNForwardInference + <<>>( + transformed_x.template data(), + est_mean->template data>(), + est_var->template data>(), + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + epsilon, + transformed_y.template data()); + } } } else { PADDLE_ENFORCE_GPU_SUCCESS(