未验证 提交 227a5112 编写于 作者: Z zhangkaihuo 提交者: GitHub

[Sparse]Optimize performance of sparse conv on T4 (#49009)

上级 032cbfc2
...@@ -15,8 +15,14 @@ limitations under the License. */ ...@@ -15,8 +15,14 @@ limitations under the License. */
#pragma once #pragma once
#include <thrust/remove.h> #include <thrust/remove.h>
#include <thrust/sort.h>
#include <thrust/unique.h> #include <thrust/unique.h>
#ifdef __NVCC__
#include <cub/block/block_scan.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/kernels/sparse/conv_kernel.h" #include "paddle/phi/kernels/sparse/conv_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
...@@ -199,6 +205,88 @@ __global__ void UniqueKernel(const IntT* in_indexs, ...@@ -199,6 +205,88 @@ __global__ void UniqueKernel(const IntT* in_indexs,
} }
} }
inline __device__ uint32_t BitCount(const uint32_t data) {
uint32_t count = data;
count = (count & 0x55555555) + ((count >> 1) & 0x55555555);
count = (count & 0x33333333) + ((count >> 2) & 0x33333333);
count = (count & 0x0f0f0f0f) + ((count >> 4) & 0x0f0f0f0f);
count = (count & 0x00ff00ff) + ((count >> 8) & 0x00ff00ff);
count = (count & 0x0000ffff) + ((count >> 16) & 0x0000ffff);
return count;
}
static __global__ void GetOutIndexsCounter(const int* flags,
const int n,
int* out) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ int block_count;
if (threadIdx.x == 0) {
block_count = 0;
}
__syncthreads();
if (tid < n) {
// get the count of 1 in flags[tid]
uint32_t count = BitCount(static_cast<uint32_t>(flags[tid]));
// add to block_count
// TODO(zhangkaihuo): replace with block reduce_sum
atomicAdd(&block_count, static_cast<int>(count));
}
__syncthreads();
// write to out
if (threadIdx.x == 0) {
out[blockIdx.x] = block_count;
}
}
template <int BS>
__global__ void GetOutIndexs(const int* flags,
const int n,
const int* offsets,
const int out_nnz,
int* out) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
__shared__ int block_counts[BS];
__shared__ int block_outs[BS * 32];
int count = 0;
if (tid < n) {
// get the count of 1 in flags[tid]
int flag = flags[tid];
count = BitCount(static_cast<uint32_t>(flag));
}
// call block prefix_sum
// using namespace cub;
typedef cub::BlockScan<int, BS> BlockScan;
__shared__ typename BlockScan::TempStorage temp_storage;
BlockScan(temp_storage).ExclusiveSum(count, count);
__syncthreads();
// write index to out
if (tid < n) {
// get the count of 1 in flags[tid]
int flag = flags[tid];
// int j = block_counts[threadIdx.x];
int j = count;
// TODO(zhangkaihuo): opt the loop
for (int i = 0; i < 32; ++i) {
if ((1 & (flag >> i)) == 1) {
block_outs[j++] = (tid << 5) + i;
}
}
}
__syncthreads();
// write to block_outs
int start = offsets[blockIdx.x];
int end = blockIdx.x == gridDim.x - 1 ? out_nnz : offsets[blockIdx.x + 1];
for (int i = threadIdx.x; i < end - start; i += blockDim.x) {
out[start + i] = block_outs[i];
}
}
template <typename IntT> template <typename IntT>
__global__ void GroupIndexs(const int* out_index_table, __global__ void GroupIndexs(const int* out_index_table,
const int n, const int n,
...@@ -725,13 +813,25 @@ int ProductRuleBook(const Context& dev_ctx, ...@@ -725,13 +813,25 @@ int ProductRuleBook(const Context& dev_ctx,
gpuMemcpyDeviceToHost, gpuMemcpyDeviceToHost,
dev_ctx.stream()); dev_ctx.stream());
dev_ctx.Wait(); dev_ctx.Wait();
const int threads = 256;
const int blocks = (index_flags.numel() + threads - 1) / threads;
GetOutIndexsCounter<<<blocks, threads, 0, dev_ctx.stream()>>>(
index_flags_ptr, index_flags.numel(), out_index_table_ptr);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
thrust::sort(thrust::hip::par.on(dev_ctx.stream()), thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()),
#else #else
thrust::sort(thrust::cuda::par.on(dev_ctx.stream()), thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()),
#endif #endif
out_index_ptr, out_index_table_ptr,
out_index_ptr + out_nnz); out_index_table_ptr + blocks,
out_index_table_ptr);
GetOutIndexs<threads>
<<<blocks, threads, 0, dev_ctx.stream()>>>(index_flags_ptr,
index_flags.numel(),
out_index_table_ptr,
out_nnz,
out_index_ptr);
const int64_t sparse_dim = 4; const int64_t sparse_dim = 4;
phi::DenseTensor out_indices = phi::DenseTensor out_indices =
......
...@@ -125,7 +125,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -125,7 +125,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
#ifdef PADDLE_WITH_CUTLASS #ifdef PADDLE_WITH_CUTLASS
bool cutlass = true; bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 80) cutlass = false; if (dev_ctx.GetComputeCapability() < 75) cutlass = false;
if (in_channels % 4 != 0 || out_channels % 4 != 0) { if (in_channels % 4 != 0 || out_channels % 4 != 0) {
if (std::is_same<T, phi::dtype::float16>::value) cutlass = false; if (std::is_same<T, phi::dtype::float16>::value) cutlass = false;
if (std::is_same<T, float>::value) cutlass = false; if (std::is_same<T, float>::value) cutlass = false;
...@@ -173,7 +173,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, ...@@ -173,7 +173,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
if constexpr (std::is_same<T, float>::value && if constexpr (std::is_same<T, float>::value &&
std::is_same<IntT, int32_t>::value) { std::is_same<IntT, int32_t>::value) {
fp32_gather_gemm_scatter gather_gemm_scatter = fp32_gather_gemm_scatter gather_gemm_scatter =
getBestFp32Kernel(M, N, K); getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability());
gather_gemm_scatter(dev_ctx, gather_gemm_scatter(dev_ctx,
x.non_zero_elements().data<T>(), x.non_zero_elements().data<T>(),
tmp_kernel_ptr, tmp_kernel_ptr,
......
...@@ -72,7 +72,13 @@ fp16_gather_gemm_scatter getBestFp16Kernel(const int M, ...@@ -72,7 +72,13 @@ fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
} }
fp32_gather_gemm_scatter getBestFp32Kernel(const int M, fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int N, const int N,
const int K) { const int K,
const int SM) {
if (SM == 75) {
return launchKernel<
float,
cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4::Gemm>;
}
if (K == 4 && N == 16) { if (K == 4 && N == 16) {
return launchKernel< return launchKernel<
float, float,
......
...@@ -66,7 +66,8 @@ fp16_gather_gemm_scatter getBestFp16Kernel(const int M, ...@@ -66,7 +66,8 @@ fp16_gather_gemm_scatter getBestFp16Kernel(const int M,
const int N); const int N);
fp32_gather_gemm_scatter getBestFp32Kernel(const int M, fp32_gather_gemm_scatter getBestFp32Kernel(const int M,
const int K, const int K,
const int N); const int N,
const int SM);
fp64_gather_gemm_scatter getBestFp64Kernel(const int M, fp64_gather_gemm_scatter getBestFp64Kernel(const int M,
const int K, const int K,
const int N); const int N);
...@@ -550,6 +551,30 @@ struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 { ...@@ -550,6 +551,30 @@ struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 {
false, false,
true>; true>;
}; };
// sm75
struct cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4 {
using Gemm = cutlass::gemm::device::GemmUniversal<
cutlass::half_t,
cutlass::layout::RowMajor,
cutlass::half_t,
cutlass::layout::RowMajor,
float,
cutlass::layout::RowMajor,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>,
cutlass::gemm::GemmShape<16, 8, 8>,
cutlass::epilogue::thread::LinearCombination<float, 4, float, float>,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>,
2,
8,
8,
cutlass::arch::OpMultiplyAdd>;
};
} // namespace sparse } // namespace sparse
} // namespace phi } // namespace phi
#endif #endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册