From 1bc47c8468363ffc22190416968bfbc1a5078132 Mon Sep 17 00:00:00 2001 From: Yao Zihang <1162526220@qq.com> Date: Thu, 14 Jul 2022 16:00:52 +0800 Subject: [PATCH] Optimize batchnorm1d using 2D kernel (#43530) --- .../phi/kernels/gpu/batch_norm_grad_kernel.cu | 6 +- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 522 +++++++++++++++++- .../tests/unittests/test_batch_norm_op_v2.py | 59 +- 3 files changed, 549 insertions(+), 38 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index b23b119342..0f028f42a9 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -591,10 +591,12 @@ void BatchNormGradRawKernel(const Context &ctx, // ctx.GetPlace()), // epsilon, saved_mean_data, saved_var_data)); #else - // CUDNN PER_ACTIVATION mode only support small batch size + // CUDNN only support small batch size const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; + const size_t CUDNN_SPATIAL_THRESHOLD = 880801; const bool use_native_kernel = - (x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD); + ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || + (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); if (use_native_kernel) { if (compute_format == DataLayout::kNCHW) { BNBackward diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 7027225915..61694db7e8 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -31,6 +31,7 @@ namespace cub = hipcub; #include "paddle/phi/kernels/batch_norm_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/norm_utils.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/gpu/batch_norm_utils.h" #ifdef __HIPCC__ @@ -137,6 +138,398 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNForwardTraining( } } +template +__device__ __forceinline__ void merge_block_vertical( + BatchNormParamType x_sum, + BatchNormParamType x_square_sum, + BatchNormParamType *smem_sum, + BatchNormParamType *smem_square_sum, + BatchNormParamType *x_sum_out, + BatchNormParamType *x_square_sum_out) { + int tid = threadIdx.x + threadIdx.y * blockDim.x; +#pragma unroll + for (int offset = blockDim.y / 2; offset > 0; offset >>= 1) { + if (threadIdx.y < offset * 2) { + smem_sum[tid] = x_sum; + smem_square_sum[tid] = x_square_sum; + } + __syncthreads(); + if (threadIdx.y < offset) { + int pair_tid = tid + offset * blockDim.x; + x_sum += smem_sum[pair_tid]; + x_square_sum += smem_square_sum[pair_tid]; + } + } + if (threadIdx.y == 0) { + *x_sum_out = x_sum; + *x_square_sum_out = x_square_sum; + } +} + +template +__device__ __forceinline__ void merge_block_horizonal( + BatchNormParamType x_sum, + BatchNormParamType x_square_sum, + BatchNormParamType *smem_sum, + BatchNormParamType *smem_square_sum, + BatchNormParamType *x_sum_out, + BatchNormParamType *x_square_sum_out) { + int tid = threadIdx.x + threadIdx.y * blockDim.x; +#pragma unroll + for (int offset = blockDim.x / 2; offset > 0; offset >>= 1) { + if (threadIdx.x < offset * 2) { + smem_sum[tid] = x_sum; + smem_square_sum[tid] = x_square_sum; + } + __syncthreads(); + if (threadIdx.x < offset) { + int pair_tid = tid + offset; + x_sum += smem_sum[pair_tid]; + x_square_sum += smem_square_sum[pair_tid]; + } + } + if (threadIdx.x == 0) { + *x_sum_out = x_sum; + *x_square_sum_out = x_square_sum; + } +} + +template +static __global__ void BNForwardTraining2DChannelLastCompStat( + const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const int C, + const int N, + const int HxW, + const double epsilon, + double exponentialAverageFactor, + T *y, + BatchNormParamType *global_mean, + BatchNormParamType *global_variance, + BatchNormParamType *save_mean, + BatchNormParamType *save_inv_variance, + BatchNormParamType *compute_mean, + BatchNormParamType *compute_inv_var, + BatchNormParamType *block_data_ptr, + int *flag_ptr) { + int outer_size = C; + int inner_size = N * HxW; + + __shared__ BatchNormParamType smem_sum[BlockDim]; + __shared__ BatchNormParamType smem_square_sum[BlockDim]; + + int outer_loop_stride = gridDim.x * blockDim.x; + int inner_loop_stride = gridDim.y * blockDim.y; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; + i += outer_loop_stride) { + BatchNormParamType x_sum = static_cast>(0); + BatchNormParamType x_square_sum = static_cast>(0); + + for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size; + j += inner_loop_stride) { + const int index = j * outer_size + i; + BatchNormParamType x_i = static_cast>(x[index]); + x_sum += x_i; + x_square_sum += x_i * x_i; + } + + // vertical block sum + merge_block_vertical(x_sum, + x_square_sum, + &smem_sum[0], + &smem_square_sum[0], + &x_sum, + &x_square_sum); + + if (gridDim.y > 1) { + volatile BatchNormParamType *staging_sum = block_data_ptr; + volatile BatchNormParamType *staging_square_sum = + &block_data_ptr[C * gridDim.y]; + // write block data to global memory + if (threadIdx.y == 0) { + staging_sum[i + blockIdx.y * C] = x_sum; + staging_square_sum[i + blockIdx.y * C] = x_square_sum; + } + + // make sure write is visible to all blocks + __threadfence(); + __syncthreads(); + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&flag_ptr[blockIdx.x], 1); + is_last_block_done = (old == (gridDim.y - 1)); + } + + __syncthreads(); + + if (is_last_block_done) { + x_sum = static_cast>(0); + x_square_sum = static_cast>(0); + // thread sum + for (int y = threadIdx.y; y < gridDim.y; y += blockDim.y) { + x_sum += staging_sum[i + y * C]; + x_square_sum += staging_square_sum[i + y * C]; + } + + // vertical block sum + merge_block_vertical(x_sum, + x_square_sum, + &smem_sum[0], + &smem_square_sum[0], + &x_sum, + &x_square_sum); + + // final compute + if (threadIdx.y == 0) { + BatchNormParamType compute_mean_val = x_sum / inner_size; + BatchNormParamType variance_val = + x_square_sum / inner_size - compute_mean_val * compute_mean_val; + BatchNormParamType compute_inv_var_val = + 1 / sqrt(variance_val + epsilon); + + if (save_mean && save_inv_variance) { + save_mean[i] = compute_mean_val; + save_inv_variance[i] = compute_inv_var_val; + } + global_mean[i] = (1 - exponentialAverageFactor) * compute_mean_val + + exponentialAverageFactor * global_mean[i]; + global_variance[i] = (1 - exponentialAverageFactor) * variance_val + + exponentialAverageFactor * global_variance[i]; + + compute_mean[i] = compute_mean_val; + compute_inv_var[i] = compute_inv_var_val; + } + } + } else { + if (blockIdx.y == 0 && threadIdx.y == 0) { + BatchNormParamType compute_mean_val = x_sum / inner_size; + BatchNormParamType variance_val = + x_square_sum / inner_size - compute_mean_val * compute_mean_val; + BatchNormParamType compute_inv_var_val = + 1 / sqrt(variance_val + epsilon); + + if (save_mean && save_inv_variance) { + save_mean[i] = compute_mean_val; + save_inv_variance[i] = compute_inv_var_val; + } + global_mean[i] = (1 - exponentialAverageFactor) * compute_mean_val + + exponentialAverageFactor * global_mean[i]; + global_variance[i] = (1 - exponentialAverageFactor) * variance_val + + exponentialAverageFactor * global_variance[i]; + + compute_mean[i] = compute_mean_val; + compute_inv_var[i] = compute_inv_var_val; + } + } + } +} + +template +static __global__ void BNForwardTraining2DChannelLastWriteRes( + const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const int C, + const int N, + const int HxW, + T *y, + BatchNormParamType *compute_mean, + BatchNormParamType *compute_inv_var) { + int outer_size = C; + int inner_size = N * HxW; + + int outer_loop_stride = gridDim.x * blockDim.x; + int inner_loop_stride = gridDim.y * blockDim.y; + + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < outer_size; + i += outer_loop_stride) { + BatchNormParamType mean_val = compute_mean[i]; + BatchNormParamType inv_var_val = compute_inv_var[i]; + BatchNormParamType scale_val = scale[i]; + BatchNormParamType bias_val = bias[i]; + + for (int j = blockIdx.y * blockDim.y + threadIdx.y; j < inner_size; + j += inner_loop_stride) { + const int index = j * outer_size + i; + BatchNormParamType x_sub_mean = + static_cast>(x[index]) - mean_val; + y[index] = scale_val * x_sub_mean * inv_var_val + bias_val; + } + } +} + +template +static __global__ void BNForwardTraining2DCompStat( + const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const int C, + const int N, + const int HxW, + const double epsilon, + double exponentialAverageFactor, + T *y, + BatchNormParamType *global_mean, + BatchNormParamType *global_variance, + BatchNormParamType *save_mean, + BatchNormParamType *save_inv_variance, + BatchNormParamType *compute_mean, + BatchNormParamType *compute_inv_var, + BatchNormParamType *block_data_ptr, + int *flag_ptr) { + int outer_size = C; + int inner_size = N * HxW; + + __shared__ BatchNormParamType smem_sum[BlockDim]; + __shared__ BatchNormParamType smem_square_sum[BlockDim]; + + int outer_loop_stride = gridDim.y * blockDim.y; + int inner_loop_stride = gridDim.x * blockDim.x; + + for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < outer_size; + i += outer_loop_stride) { + BatchNormParamType x_sum = static_cast>(0); + BatchNormParamType x_square_sum = static_cast>(0); + + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < inner_size; + j += inner_loop_stride) { + const int index = (j / HxW * C + i) * HxW + j % HxW; + BatchNormParamType x_i = static_cast>(x[index]); + x_sum += x_i; + x_square_sum += x_i * x_i; + } + + // horizonal block sum + merge_block_horizonal(x_sum, + x_square_sum, + &smem_sum[0], + &smem_square_sum[0], + &x_sum, + &x_square_sum); + + if (gridDim.x > 1) { + volatile BatchNormParamType *staging_sum = block_data_ptr; + volatile BatchNormParamType *staging_square_sum = + &block_data_ptr[C * gridDim.x]; + // write block data to global memory + if (threadIdx.x == 0) { + staging_sum[i + blockIdx.x * C] = x_sum; + staging_square_sum[i + blockIdx.x * C] = x_square_sum; + } + + // make sure write is visible to all blocks + __threadfence(); + __syncthreads(); + + __shared__ bool is_last_block_done; + // mark block done + if (threadIdx.x == 0 && threadIdx.y == 0) { + int old = atomicAdd(&flag_ptr[blockIdx.y], 1); + is_last_block_done = (old == (gridDim.x - 1)); + } + + __syncthreads(); + + if (is_last_block_done) { + x_sum = static_cast>(0); + x_square_sum = static_cast>(0); + // thread sum + for (int x = threadIdx.x; x < gridDim.x; x += blockDim.x) { + x_sum += staging_sum[i + x * C]; + x_square_sum += staging_square_sum[i + x * C]; + } + + // horizonal block sum + merge_block_horizonal(x_sum, + x_square_sum, + &smem_sum[0], + &smem_square_sum[0], + &x_sum, + &x_square_sum); + + // final compute + if (threadIdx.x == 0) { + BatchNormParamType compute_mean_val = x_sum / inner_size; + BatchNormParamType variance_val = + x_square_sum / inner_size - compute_mean_val * compute_mean_val; + BatchNormParamType compute_inv_var_val = + 1 / sqrt(variance_val + epsilon); + + if (save_mean && save_inv_variance) { + save_mean[i] = compute_mean_val; + save_inv_variance[i] = compute_inv_var_val; + } + global_mean[i] = (1 - exponentialAverageFactor) * compute_mean_val + + exponentialAverageFactor * global_mean[i]; + global_variance[i] = (1 - exponentialAverageFactor) * variance_val + + exponentialAverageFactor * global_variance[i]; + + compute_mean[i] = compute_mean_val; + compute_inv_var[i] = compute_inv_var_val; + } + } + } else { + if (blockIdx.x == 0 && threadIdx.x == 0) { + BatchNormParamType compute_mean_val = x_sum / inner_size; + BatchNormParamType variance_val = + x_square_sum / inner_size - compute_mean_val * compute_mean_val; + BatchNormParamType compute_inv_var_val = + 1 / sqrt(variance_val + epsilon); + + if (save_mean && save_inv_variance) { + save_mean[i] = compute_mean_val; + save_inv_variance[i] = compute_inv_var_val; + } + global_mean[i] = (1 - exponentialAverageFactor) * compute_mean_val + + exponentialAverageFactor * global_mean[i]; + global_variance[i] = (1 - exponentialAverageFactor) * variance_val + + exponentialAverageFactor * global_variance[i]; + + compute_mean[i] = compute_mean_val; + compute_inv_var[i] = compute_inv_var_val; + } + } + } +} + +template +static __global__ void BNForwardTraining2DWriteRes( + const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const int C, + const int N, + const int HxW, + T *y, + BatchNormParamType *compute_mean, + BatchNormParamType *compute_inv_var) { + int outer_size = C; + int inner_size = N * HxW; + + int outer_loop_stride = gridDim.y * blockDim.y; + int inner_loop_stride = gridDim.x * blockDim.x; + + for (int i = blockIdx.y * blockDim.y + threadIdx.y; i < outer_size; + i += outer_loop_stride) { + BatchNormParamType mean_val = compute_mean[i]; + BatchNormParamType inv_var_val = compute_inv_var[i]; + BatchNormParamType scale_val = scale[i]; + BatchNormParamType bias_val = bias[i]; + + for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < inner_size; + j += inner_loop_stride) { + const int index = (j / HxW * C + i) * HxW + j % HxW; + BatchNormParamType x_sub_mean = + static_cast>(x[index]) - mean_val; + y[index] = scale_val * x_sub_mean * inv_var_val + bias_val; + } + } +} + template void BatchNormKernel(const Context &ctx, const DenseTensor &x, @@ -515,17 +908,63 @@ void BatchNormKernel(const Context &ctx, // static_cast(saved_variance->template mutable_data< // BatchNormParamType>(ctx.GetPlace())))); #else - // CUDNN PER_ACTIVATION mode only support small batch size const size_t CUDNN_PER_ACTIVATION_THRESHOLD = 131070; + const size_t CUDNN_SPATIAL_THRESHOLD = 880801; const bool use_native_kernel = - (x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD); + ((x_dims.size() == 2 && N >= CUDNN_PER_ACTIVATION_THRESHOLD) || + (x_dims.size() == 3 && N >= CUDNN_SPATIAL_THRESHOLD)); if (use_native_kernel) { - const int block = 512; - const int max_threads = ctx.GetMaxPhysicalThreadCount(); - const int max_blocks = std::max(max_threads / block, 1); - const int grid = std::min(C, max_blocks); - if (compute_format == DataLayout::kNCHW) { - BNForwardTraining + dim3 block; + dim3 grid; + const int block_size = 512; + const int MAX_GRID_SIZE = 128; + const int WARP_SIZE = 32; + + // init intermediate storage + DenseTensor block_data_tensor; + DenseTensor flag_tensor; + DenseTensor compute_mean_tensor = + phi::Empty, Context>(ctx, {C}); + DenseTensor compute_inv_var_tensor = + phi::Empty, Context>(ctx, {C}); + + BatchNormParamType *block_data_ptr = nullptr; + int *flag_ptr = nullptr; + + if (x_dims.size() != 2 && compute_format == DataLayout::kNCHW) { + // init block&grid config + int block_x = + std::min(phi::funcs::details::GetLastPow2(H * W * D), block_size); + int block_y = std::min(phi::funcs::details::GetLastPow2(C), + block_size / block_x); + + if (block_x * block_y != block_size) { + block_x = + std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16), + block_size / block_y); + } + + int grid_x = + std::min((N * H * W * D + block_x * 16 - 1) / (block_x * 16), + MAX_GRID_SIZE); + int grid_y = (C + block_y - 1) / block_y; + + block.x = block_x; + block.y = block_y; + grid.x = grid_x; + grid.y = grid_y; + + if (grid.x > 1) { + block_data_tensor = phi::Empty, Context>( + ctx, {2 * C * grid.x}); + flag_tensor = phi::Empty(ctx, {grid.y}); + + block_data_ptr = block_data_tensor.data>(); + flag_ptr = flag_tensor.data(); + funcs::SetConstant set_zero; + set_zero(ctx, &flag_tensor, static_cast(0)); + } + BNForwardTraining2DCompStat <<>>( transformed_x.template data(), scale.template data>(), @@ -539,9 +978,54 @@ void BatchNormKernel(const Context &ctx, mean_out->template data>(), variance_out->template data>(), saved_mean->template data>(), - saved_variance->template data>()); + saved_variance->template data>(), + compute_mean_tensor.data>(), + compute_inv_var_tensor.data>(), + block_data_ptr, + flag_ptr); + + BNForwardTraining2DWriteRes<<>>( + transformed_x.template data(), + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + transformed_y.template data(), + compute_mean_tensor.data>(), + compute_inv_var_tensor.data>()); } else { - BNForwardTraining + // init block&grid config + int block_x = + std::min(phi::funcs::details::GetLastPow2(C), WARP_SIZE); + int block_y = + std::min(phi::funcs::details::GetLastPow2(N * H * W * D / 16), + block_size / block_x); + if (block_x * block_y != block_size) { + block_x = std::min(phi::funcs::details::GetLastPow2(C), + block_size / block_y); + } + int grid_x = (C + block_x - 1) / block_x; + int grid_y = + std::min((N * H * W * D + block_y * 16 - 1) / (block_y * 16), + MAX_GRID_SIZE); + + block.x = block_x; + block.y = block_y; + grid.x = grid_x; + grid.y = grid_y; + + if (grid.y > 1) { + block_data_tensor = phi::Empty, Context>( + ctx, {2 * C * grid.y}); + flag_tensor = phi::Empty(ctx, {grid.x}); + + block_data_ptr = block_data_tensor.data>(); + flag_ptr = flag_tensor.data(); + funcs::SetConstant set_zero; + set_zero(ctx, &flag_tensor, static_cast(0)); + } + BNForwardTraining2DChannelLastCompStat <<>>( transformed_x.template data(), scale.template data>(), @@ -555,7 +1039,23 @@ void BatchNormKernel(const Context &ctx, mean_out->template data>(), variance_out->template data>(), saved_mean->template data>(), - saved_variance->template data>()); + saved_variance->template data>(), + compute_mean_tensor.data>(), + compute_inv_var_tensor.data>(), + block_data_ptr, + flag_ptr); + + BNForwardTraining2DChannelLastWriteRes + <<>>( + transformed_x.template data(), + scale.template data>(), + bias.template data>(), + C, + N, + H * W * D, + transformed_y.template data(), + compute_mean_tensor.data>(), + compute_inv_var_tensor.data>()); } } else { #if CUDNN_VERSION_MIN(7, 4, 1) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py index cfd5d5f7c9..7aa3b8cddf 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op_v2.py @@ -82,50 +82,58 @@ class TestBatchNorm(unittest.TestCase): self.assertRaises(ValueError, error2d_dataformat) self.assertRaises(ValueError, error3d_dataformat) - def test_eager_api(self): - places = [fluid.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(fluid.CUDAPlace(0)) - for p in places: - shape = [4, 10, 4, 4] + def test_large_batch(self): - def compute_v1(x): - with fluid.dygraph.guard(p): - bn = fluid.dygraph.BatchNorm(shape[1]) - #bn = paddle.nn.BatchNorm2D(shape[1]) + def compute_baseline(x): + with fluid.dygraph.guard(p): + bn = fluid.dygraph.BatchNorm(shape[1]) + x1 = paddle.to_tensor(x) + x1.stop_gradient = False + y = bn(x1) + y.backward() + return y.numpy(), x1.gradient() + + def compute_1d(x): + with fluid.dygraph.guard(p): + with _test_eager_guard(): + bn = paddle.nn.BatchNorm1D(shape[1]) x1 = paddle.to_tensor(x) x1.stop_gradient = False y = bn(x1) y.backward() return y.numpy(), x1.gradient() - def compute_v2(x): - with fluid.dygraph.guard(p): - with _test_eager_guard(): - print("v2") - bn = paddle.nn.BatchNorm2D(shape[1]) - x1 = paddle.to_tensor(x) - x1.stop_gradient = False - y = bn(x1) - y.backward() - return y.numpy(), x1.gradient() + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + # [N, C] + shape = [200000, 4] + x = np.random.randn(*shape).astype("float32") + y1, g1 = compute_baseline(x) + y2, g2 = compute_1d(x) + self.assertTrue(np.allclose(g1, g2)) + self.assertTrue(np.allclose(y1, y2)) + # [N, C, L] + shape = [1000000, 4, 4] x = np.random.randn(*shape).astype("float32") - y1, g1 = compute_v1(x) - y2, g2 = compute_v2(x) + y1, g1 = compute_baseline(x) + y2, g2 = compute_1d(x) self.assertTrue(np.allclose(g1, g2)) self.assertTrue(np.allclose(y1, y2)) - def test_eager_api_1d(self): + def test_eager_api(self): places = [fluid.CPUPlace()] if core.is_compiled_with_cuda(): places.append(fluid.CUDAPlace(0)) for p in places: - shape = [200000, 4] + shape = [4, 10, 4, 4] def compute_v1(x): with fluid.dygraph.guard(p): bn = fluid.dygraph.BatchNorm(shape[1]) + #bn = paddle.nn.BatchNorm2D(shape[1]) x1 = paddle.to_tensor(x) x1.stop_gradient = False y = bn(x1) @@ -135,7 +143,8 @@ class TestBatchNorm(unittest.TestCase): def compute_v2(x): with fluid.dygraph.guard(p): with _test_eager_guard(): - bn = paddle.nn.BatchNorm1D(shape[1]) + print("v2") + bn = paddle.nn.BatchNorm2D(shape[1]) x1 = paddle.to_tensor(x) x1.stop_gradient = False y = bn(x1) -- GitLab