From 8a1cdc7010e692cd9d66fde27102460ec52ff91b Mon Sep 17 00:00:00 2001 From: umiswing Date: Fri, 21 Apr 2023 11:18:45 +0800 Subject: [PATCH] [cutlass] gather-gemm-scatter fusion on sm 75 (#53017) --- .../kernels/sparse/gpu/conv_grad_kernel.cu | 4 +- paddle/phi/kernels/sparse/gpu/conv_kernel.cu | 83 ++++++++++------- .../sparse/gpu/cutlass_generator/common.h | 70 ++++++++------- .../gather_gemm_scatter_generator.py | 89 ++++++++++++++++++- .../gather_gemm_scatter_manifest.py | 18 ++-- .../kernels/sparse/gpu/gather_gemm_scatter.h | 76 ++++++++++------ 6 files changed, 239 insertions(+), 101 deletions(-) diff --git a/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu index c29dd6ee86e..5128348b5d5 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu @@ -205,7 +205,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, // (in_channels, n) * (n, out_channels) static cutlass::device_memory::allocation workspace( workspace_size); - GatherGemmScatterDriver( + GatherGemmScatterDriver<80, true, false>( dev_ctx, key, x.values().data(), @@ -223,7 +223,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, &workspace); // call gemm: d_x = out_grad * transpose(kernel) // (n, out_channels) * (out_channels, in_channels) - GatherGemmScatterDriver( + GatherGemmScatterDriver<80, false, true>( dev_ctx, key, out_grad.values().data(), diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu index 43e2b8c01cf..adefddd5af1 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" @@ -31,6 +32,41 @@ limitations under the License. */ namespace phi { namespace sparse { +#define GATHER_GEMM_SCATTER(arch, input_type, x_nnz, kernel) \ + ({ \ + const input_type* kernel_ptr = kernel.data(); \ + const input_type* x_nnz_ptr = x_nnz.data(); \ + for (int i = 0; i < kernel_size; i++) { \ + if (h_counter_ptr[i] <= 0) { \ + continue; \ + } \ + const int M = h_counter_ptr[i]; \ + const int K = in_channels; \ + const int N = out_channels; \ + const input_type* tmp_kernel_ptr = kernel_ptr + i * K * N; \ + const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i]; \ + const IntT* scatter_indices = \ + rulebook_ptr + rulebook_len + h_offsets_ptr[i]; \ + const size_t key = autotune::GenKey(M / features_num_range, N, K); \ + GatherGemmScatterDriver( \ + dev_ctx, \ + key, \ + x_nnz_ptr, \ + tmp_kernel_ptr, \ + out_values_ptr, \ + out_values_ptr, \ + M, \ + N, \ + K, \ + gather_indices, \ + static_cast(nullptr), \ + scatter_indices, \ + static_cast(1.0), \ + static_cast(1.0), \ + nullptr); \ + } \ + }) + template void Conv3dCooGPUKernel(const GPUContext& dev_ctx, const SparseCooTensor& x, @@ -124,10 +160,14 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, } #ifdef PADDLE_WITH_CUTLASS + bool mixed_precision = dev_ctx.GetComputeCapability() >= 75 && + dev_ctx.GetComputeCapability() < 80 && + std::is_same::value; bool cutlass = true; - if (dev_ctx.GetComputeCapability() < 80) cutlass = false; + if (dev_ctx.GetComputeCapability() < 75) cutlass = false; if (in_channels % 8 != 0 || out_channels % 8 != 0) { if (std::is_same::value) cutlass = false; + if (mixed_precision) cutlass = false; } if (in_channels % 4 != 0 || out_channels % 4 != 0) { if (std::is_same::value) cutlass = false; @@ -141,36 +181,17 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, phi::funcs::SetConstant set_zero; set_zero(dev_ctx, out_values, static_cast(0.0f)); - const T* kernel_ptr = kernel.data(); - for (int i = 0; i < kernel_size; i++) { - if (h_counter_ptr[i] <= 0) { - continue; - } - - const int M = h_counter_ptr[i]; - const int K = in_channels; - const int N = out_channels; - const T* tmp_kernel_ptr = kernel_ptr + i * K * N; - const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i]; - const IntT* scatter_indices = - rulebook_ptr + rulebook_len + h_offsets_ptr[i]; - const size_t key = autotune::GenKey(M / features_num_range, N, K); - GatherGemmScatterDriver( - dev_ctx, - key, - x.non_zero_elements().data(), - tmp_kernel_ptr, - out_values_ptr, - out_values_ptr, - M, - N, - K, - gather_indices, - static_cast(nullptr), - scatter_indices, - static_cast(1.0), - static_cast(1.0), - nullptr); + if (mixed_precision) { + DenseTensor kernel_fp16 = + phi::Cast(dev_ctx, kernel, DataType::FLOAT16); + DenseTensor x_nnz_fp16 = phi::Cast( + dev_ctx, x.non_zero_elements(), DataType::FLOAT16); + GATHER_GEMM_SCATTER(75, phi::dtype::float16, x_nnz_fp16, kernel_fp16); + } else { + if (dev_ctx.GetComputeCapability() < 80) + GATHER_GEMM_SCATTER(75, T, x.non_zero_elements(), kernel); + else + GATHER_GEMM_SCATTER(80, T, x.non_zero_elements(), kernel); } } else { #endif diff --git a/paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h b/paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h index 5a94344e8f9..71d9aa3084a 100644 --- a/paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h +++ b/paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h @@ -36,20 +36,20 @@ size_t constexpr max_out_channels = 256; static size_t workspace_size = sizeof(float) * max_splitk_slices * max_in_channels * max_out_channels; -#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \ - typedef void (*kernel)(dtype const alpha, \ - dtype const beta, \ - const GPUContext& dev_ctx, \ - const dtype* const a, \ - const dtype* const b, \ - const dtype* const c, \ - dtype* const d, \ - const int m, \ - const int n, \ - const int k, \ - const int32_t* a_indices, \ - const int32_t* b_indices, \ - const int32_t* c_d_indices, \ +#define TYPEDEF_KERNEL_POINTER(kernel, in_type, out_type) \ + typedef void (*kernel)(out_type const alpha, \ + out_type const beta, \ + const GPUContext& dev_ctx, \ + const in_type* const a, \ + const in_type* const b, \ + const out_type* const c, \ + out_type* const d, \ + const int m, \ + const int n, \ + const int k, \ + const int32_t* a_indices, \ + const int32_t* b_indices, \ + const int32_t* c_d_indices, \ void* const workspace_ptr); #define GATHER_GEMM_SCATTER_CHECK(status) \ { \ @@ -58,15 +58,15 @@ static size_t workspace_size = throw std::runtime_error(cutlassGetStatusString(error)); \ } \ } -#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \ +#define DEFINE_LAUNCH_KERNEL(in_type, out_type) \ template \ - void launchKernel(dtype const alpha, \ - dtype const beta, \ + void launchKernel(out_type const alpha, \ + out_type const beta, \ const GPUContext& dev_ctx, \ - const dtype* const a, \ - const dtype* const b, \ - const dtype* const c, \ - dtype* const d, \ + const in_type* const a, \ + const in_type* const b, \ + const out_type* const c, \ + out_type* const d, \ const int m, \ const int n, \ const int k, \ @@ -81,12 +81,14 @@ static size_t workspace_size = Config::Mode, \ problem_size_real, \ split_k_slices, \ - {static_cast(static_cast(alpha)), \ - static_cast(static_cast(beta))}, \ - reinterpret_cast(a), \ - reinterpret_cast(b), \ - reinterpret_cast(c), \ - reinterpret_cast(d), \ + {static_cast( \ + static_cast(alpha)), \ + static_cast( \ + static_cast(beta))}, \ + reinterpret_cast(a), \ + reinterpret_cast(b), \ + reinterpret_cast(c), \ + reinterpret_cast(d), \ m * k, \ k * n, \ m * n, \ @@ -172,19 +174,23 @@ static size_t workspace_size = ref_workspace, \ ref_d, \ ref_c, \ - {static_cast(static_cast(alpha)), \ - static_cast(static_cast(beta))}); \ + {static_cast( \ + static_cast(alpha)), \ + static_cast( \ + static_cast(beta))}); \ status = reduction_op.initialize(reduction_args); \ GATHER_GEMM_SCATTER_CHECK(status); \ reduction_op(dev_ctx.stream()); \ } \ } -TYPEDEF_KERNEL_POINTER(fp16_gather_gemm_scatter, phi::dtype::float16) -TYPEDEF_KERNEL_POINTER(fp32_gather_gemm_scatter, float) +TYPEDEF_KERNEL_POINTER(gather_hgemm_scatter, phi::dtype::float16, phi::float16) +TYPEDEF_KERNEL_POINTER(gather_sgemm_scatter, float, float) +TYPEDEF_KERNEL_POINTER(gather_sgemm_f16_scatter, phi::dtype::float16, float) -DEFINE_LAUNCH_KERNEL(phi::dtype::float16, cutlass::half_t) +DEFINE_LAUNCH_KERNEL(phi::dtype::float16, phi::dtype::float16) DEFINE_LAUNCH_KERNEL(float, float) +DEFINE_LAUNCH_KERNEL(phi::dtype::float16, float) } // namespace sparse } // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py b/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py index a2d837347b4..c19c083edc0 100644 --- a/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py +++ b/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py @@ -524,6 +524,91 @@ def GenerateSM80_TensorOp_1688_fast_fp32_math( ) +def GenerateSM75_TensorOp_1688(manifest, cuda_version, debug=False): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 2): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( + [16, 8, 8], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ), + MathInstruction( + [16, 8, 8], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ), + ] + + min_cc = 75 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription( + [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 64], 2, [1, 2, 2], math_inst, min_cc, max_cc + ), + ] + + if debug: + tile_descriptions = [ + TileDescription( + [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc + ), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGatherGemmScatterOperator( + manifest, layouts, tile_descriptions, data_type + ) + + +def GenerateSM75(manifest, cuda_version, debug=False): + GenerateSM75_TensorOp_1688(manifest, cuda_version, debug) + + def GenerateSM80(manifest, cuda_version, debug=False): GenerateSM80_TensorOp_16816(manifest, cuda_version, debug) GenerateSM80_TensorOp_1688(manifest, cuda_version, debug) @@ -582,6 +667,8 @@ if __name__ == "__main__": ) manifest = GatherGemmScatterManifest(args) - GenerateSM80(manifest, args.cuda_version) + debug = False + GenerateSM75(manifest, args.cuda_version, debug) + GenerateSM80(manifest, args.cuda_version, debug) manifest.emit(GeneratorTarget.Library) diff --git a/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py b/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py index 280cd082d2f..92c07a655b8 100644 --- a/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py +++ b/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py @@ -42,10 +42,12 @@ namespace sparse { #endif """ self.kernels_lists = { - "hnn": "static std::vector fp16_nn_kernels = {", - "snn": "static std::vector fp32_nn_kernels = {", - "snt": "static std::vector fp32_nt_kernels = {", - "stn": "static std::vector fp32_tn_kernels = {", + "hnn75": "static std::vector sm75_fp16_nn_kernels = {", + "snn75": "static std::vector sm75_fp32_nn_kernels = {", + "hnn80": "static std::vector sm80_fp16_nn_kernels = {", + "snn80": "static std::vector sm80_fp32_nn_kernels = {", + "snt80": "static std::vector sm80_fp32_nt_kernels = {", + "stn80": "static std::vector sm80_fp32_tn_kernels = {", } def __enter__(self): @@ -81,7 +83,9 @@ namespace sparse { if operations[0].layout_name() == 'tn': self.kernels_lists[ - operations[0].short_math_name() + operations[0].layout_name() + operations[0].short_math_name() + + operations[0].layout_name() + + str(operations[0].arch) ] += ( """ launchKernel<""" @@ -91,7 +95,9 @@ launchKernel<""" ) else: self.kernels_lists[ - operations[0].short_math_name() + operations[0].layout_name() + operations[0].short_math_name() + + operations[0].layout_name() + + str(operations[0].arch) ] += ( """ launchKernel<""" diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h index 60ffd99c7f1..92e07c1e4d5 100644 --- a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h +++ b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h @@ -27,42 +27,56 @@ namespace sparse { // that shapes within this range share the same key. constexpr int features_num_range = 10000; -template +template void GatherGemmScatterDriver( const phi::GPUContext& ctx, const size_t key, - const T* const a, - const T* const b, - const T* const c, - T* const d, + const Input* const a, + const Input* const b, + const Output* const c, + Output* const d, const int& m, const int& n, const int& k, const IntT* a_indices, const IntT* b_indices, const IntT* c_d_indices, - T alpha, - T beta, - cutlass::device_memory::allocation* const workspace_ptr) {} + Output alpha, + Output beta, + cutlass::device_memory::allocation* const workspace_ptr) { + PADDLE_THROW( + phi::errors::Unimplemented("gather_gemm_scatter fusion only supports " + "fp16_nn, fp32_nn, fp32_nt and fp32_tn now.")); +} #define EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER( \ - T, kernels, transpose_a, transpose_b) \ + compute_capability, transpose_a, transpose_b, in_type, out_type, kernels) \ template <> \ - inline void GatherGemmScatterDriver( \ + inline void GatherGemmScatterDriver( \ const phi::GPUContext& ctx, \ const size_t key, \ - const T* const a, \ - const T* const b, \ - const T* const c, \ - T* const d, \ + const in_type* const a, \ + const in_type* const b, \ + const out_type* const c, \ + out_type* const d, \ const int& m, \ const int& n, \ const int& k, \ const int32_t* a_indices, \ const int32_t* b_indices, \ const int32_t* c_d_indices, \ - T alpha, \ - T beta, \ + out_type alpha, \ + out_type beta, \ cutlass::device_memory::allocation* const workspace_ptr) { \ auto* tuner = \ autotune::MakeGatherGemmScatterTuner( \ @@ -86,22 +100,26 @@ void GatherGemmScatterDriver( workspace_ptr); \ } -EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(phi::dtype::float16, - fp16_nn_kernels, +EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(75, false, - false) -EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float, - fp32_nn_kernels, false, - false) -EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float, - fp32_nt_kernels, + phi::dtype::float16, + phi::dtype::float16, + sm75_fp16_nn_kernels) +EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER( + 75, false, false, phi::dtype::float16, float, sm75_fp32_nn_kernels) +EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(80, false, - true) -EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER(float, - fp32_tn_kernels, - true, - false) + false, + phi::dtype::float16, + phi::dtype::float16, + sm80_fp16_nn_kernels) +EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER( + 80, false, false, float, float, sm80_fp32_nn_kernels) +EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER( + 80, false, true, float, float, sm80_fp32_nt_kernels) +EXPLICIT_SPECIALIZE_GATHER_GEMM_SCATTER_DRIVER( + 80, true, false, float, float, sm80_fp32_tn_kernels) } // namespace sparse } // namespace phi -- GitLab