diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 0f028f42a956c8502ce4e8b83c4781849cb6fcb4..cfd04cf6a8ef8db1904343fbd679e0aec6276231 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -21,8 +21,10 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/batch_norm_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/norm_utils.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/gpu/batch_norm_utils.h" #ifdef __HIPCC__ @@ -197,6 +199,7 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward( x_sum += x_i; x_square_sum += x_i * x_i; } + x_sum = BlockReduce(mean_storage).Reduce(x_sum, cub::Sum()); x_square_sum = BlockReduce(variance_storeage).Reduce(x_square_sum, cub::Sum()); @@ -218,6 +221,7 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward( dy_i * (static_cast>(x[index]) - mean_val); db_sum += dy_i; } + ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum()); db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum()); if (threadIdx.x == 0) { @@ -241,6 +245,263 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackward( } } +template +__device__ __forceinline__ void BlockReduceByVetical( + BatchNormParamType x_sum, + BatchNormParamType x_square_sum, + BatchNormParamType *smem_sum, + BatchNormParamType *smem_square_sum, + BatchNormParamType *x_sum_out, + BatchNormParamType *x_square_sum_out) { + int tid = threadIdx.x + threadIdx.y * blockDim.x; +#pragma unroll + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset * 2) { + smem_sum[tid] = x_sum; + smem_square_sum[tid] = x_square_sum; + } + __syncthreads(); + if (threadIdx.y < offset) { + int pair_tid = tid + offset * blockDim.x; + x_sum += smem_sum[pair_tid]; + x_square_sum += smem_square_sum[pair_tid]; + } + } + if (threadIdx.y == 0) { + *x_sum_out = x_sum; + *x_square_sum_out = x_square_sum; + } +} + +template +static __global__ void BNBackward2DChannelLastStage1( + const T *x, + const int C, + const int N, + const int HxW, + const double epsilon, + BatchNormParamType *block_data_ptr, + BatchNormParamType *compute_mean, + BatchNormParamType *compute_inv_var, + int *flag_ptr) { + int outer_size = C; + int inner_size = N * HxW; + + __shared__ BatchNormParamType smem_sum[BlockDim]; + __shared__ BatchNormParamType smem_square_sum[BlockDim]; + __shared__ BatchNormParamType inv_var_val; + __shared__ BatchNormParamType mean_val; + + int outer_loop_stride = gridDim.x * blockDim.x; + int inner_loop_stride = gridDim.y * blockDim.y; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; + i += outer_loop_stride) { + BatchNormParamType x_sum = static_cast>(0); + BatchNormParamType x_square_sum = static_cast>(0); + + for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size; + j += inner_loop_stride) { + const int index = j * outer_size + i; + BatchNormParamType x_i = static_cast>(x[index]); + x_sum += x_i; + x_square_sum += x_i * x_i; + } + + // vertical block sum + BlockReduceByVetical(x_sum, + x_square_sum, + &smem_sum[0], + &smem_square_sum[0], + &x_sum, + &x_square_sum); + + if (gridDim.y > 1) { + volatile BatchNormParamType *staging_sum = block_data_ptr; + volatile BatchNormParamType *staging_square_sum = + &block_data_ptr[C * gridDim.y]; + // write block data to global memory + if (threadIdx.y == 0) { + staging_sum[i + blockIdx.y * C] = x_sum; + staging_square_sum[i + blockIdx.y * C] = x_square_sum; + } + + // make sure write is visible to all blocks + __threadfence(); + __syncthreads(); + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&flag_ptr[blockIdx.x], 1); + is_last_block_done = (old == (gridDim.y - 1)); + } + + __syncthreads(); + + if (is_last_block_done) { + x_sum = static_cast>(0); + x_square_sum = static_cast>(0); + // thread sum + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + x_sum += staging_sum[i + y * C]; + x_square_sum += staging_square_sum[i + y * C]; + } + + // vertical block sum + BlockReduceByVetical(x_sum, + x_square_sum, + &smem_sum[0], + &smem_square_sum[0], + &x_sum, + &x_square_sum); + + // final compute + if (threadIdx.y == 0) { + BatchNormParamType compute_mean_val = x_sum / inner_size; + BatchNormParamType variance_val = + x_square_sum / inner_size - compute_mean_val * compute_mean_val; + BatchNormParamType compute_inv_var_val = + 1 / sqrt(variance_val + epsilon); + + compute_mean[i] = compute_mean_val; + compute_inv_var[i] = compute_inv_var_val; + } + } + } + } +} + +template +static __global__ void BNBackward2DChannelLastStage2( + const T *dy, + const T *x, + const BatchNormParamType *means, + const BatchNormParamType *variances, + const int C, + const int N, + const int HxW, + const double epsilon, + BatchNormParamType *block_data_ptr, + BatchNormParamType *dscale, + BatchNormParamType *dbias, + int *flag_ptr) { + int outer_size = C; + int inner_size = N * HxW; + + __shared__ BatchNormParamType smem_ds_sum[BlockDim]; + __shared__ BatchNormParamType smem_db_sum[BlockDim]; + __shared__ BatchNormParamType inv_var_val; + __shared__ BatchNormParamType mean_val; + + int outer_loop_stride = gridDim.x * blockDim.x; + int inner_loop_stride = gridDim.y * blockDim.y; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; + i += outer_loop_stride) { + BatchNormParamType ds_sum = static_cast>(0); + BatchNormParamType db_sum = static_cast>(0); + BatchNormParamType mean_val = means[i]; + BatchNormParamType inv_var_val = variances[i]; + + for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size; + j += inner_loop_stride) { + const int index = j * outer_size + i; + BatchNormParamType dy_i = + static_cast>(dy[index]); + ds_sum += + dy_i * (static_cast>(x[index]) - mean_val); + db_sum += dy_i; + } + + // vertical block sum + BlockReduceByVetical( + ds_sum, db_sum, &smem_ds_sum[0], &smem_db_sum[0], &ds_sum, &db_sum); + + if (gridDim.y > 1) { + volatile BatchNormParamType *staging_ds_sum = block_data_ptr; + volatile BatchNormParamType *staging_db_sum = + &block_data_ptr[C * gridDim.y]; + // write block data to global memory + if (threadIdx.y == 0) { + staging_ds_sum[i + blockIdx.y * C] = ds_sum; + staging_db_sum[i + blockIdx.y * C] = db_sum; + } + + // make sure write is visible to all blocks + __threadfence(); + __syncthreads(); + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&flag_ptr[blockIdx.x], 1); + is_last_block_done = (old == (gridDim.y - 1)); + } + + __syncthreads(); + + if (is_last_block_done) { + ds_sum = static_cast>(0); + db_sum = static_cast>(0); + // thread sum + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + ds_sum += staging_ds_sum[i + y * C]; + db_sum += staging_db_sum[i + y * C]; + } + + // vertical block sum + BlockReduceByVetical( + ds_sum, db_sum, &smem_ds_sum[0], &smem_db_sum[0], &ds_sum, &db_sum); + + // final compute + if (threadIdx.y == 0) { + dscale[i] = ds_sum * inv_var_val; + dbias[i] = db_sum; + } + } + } + } +} + +template +static __global__ void BNBackward2DChannelLastStage3( + const T *dy, + const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *dscales, + const BatchNormParamType *dbias, + const BatchNormParamType *means, + const BatchNormParamType *variances, + const int C, + const int N, + const int HxW, + const double epsilon, + T *dx) { + const int outer_size = C; + const int inner_size = N * HxW; + int outer_loop_stride = gridDim.x * blockDim.x; + int inner_loop_stride = gridDim.y * blockDim.y; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; + i += outer_loop_stride) { + BatchNormParamType mean_val = means[i]; + BatchNormParamType inv_var_val = variances[i]; + BatchNormParamType dscale_val = dscales[i]; + BatchNormParamType dbias_val = dbias[i]; + + for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size; + j += inner_loop_stride) { + const int index = j * outer_size + i; + dx[index] = scale[i] * inv_var_val * + (static_cast>(dy[index]) - + dbias_val / static_cast>(inner_size) - + (static_cast>(x[index]) - mean_val) * + inv_var_val * dscale_val / inner_size); + } + } +} + template static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData( const T *dy, @@ -592,42 +853,147 @@ void BatchNormGradRawKernel(const Context &ctx, // epsilon, saved_mean_data, saved_var_data)); #else // CUDNN only support small batch size - const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; + // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; + const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240; const size_t CUDNN_SPATIAL_THRESHOLD = 880801; const bool use_native_kernel = ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); if (use_native_kernel) { - if (compute_format == DataLayout::kNCHW) { - BNBackward - <<>>( + if (x_dims.size() == 2) { + dim3 block; + dim3 grid; + const int block_size = 512; + const int MAX_GRID_SIZE = 128; + const int WARP_SIZE = 32; + + // init intermediate storage + DenseTensor block_data_tensor; + DenseTensor flag_tensor; + DenseTensor compute_mean_tensor = + phi::Empty, Context>(ctx, {C}); + DenseTensor compute_inv_var_tensor = + phi::Empty, Context>(ctx, {C}); + + BatchNormParamType *block_data_ptr = nullptr; + int *flag_ptr = nullptr; + + int block_x = + std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE); + int block_y = + std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16), + block_size / block_x); + if (block_x * block_y != block_size) { + block_x = std::min(phi::funcs::details::GetLastPow2(C), + block_size / block_y); + } + int grid_x = (C + block_x - 1) / block_x; + int grid_y = + std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16), + MAX_GRID_SIZE); + + block.x = block_x; + block.y = block_y; + grid.x = grid_x; + grid.y = grid_y; + + if (grid.y > 1) { + block_data_tensor = phi::Empty, Context>( + ctx, {2 * C * grid.y}); + flag_tensor = phi::Empty(ctx, {grid.x}); + + block_data_ptr = block_data_tensor.data>(); + flag_ptr = flag_tensor.data(); + funcs::SetConstant set_zero; + set_zero(ctx, &flag_tensor, static_cast(0)); + } + // 1. reduce_sum(x) => mean, inv_var + auto *mean_ptr = + saved_mean_data == nullptr + ? compute_mean_tensor.data>() + : saved_mean_data; + auto *variance_ptr = + saved_var_data == nullptr + ? compute_inv_var_tensor.data>() + : saved_var_data; + + if (saved_mean_data == nullptr) { + BNBackward2DChannelLastStage1 + <<>>( + transformed_x.template data(), + C, + N, + H * W * D, + epsilon, + block_data_ptr, + compute_mean_tensor.data>(), + compute_inv_var_tensor.data>(), + flag_ptr); + } + // 2. reduce_sum(x, dy, mean) => dscale, dbias + BNBackward2DChannelLastStage2 + <<>>( transformed_d_y.template data(), transformed_x.template data(), - scale.template data>(), - saved_mean_data, - saved_var_data, + mean_ptr, + variance_ptr, C, N, H * W * D, epsilon, - transformed_d_x.template data(), + block_data_ptr, ctx.template Alloc>(d_scale), - ctx.template Alloc>(d_bias)); - } else { - BNBackward - <<>>( + ctx.template Alloc>(d_bias), + flag_ptr); + + // 3. elementwise_mul(scale, mean, inv_var, dy, dscale, dbias) => dx + BNBackward2DChannelLastStage3 + <<>>( transformed_d_y.template data(), transformed_x.template data(), scale.template data>(), - saved_mean_data, - saved_var_data, + d_scale->data>(), + d_bias->data>(), + mean_ptr, + variance_ptr, C, N, H * W * D, epsilon, - transformed_d_x.template data(), - ctx.template Alloc>(d_scale), - ctx.template Alloc>(d_bias)); + transformed_d_x.template data()); + + } else { + if (compute_format == DataLayout::kNCHW) { + BNBackward + <<>>( + transformed_d_y.template data(), + transformed_x.template data(), + scale.template data>(), + saved_mean_data, + saved_var_data, + C, + N, + H * W * D, + epsilon, + transformed_d_x.template data(), + ctx.template Alloc>(d_scale), + ctx.template Alloc>(d_bias)); + } else { + BNBackward + <<>>( + transformed_d_y.template data(), + transformed_x.template data(), + scale.template data>(), + saved_mean_data, + saved_var_data, + C, + N, + H * W * D, + epsilon, + transformed_d_x.template data(), + ctx.template Alloc>(d_scale), + ctx.template Alloc>(d_bias)); + } } } else { #if CUDNN_VERSION_MIN(7, 4, 1) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 61694db7e8ed36bf7b9475b48104d4ec7203f590..8731946ba2d424503ae631f191259a667190363a 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -908,7 +908,8 @@ void BatchNormKernel(const Context &ctx, // static_cast(saved_variance->template mutable_data< // BatchNormParamType>(ctx.GetPlace())))); #else - const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; + // const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; + const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 10240; const size_t CUDNN_SPATIAL_THRESHOLD = 880801; const bool use_native_kernel = ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py index 7aa3b8cddf80cbcd202d1e844af648539af33ab7..7c569b700311ca8d8f5e0674d5db51d669e40a49 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py @@ -323,6 +323,35 @@ class TestBatchNormChannelLast(unittest.TestCase): else: self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True) + def test_1d_opt(self): + with fluid.dygraph.guard(): + batch_size = 13700 + channels = 16 + shape = (batch_size, channels) + x = paddle.randn(shape) + x_4d = x.reshape((batch_size, channels, 1, 1)) + + x.stop_gradient = False + x_4d.stop_gradient = False + + bn1d = paddle.nn.BatchNorm1D(channels) + bn2d = paddle.nn.BatchNorm2D(channels) + + y = bn1d(x) + y2 = bn2d(x_4d) + + y.backward() + y2.backward() + + assert np.allclose(y.numpy().flatten(), + y2.numpy().flatten(), + atol=1e-5, + rtol=1e-5) + assert np.allclose(bn1d.weight.grad.numpy().flatten(), + bn2d.weight.grad.numpy().flatten(), + atol=1e-5, + rtol=1e-5) + class TestBatchNormUseGlobalStats(unittest.TestCase):