From 227a5112f68d00aa57016c95415e025a7c631f57 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Wed, 14 Dec 2022 11:18:42 +0800 Subject: [PATCH] [Sparse]Optimize performance of sparse conv on T4 (#49009) --- paddle/phi/kernels/sparse/gpu/conv.cu.h | 110 +++++++++++++++++- paddle/phi/kernels/sparse/gpu/conv_kernel.cu | 4 +- .../kernels/sparse/gpu/gather_gemm_scatter.cu | 8 +- .../kernels/sparse/gpu/gather_gemm_scatter.h | 27 ++++- 4 files changed, 140 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h index 8618171b8f9..61457e506b2 100644 --- a/paddle/phi/kernels/sparse/gpu/conv.cu.h +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -15,8 +15,14 @@ limitations under the License. */ #pragma once #include -#include #include +#ifdef __NVCC__ +#include +#endif +#ifdef __HIPCC__ +#include +namespace cub = hipcub; +#endif #include "paddle/phi/kernels/sparse/conv_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" @@ -199,6 +205,88 @@ __global__ void UniqueKernel(const IntT* in_indexs, } } +inline __device__ uint32_t BitCount(const uint32_t data) { + uint32_t count = data; + count = (count & 0x55555555) + ((count >> 1) & 0x55555555); + count = (count & 0x33333333) + ((count >> 2) & 0x33333333); + count = (count & 0x0f0f0f0f) + ((count >> 4) & 0x0f0f0f0f); + count = (count & 0x00ff00ff) + ((count >> 8) & 0x00ff00ff); + count = (count & 0x0000ffff) + ((count >> 16) & 0x0000ffff); + return count; +} + +static __global__ void GetOutIndexsCounter(const int* flags, + const int n, + int* out) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + __shared__ int block_count; + if (threadIdx.x == 0) { + block_count = 0; + } + __syncthreads(); + + if (tid < n) { + // get the count of 1 in flags[tid] + uint32_t count = BitCount(static_cast(flags[tid])); + // add to block_count + // TODO(zhangkaihuo): replace with block reduce_sum + atomicAdd(&block_count, static_cast(count)); + } + __syncthreads(); + // write to out + if (threadIdx.x == 0) { + out[blockIdx.x] = block_count; + } +} + +template +__global__ void GetOutIndexs(const int* flags, + const int n, + const int* offsets, + const int out_nnz, + int* out) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + __shared__ int block_counts[BS]; + __shared__ int block_outs[BS * 32]; + + int count = 0; + + if (tid < n) { + // get the count of 1 in flags[tid] + int flag = flags[tid]; + count = BitCount(static_cast(flag)); + } + + // call block prefix_sum + // using namespace cub; + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + BlockScan(temp_storage).ExclusiveSum(count, count); + __syncthreads(); + + // write index to out + if (tid < n) { + // get the count of 1 in flags[tid] + int flag = flags[tid]; + // int j = block_counts[threadIdx.x]; + int j = count; + // TODO(zhangkaihuo): opt the loop + for (int i = 0; i < 32; ++i) { + if ((1 & (flag >> i)) == 1) { + block_outs[j++] = (tid << 5) + i; + } + } + } + + __syncthreads(); + // write to block_outs + int start = offsets[blockIdx.x]; + int end = blockIdx.x == gridDim.x - 1 ? out_nnz : offsets[blockIdx.x + 1]; + for (int i = threadIdx.x; i < end - start; i += blockDim.x) { + out[start + i] = block_outs[i]; + } +} + template __global__ void GroupIndexs(const int* out_index_table, const int n, @@ -725,13 +813,25 @@ int ProductRuleBook(const Context& dev_ctx, gpuMemcpyDeviceToHost, dev_ctx.stream()); dev_ctx.Wait(); + + const int threads = 256; + const int blocks = (index_flags.numel() + threads - 1) / threads; + GetOutIndexsCounter<<>>( + index_flags_ptr, index_flags.numel(), out_index_table_ptr); #ifdef PADDLE_WITH_HIP - thrust::sort(thrust::hip::par.on(dev_ctx.stream()), + thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), #else - thrust::sort(thrust::cuda::par.on(dev_ctx.stream()), + thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()), #endif - out_index_ptr, - out_index_ptr + out_nnz); + out_index_table_ptr, + out_index_table_ptr + blocks, + out_index_table_ptr); + GetOutIndexs + <<>>(index_flags_ptr, + index_flags.numel(), + out_index_table_ptr, + out_nnz, + out_index_ptr); const int64_t sparse_dim = 4; phi::DenseTensor out_indices = diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu index 87037581e52..e6f3ca33649 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu @@ -125,7 +125,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, #ifdef PADDLE_WITH_CUTLASS bool cutlass = true; - if (dev_ctx.GetComputeCapability() < 80) cutlass = false; + if (dev_ctx.GetComputeCapability() < 75) cutlass = false; if (in_channels % 4 != 0 || out_channels % 4 != 0) { if (std::is_same::value) cutlass = false; if (std::is_same::value) cutlass = false; @@ -173,7 +173,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, if constexpr (std::is_same::value && std::is_same::value) { fp32_gather_gemm_scatter gather_gemm_scatter = - getBestFp32Kernel(M, N, K); + getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability()); gather_gemm_scatter(dev_ctx, x.non_zero_elements().data(), tmp_kernel_ptr, diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu index 48727c8f851..cfbaa7f1d63 100644 --- a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu +++ b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu @@ -72,7 +72,13 @@ fp16_gather_gemm_scatter getBestFp16Kernel(const int M, } fp32_gather_gemm_scatter getBestFp32Kernel(const int M, const int N, - const int K) { + const int K, + const int SM) { + if (SM == 75) { + return launchKernel< + float, + cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4::Gemm>; + } if (K == 4 && N == 16) { return launchKernel< float, diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h index 462cd710340..b596ff54538 100644 --- a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h +++ b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h @@ -66,7 +66,8 @@ fp16_gather_gemm_scatter getBestFp16Kernel(const int M, const int N); fp32_gather_gemm_scatter getBestFp32Kernel(const int M, const int K, - const int N); + const int N, + const int SM); fp64_gather_gemm_scatter getBestFp64Kernel(const int M, const int K, const int N); @@ -550,6 +551,30 @@ struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 { false, true>; }; + +// sm75 +struct cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4 { + using Gemm = cutlass::gemm::device::GemmUniversal< + cutlass::half_t, + cutlass::layout::RowMajor, + cutlass::half_t, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + cutlass::epilogue::thread::LinearCombination, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 2, + 8, + 8, + cutlass::arch::OpMultiplyAdd>; +}; + } // namespace sparse } // namespace phi #endif -- GitLab