diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu index e6f3ca336491874bb86ffcdb2606ea8841a19a35..f575b903895f92c63594bd1d653721b78ee7baa1 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu @@ -150,60 +150,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i]; const IntT* scatter_indices = rulebook_ptr + rulebook_len + h_offsets_ptr[i]; - - if constexpr (std::is_same::value && - std::is_same::value) { - fp16_gather_gemm_scatter gather_gemm_scatter = - getBestFp16Kernel(M, N, K); - gather_gemm_scatter( - dev_ctx, - reinterpret_cast( - x.non_zero_elements().data()), - reinterpret_cast(tmp_kernel_ptr), - reinterpret_cast(out_values_ptr), - reinterpret_cast(out_values_ptr), - M, - N, - K, - static_cast(gather_indices), - static_cast(scatter_indices), - static_cast(1), - static_cast(1)); - } - if constexpr (std::is_same::value && - std::is_same::value) { - fp32_gather_gemm_scatter gather_gemm_scatter = - getBestFp32Kernel(M, N, K, dev_ctx.GetComputeCapability()); - gather_gemm_scatter(dev_ctx, - x.non_zero_elements().data(), - tmp_kernel_ptr, - out_values_ptr, - out_values_ptr, - M, - N, - K, - gather_indices, - scatter_indices, - static_cast(1), - static_cast(1)); - } - if constexpr (std::is_same::value && - std::is_same::value) { - fp64_gather_gemm_scatter gather_gemm_scatter = - getBestFp64Kernel(M, N, K); - gather_gemm_scatter(dev_ctx, - x.non_zero_elements().data(), - tmp_kernel_ptr, - out_values_ptr, - out_values_ptr, - M, - N, - K, - gather_indices, - scatter_indices, - static_cast(1), - static_cast(1)); - } + dispatchKernel(dev_ctx, + x.non_zero_elements().data(), + tmp_kernel_ptr, + out_values_ptr, + out_values_ptr, + M, + N, + K, + gather_indices, + scatter_indices, + cutlass, + x.dtype()); } } else { #endif diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h index b596ff545383fe88e66ef46ad56e778e91703460..dab35ed47737a6e2e1dcf2e40b924ad1a8a8645c 100644 --- a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h +++ b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h @@ -23,6 +23,7 @@ #include "cutlass/util/device_memory.h" #include "examples/common/helper.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" namespace phi { namespace sparse { typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx, @@ -115,6 +116,66 @@ void launchKernel(const GPUContext& dev_ctx, CUTLASS_CHECK(status); gemm_op(dev_ctx.stream()); } +static void dispatchKernel(const GPUContext& dev_ctx, + const void* const a, + const void* const b, + const void* const c, + void* const d, + const int m, + const int n, + const int k, + const void* a_indices, + const void* c_d_indices, + const bool cutlass, + const phi::DataType type) { + if (!cutlass) return; + + if (type == phi::DataType::FLOAT16) { + fp16_gather_gemm_scatter gather_gemm_scatter = getBestFp16Kernel(m, n, k); + gather_gemm_scatter(dev_ctx, + static_cast(a), + static_cast(b), + static_cast(c), + static_cast(d), + m, + n, + k, + static_cast(a_indices), + static_cast(c_d_indices), + static_cast(1), + static_cast(1)); + } else if (type == phi::DataType::FLOAT32) { + fp32_gather_gemm_scatter gather_gemm_scatter = + getBestFp32Kernel(m, n, k, dev_ctx.GetComputeCapability()); + gather_gemm_scatter(dev_ctx, + static_cast(a), + static_cast(b), + static_cast(c), + static_cast(d), + m, + n, + k, + static_cast(a_indices), + static_cast(c_d_indices), + static_cast(1), + static_cast(1)); + } else if (type == phi::DataType::FLOAT64) { + fp64_gather_gemm_scatter gather_gemm_scatter = getBestFp64Kernel(m, n, k); + gather_gemm_scatter(dev_ctx, + static_cast(a), + static_cast(b), + static_cast(c), + static_cast(d), + m, + n, + k, + static_cast(a_indices), + static_cast(c_d_indices), + static_cast(1), + static_cast(1)); + } +} + struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 { using Gemm = cutlass::gemm::device::GemmUniversal< cutlass::half_t,