diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 1c6d1debbabd9333b4d2c2ff6f7ebc60b9ec914c..f65e22ec997fa4d48070ed5ea988220191f73d13 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -382,6 +382,7 @@ static __global__ void BNBackward2DChannelLastStage2( const int N, const int HxW, const double epsilon, + const bool is_test, BatchNormParamType *block_data_ptr, BatchNormParamType *dscale, BatchNormParamType *dbias, @@ -402,7 +403,8 @@ static __global__ void BNBackward2DChannelLastStage2( BatchNormParamType ds_sum = static_cast>(0); BatchNormParamType db_sum = static_cast>(0); BatchNormParamType mean_val = means[i]; - BatchNormParamType inv_var_val = variances[i]; + BatchNormParamType inv_var_val = + is_test ? 1.0 / sqrt(variances[i] + epsilon) : variances[i]; for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size; j += inner_loop_stride) { @@ -561,6 +563,51 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData( } } +template +void SetLaunchConfigInfoForChannelLast(const Context &ctx, + DenseTensor *block_data_tensor, + DenseTensor *flag_tensor, + BatchNormParamType **block_data_ptr, + int **flag_ptr, + const int N, + const int H, + const int W, + const int D, + const int C, + const int block_size, + dim3 *block, + dim3 *grid) { + const int MAX_GRID_SIZE = 128; + const int WARP_SIZE = 32; + + int block_x = std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE); + int block_y = std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16), + block_size / block_x); + if (block_x * block_y != block_size) { + block_x = + std::min(phi::funcs::details::GetLastPow2(C), block_size / block_y); + } + int grid_x = (C + block_x - 1) / block_x; + int grid_y = std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16), + MAX_GRID_SIZE); + + block->x = block_x; + block->y = block_y; + grid->x = grid_x; + grid->y = grid_y; + + if (grid->y > 1) { + *block_data_tensor = + phi::Empty, Context>(ctx, {2 * C * grid->y}); + *flag_tensor = phi::Empty(ctx, {grid->x}); + + *block_data_ptr = block_data_tensor->data>(); + *flag_ptr = flag_tensor->data(); + funcs::SetConstant set_zero; + set_zero(ctx, flag_tensor, static_cast(0)); + } +} + template void BatchNormGradRawKernel(const Context &ctx, const DenseTensor &x, @@ -875,8 +922,6 @@ void BatchNormGradRawKernel(const Context &ctx, dim3 block; dim3 grid; const int block_size = 512; - const int MAX_GRID_SIZE = 128; - const int WARP_SIZE = 32; // init intermediate storage DenseTensor block_data_tensor; @@ -889,35 +934,20 @@ void BatchNormGradRawKernel(const Context &ctx, BatchNormParamType *block_data_ptr = nullptr; int *flag_ptr = nullptr; - int block_x = - std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE); - int block_y = - std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16), - block_size / block_x); - if (block_x * block_y != block_size) { - block_x = std::min(phi::funcs::details::GetLastPow2(C), - block_size / block_y); - } - int grid_x = (C + block_x - 1) / block_x; - int grid_y = - std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16), - MAX_GRID_SIZE); - - block.x = block_x; - block.y = block_y; - grid.x = grid_x; - grid.y = grid_y; - - if (grid.y > 1) { - block_data_tensor = phi::Empty, Context>( - ctx, {2 * C * grid.y}); - flag_tensor = phi::Empty(ctx, {grid.x}); - - block_data_ptr = block_data_tensor.data>(); - flag_ptr = flag_tensor.data(); - funcs::SetConstant set_zero; - set_zero(ctx, &flag_tensor, static_cast(0)); - } + SetLaunchConfigInfoForChannelLast(ctx, + &block_data_tensor, + &flag_tensor, + &block_data_ptr, + &flag_ptr, + N, + H, + W, + D, + C, + block_size, + &block, + &grid); + // 1. reduce_sum(x) => mean, inv_var auto *mean_ptr = saved_mean_data == nullptr @@ -967,6 +997,7 @@ void BatchNormGradRawKernel(const Context &ctx, N, H * W * D, epsilon, + false, block_data_ptr, dscale, dbias, @@ -1256,18 +1287,44 @@ void BatchNormGradRawKernel(const Context &ctx, d_x->data()); } if (d_scale && d_bias) { - KeBNBackwardScaleBias - <<>>( - d_y->data(), - x.data(), + dim3 block; + dim3 grid; + const int block_size = 512; + + // init intermediate storage + DenseTensor block_data_tensor; + DenseTensor flag_tensor; + BatchNormParamType *block_data_ptr = nullptr; + int *flag_ptr = nullptr; + + SetLaunchConfigInfoForChannelLast(ctx, + &block_data_tensor, + &flag_tensor, + &block_data_ptr, + &flag_ptr, + N, + H, + W, + D, + C, + block_size, + &block, + &grid); + BNBackward2DChannelLastStage2 + <<>>( + transformed_d_y.template data(), + transformed_x.template data(), running_mean_data, running_var_data, - epsilon, - N, C, + N, H * W * D, + epsilon, + true, + block_data_ptr, d_scale->data>(), - d_bias->data>()); + d_bias->data>(), + flag_ptr); } } }