diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index 3b731e0c112b69662fee582dd5f4dbb5df2a1e68..11f7630c9724c65a0db249e7cad1b72c1d483933 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -39,12 +39,14 @@ ExternalProject_Add( UPDATE_COMMAND "" CONFIGURE_COMMAND "" BUILD_COMMAND + rm -rf + ${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build && mkdir -p - ${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/build/generated/gemm + ${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build/generated/gemm && ${PYTHON_EXECUTABLE} -B - ${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py + ${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py "${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/library/scripts/" - "${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/build" + "${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build" "${CMAKE_CUDA_COMPILER_VERSION}" INSTALL_COMMAND "" TEST_COMMAND "") diff --git a/paddle/phi/kernels/autotune/auto_tune_base.h b/paddle/phi/kernels/autotune/auto_tune_base.h index 908b6560e28c18d6342c879db1caa3706107df40..606ecc3c59ca72ba70cabe5cfc95c0d0da2eaccc 100644 --- a/paddle/phi/kernels/autotune/auto_tune_base.h +++ b/paddle/phi/kernels/autotune/auto_tune_base.h @@ -177,6 +177,85 @@ class MatmulAutoTuner } }; +template +class GatherGemmScatterAutoTuner + : public AutoTuneBase> { + public: + static GatherGemmScatterAutoTuner* Instance( + ReturnType (*func)(T, T, Args...)) { + static std::once_flag gather_gemm_scatter_init_flag; + static std::unique_ptr> + instance; + std::call_once(gather_gemm_scatter_init_flag, [&] { + auto obj = MakeCallback(func); + instance.reset(new GatherGemmScatterAutoTuner); + instance->AddCallBack(func); + }); + return instance.get(); + } + + void Run(const phi::GPUContext& ctx, + const size_t key, + T const alpha, + T const beta, + Args... args) { + this->is_init_ = true; + this->CheckKernelSize(); + auto& cache = AutoTuneCache::Instance().GetGatherGemmScatter(); + + if (cache.Find(key)) { + auto best_idx = cache.Get(key); + this->kernels_[best_idx].Run(alpha, beta, args...); + + } else { + // Set alpha to 0 and beta to 1 to avoid changing the value of d when + // picking the best kernel + auto best_idx = + PickBestKernel(ctx, static_cast(0), static_cast(1), args...); + cache.Set(key, best_idx); + this->kernels_[best_idx].Run(alpha, beta, args...); + } + } + + protected: + size_t PickBestKernel(const phi::GPUContext& ctx, + const T& alpha, + const T& beta, + Args&... args) { + std::lock_guard lock(this->mutex_); + constexpr size_t NO_KERNEL_WORKS = -1; + size_t best_idx = NO_KERNEL_WORKS; + float min_time = std::numeric_limits::max(); + + // Time cost test estabulished in default stream. + for (int i = 0; i < this->kernels_.size(); ++i) { + float time = 0; + // Some kernels may require more shared memory than available, skip these + // kernels. + try { + time = this->RunAndMeasureKernel(ctx, i, alpha, beta, args...); + if (time < min_time) { + min_time = time; + best_idx = i; + } + } catch (const std::runtime_error& error) { + VLOG(3) << "the kernels_[" << i << "] get error:" << error.what(); + } + } + if (best_idx == NO_KERNEL_WORKS) { + LOG(ERROR) << "No kernel works!\n"; + exit(-1); + } + VLOG(3) << "best kernel idx is " << best_idx; + return best_idx; + } +}; +template +static GatherGemmScatterAutoTuner* +MakeGatherGemmScatterTuner(ReturnType (*func)(T, T, Args...)) { + return GatherGemmScatterAutoTuner::Instance(func); +} + // Define the auto_tuner inital object. #define DEFINE_AUTOTUNER_COMMON_OBJ(name) \ template \ diff --git a/paddle/phi/kernels/autotune/cache.h b/paddle/phi/kernels/autotune/cache.h index 711c8a063f78ad5ecc462996dd579c5981a831b5..c5122f0260cb2aa5351ec76ba2d6a36ab49e6ebc 100644 --- a/paddle/phi/kernels/autotune/cache.h +++ b/paddle/phi/kernels/autotune/cache.h @@ -45,13 +45,15 @@ enum class AlgorithmType { kConvBackwardFilter = 3, kTranspose = 4, kMatmul = 5, + kGatherGemmScatterFP16NN = 6, + kGatherGemmScatterFP32NN = 7, #if !defined(PADDLE_WITH_CUDNN_FRONTEND) - kAlgorithmCount = 6 + kAlgorithmCount = 8 #else - kConvForwardV8 = 6, - kConvBackwardDataV8 = 7, - kConvBackwardFilterV8 = 8, - kAlgorithmCount = 9 + kConvForwardV8 = 8, + kConvBackwardDataV8 = 9, + kConvBackwardFilterV8 = 10, + kAlgorithmCount = 11 #endif }; @@ -88,6 +90,20 @@ class AutoTuneCache { return conv_auto_tune_map_[static_cast(algo_type)]; } + template + typename std::enable_if::value, + AlgorithmsCacheMap&>::type + GetGatherGemmScatter() { + return Get(AlgorithmType::kGatherGemmScatterFP32NN); + } + + template + typename std::enable_if::value, + AlgorithmsCacheMap&>::type + GetGatherGemmScatter() { + return Get(AlgorithmType::kGatherGemmScatterFP16NN); + } + #ifdef PADDLE_WITH_CUDNN_FRONTEND CudnnFrontendPlanCache& GetConvV8(const AlgorithmType& algo_type) { return cudnn_v8_auto_tune_map_[static_cast(algo_type)]; diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu index a29941d7646d8f7f3c8709e5206bd49cc0f866ee..aa3ae43397ab5f4bbb231bb2cad7c902cc7d248b 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu @@ -125,12 +125,16 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, #ifdef PADDLE_WITH_CUTLASS bool cutlass = true; - if (dev_ctx.GetComputeCapability() < 75) cutlass = false; - if (in_channels % 4 != 0 || out_channels % 4 != 0) { + if (dev_ctx.GetComputeCapability() < 80) cutlass = false; + if (in_channels % 8 != 0 || out_channels % 8 != 0) { if (std::is_same::value) cutlass = false; + } + if (in_channels % 4 != 0 || out_channels % 4 != 0) { if (std::is_same::value) cutlass = false; } + if (std::is_same::value) cutlass = false; if (!std::is_same::value) cutlass = false; + if (cutlass) { auto* out_values = out->mutable_non_zero_elements(); T* out_values_ptr = out_values->data(); @@ -150,18 +154,18 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, const IntT* gather_indices = rulebook_ptr + h_offsets_ptr[i]; const IntT* scatter_indices = rulebook_ptr + rulebook_len + h_offsets_ptr[i]; - dispatchKernel(dev_ctx, - x.non_zero_elements().data(), - tmp_kernel_ptr, - out_values_ptr, - out_values_ptr, - M, - N, - K, - gather_indices, - scatter_indices, - cutlass, - x.dtype()); + GatherGemmScatterDriver(dev_ctx, + x.non_zero_elements().data(), + tmp_kernel_ptr, + out_values_ptr, + out_values_ptr, + M, + N, + K, + gather_indices, + scatter_indices, + static_cast(1.0), + static_cast(1.0)); } } else { #endif diff --git a/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py b/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py deleted file mode 100644 index 49c6e5975181a0e50a7f87ccf2c6fe16a8cea669..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys - -sys.path.append(sys.argv[1]) -from gather_gemm_scatter_manifest import GatherGemmScatterManifest -from gather_gemm_scatter_operation import GatherGemmScatterOperation -from generator import ( - ComplexTransform, - CudaToolkitVersionSatisfies, - EpilogueFunctor, - GemmKind, - SwizzlingFunctor, - TensorDescription, -) -from library import ( - DataType, - LayoutType, - MathInstruction, - MathOperation, - OpcodeClass, - TileDescription, -) -from manifest import GeneratorTarget - - -def CreateGatherGemmScatterOperator( - manifest, - layouts, - tile_descriptions, - data_type, - alignment_constraints, - complex_transforms=None, - epilogue_functor=EpilogueFunctor.LinearCombination, - swizzling_functor=SwizzlingFunctor.Identity8, -): - # To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK` - - if complex_transforms is None: - complex_transforms = [ - (ComplexTransform.none, ComplexTransform.none), - ] - - element_a, element_b, element_c, element_epilogue = data_type - - operations = [] - - # by default, only generate the largest tile and largest alignment - # if manifest.kernel_filter == '': - # tile_descriptions = [tile_descriptions[0],] - # alignment_constraints = [alignment_constraints[0],] - - for layout in layouts: - for tile_description in tile_descriptions: - for alignment in alignment_constraints: - for complex_transform in complex_transforms: - - alignment_c = min(8, alignment) - - A = TensorDescription( - element_a, layout[0], alignment, complex_transform[0] - ) - B = TensorDescription( - element_b, layout[1], alignment, complex_transform[1] - ) - C = TensorDescription(element_c, layout[2], alignment_c) - - new_operation = GatherGemmScatterOperation( - GemmKind.Universal, - tile_description.minimum_compute_capability, - tile_description, - A, - B, - C, - element_epilogue, - epilogue_functor, - swizzling_functor, - ) - - manifest.append(new_operation) - operations.append(new_operation) - - return operations - - -def GenerateSM70_TensorOp_884(manifest, cuda_version): - - if not CudaToolkitVersionSatisfies(cuda_version, 10, 1): - return - - layouts = [ - (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), - ] - - math_instructions = [ - MathInstruction( - [8, 8, 4], - DataType.f16, - DataType.f16, - DataType.f32, - OpcodeClass.TensorOp, - MathOperation.multiply_add, - ), - MathInstruction( - [8, 8, 4], - DataType.f16, - DataType.f16, - DataType.f16, - OpcodeClass.TensorOp, - MathOperation.multiply_add, - ), - ] - - min_cc = 70 - max_cc = 75 - - alignment_constraints = [8, 4, 2, 1] - - for math_inst in math_instructions: - tile_descriptions = [ - TileDescription( - [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc - ), - TileDescription( - [64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc - ), - ] - - data_type = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_accumulator, - math_inst.element_accumulator, - ] - - CreateGatherGemmScatterOperator( - manifest, - layouts, - tile_descriptions, - data_type, - alignment_constraints, - ) - - # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) - if math_inst.element_a != math_inst.element_accumulator: - - data_type_mixed = [ - math_inst.element_a, - math_inst.element_b, - math_inst.element_a, - math_inst.element_accumulator, - ] - - CreateGatherGemmScatterOperator( - manifest, - layouts, - tile_descriptions, - data_type_mixed, - alignment_constraints, - ) - - -def GenerateSM70(manifest, cuda_version): - GenerateSM70_TensorOp_884(manifest, cuda_version) - - -class KernelCfg: - def __init__( - self, - architectures, - build_dir, - cuda_version, - curr_build_dir, - disable_full_archs_compilation, - filter_by_cc, - generator_target, - ignore_kernels, - interface_dir, - kernel_filter_file, - kernels, - operations, - selected_kernel_list, - ): - self.architectures = architectures - self.build_dir = build_dir - self.cuda_version = cuda_version - self.curr_build_dir = curr_build_dir - self.disable_full_archs_compilation = disable_full_archs_compilation - self.filter_by_cc = filter_by_cc - self.generator_target = generator_target - self.ignore_kernels = ignore_kernels - self.interface_dir = interface_dir - self.kernel_filter_file = kernel_filter_file - self.kernels = kernels - self.operations = operations - self.selected_kernel_list = selected_kernel_list - - -if __name__ == "__main__": - - args = KernelCfg( - architectures='70', - build_dir=sys.argv[2], - cuda_version=sys.argv[3], - curr_build_dir=sys.argv[2], - disable_full_archs_compilation=False, - filter_by_cc='True', - generator_target='library', - ignore_kernels='', - interface_dir=None, - kernel_filter_file=None, - kernels='', - operations='all', - selected_kernel_list=None, - ) - manifest = GatherGemmScatterManifest(args) - - GenerateSM70(manifest, args.cuda_version) - - manifest.emit(GeneratorTarget.Library) diff --git a/paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h b/paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h new file mode 100644 index 0000000000000000000000000000000000000000..9732cb89ae899a5e554a016c2b744b70d9a3dc60 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h @@ -0,0 +1,103 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#ifdef PADDLE_WITH_CUTLASS +#include "cutlass/arch/mma.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/half.h" +#include "cutlass/util/device_memory.h" +#include "examples/common/helper.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +namespace phi { +namespace sparse { +#define TYPEDEF_KERNEL_POINTER(kernel, dtype) \ + typedef void (*kernel)(dtype const alpha, \ + dtype const beta, \ + const GPUContext& dev_ctx, \ + const dtype* const a, \ + const dtype* const b, \ + const dtype* const c, \ + dtype* const d, \ + const int m, \ + const int n, \ + const int k, \ + const int32_t* a_indices, \ + const int32_t* c_d_indices); +#define GATHER_GEMM_SCATTER_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + throw std::runtime_error(cutlassGetStatusString(error)); \ + } \ + } +#define DEFINE_LAUNCH_KERNEL(dtype, cutlass_type) \ + template \ + void launchKernel(dtype const alpha, \ + dtype const beta, \ + const GPUContext& dev_ctx, \ + const dtype* const a, \ + const dtype* const b, \ + const dtype* const c, \ + dtype* const d, \ + const int m, \ + const int n, \ + const int k, \ + const int32_t* a_indices, \ + const int32_t* c_d_indices) { \ + cutlass::gemm::GemmCoord problem_size_real({m, n, k}); \ + int split_k_slices = 1; \ + typename Gemm::Arguments arguments{ \ + cutlass::gemm::GemmUniversalMode::kGemm, \ + problem_size_real, \ + split_k_slices, \ + {static_cast(static_cast(alpha)), \ + static_cast(static_cast(beta))}, \ + reinterpret_cast(a), \ + reinterpret_cast(b), \ + reinterpret_cast(c), \ + reinterpret_cast(d), \ + cutlass::layout::RowMajor().capacity(problem_size_real.mk()), \ + cutlass::layout::RowMajor().capacity(problem_size_real.kn()), \ + cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \ + cutlass::layout::RowMajor().capacity(problem_size_real.mn()), \ + problem_size_real.k(), \ + problem_size_real.n(), \ + problem_size_real.n(), \ + problem_size_real.n(), \ + a_indices, \ + nullptr, \ + c_d_indices}; \ + size_t workspace_size = Gemm::get_workspace_size(arguments); \ + cutlass::device_memory::allocation workspace(workspace_size); \ + Gemm gemm_op; \ + cutlass::Status status = gemm_op.can_implement(arguments); \ + GATHER_GEMM_SCATTER_CHECK(status); \ + status = gemm_op.initialize(arguments, workspace.get()); \ + GATHER_GEMM_SCATTER_CHECK(status); \ + gemm_op(dev_ctx.stream()); \ + } + +TYPEDEF_KERNEL_POINTER(fp16_gather_gemm_scatter, phi::dtype::float16) +TYPEDEF_KERNEL_POINTER(fp32_gather_gemm_scatter, float) + +DEFINE_LAUNCH_KERNEL(phi::dtype::float16, cutlass::half_t) +DEFINE_LAUNCH_KERNEL(float, float) + +} // namespace sparse +} // namespace phi +#endif diff --git a/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py b/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..21b59f067a6090a04c7ac18badebdc45e09cb2e3 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py @@ -0,0 +1,552 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +sys.path.append(sys.argv[1]) +from gather_gemm_scatter_manifest import GatherGemmScatterManifest +from gather_gemm_scatter_operation import GatherGemmScatterOperation +from generator import ( + ComplexTransform, + CudaToolkitVersionSatisfies, + EpilogueFunctor, + GemmKind, + SwizzlingFunctor, + TensorDescription, +) +from library import ( + DataType, + LayoutType, + MathInstruction, + MathOperation, + OpcodeClass, + TileDescription, +) +from manifest import GeneratorTarget + + +def CreateGatherGemmScatterOperator( + manifest, + layouts, + tile_descriptions, + data_type, + complex_transforms=None, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, +): + # To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK` + + if complex_transforms is None: + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + ] + + element_a, element_b, element_c, element_epilogue = data_type + + alignment_constraints = [0] + if 'f16' == element_a.name or 'bf16' == element_a.name: + alignment_constraints = [8] + elif 'f32' == element_a.name or 'tf32' == element_a.name: + alignment_constraints = [4] + elif 'f64' == element_a.name: + alignment_constraints = [1] + + operations = [] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription( + element_a, layout[0], alignment, complex_transform[0] + ) + B = TensorDescription( + element_b, layout[1], alignment, complex_transform[1] + ) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GatherGemmScatterOperation( + GemmKind.Universal, + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor, + swizzling_functor, + ) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + + +def GenerateSM80_TensorOp_16816(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( + [16, 8, 16], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ), + ] + + min_cc = 80 + max_cc = 1024 + + alignment_constraints = [8] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription( + [256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 32], 3, [4, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 32], 6, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 64, 32], 10, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 128, 64], 3, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 256, 64], 3, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 64], 4, [4, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 256, 64], 4, [1, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 64], 3, [4, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 256, 64], 3, [1, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc + ), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGatherGemmScatterOperator( + manifest, layouts, tile_descriptions, data_type + ) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGatherGemmScatterOperator( + manifest, layouts, tile_descriptions, data_type_mixed + ) + + +def GenerateSM80_TensorOp_1688(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( + [16, 8, 8], + DataType.tf32, + DataType.tf32, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ) + ] + + min_cc = 80 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription( + [256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc + ), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGatherGemmScatterOperator( + manifest, layouts, tile_descriptions, data_type + ) + + CreateGatherGemmScatterOperator( + manifest, layouts, tile_descriptions, data_type_mixed + ) + + +def GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( + [16, 8, 8], + DataType.tf32, + DataType.tf32, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ), + ] + + min_cc = 80 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription( + [256, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 16], 4, [4, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 256, 16], 4, [1, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 16], 5, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 16], 6, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 64, 16], 10, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 32], 4, [4, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 256, 32], 4, [1, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32], 4, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 64, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc + ), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateGatherGemmScatterOperator( + manifest, layouts, tile_descriptions, data_type + ) + + +def GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( + [16, 8, 8], + DataType.f32, + DataType.f32, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add_fast_f32, + ), + ] + + min_cc = 80 + max_cc = 1024 + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription( + [128, 128, 16], 4, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 16], 3, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 256, 16], 3, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 16], 4, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 64, 16], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 32], 3, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 256, 32], 3, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 64, 32], 3, [2, 2, 1], math_inst, min_cc, max_cc + ), + ] + + data_type = [DataType.f32, DataType.f32, DataType.f32, DataType.f32] + + CreateGatherGemmScatterOperator( + manifest, layouts, tile_descriptions, data_type + ) + + +def GenerateSM80(manifest, cuda_version): + GenerateSM80_TensorOp_16816(manifest, cuda_version) + GenerateSM80_TensorOp_1688(manifest, cuda_version) + GenerateSM80_TensorOp_1688_fast_math(manifest, cuda_version) + GenerateSM80_TensorOp_1688_fast_fp32_math(manifest, cuda_version) + + +class KernelCfg: + def __init__( + self, + architectures, + build_dir, + cuda_version, + curr_build_dir, + disable_full_archs_compilation, + filter_by_cc, + generator_target, + ignore_kernels, + interface_dir, + kernel_filter_file, + kernels, + operations, + selected_kernel_list, + ): + self.architectures = architectures + self.build_dir = build_dir + self.cuda_version = cuda_version + self.curr_build_dir = curr_build_dir + self.disable_full_archs_compilation = disable_full_archs_compilation + self.filter_by_cc = filter_by_cc + self.generator_target = generator_target + self.ignore_kernels = ignore_kernels + self.interface_dir = interface_dir + self.kernel_filter_file = kernel_filter_file + self.kernels = kernels + self.operations = operations + self.selected_kernel_list = selected_kernel_list + + +if __name__ == "__main__": + + args = KernelCfg( + architectures='80', + build_dir=sys.argv[2], + cuda_version=sys.argv[3], + curr_build_dir=sys.argv[2], + disable_full_archs_compilation=False, + filter_by_cc='True', + generator_target='library', + ignore_kernels='', + interface_dir=None, + kernel_filter_file=None, + kernels='', + operations='all', + selected_kernel_list=None, + ) + manifest = GatherGemmScatterManifest(args) + + GenerateSM80(manifest, args.cuda_version) + + manifest.emit(GeneratorTarget.Library) diff --git a/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_manifest.py b/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py similarity index 64% rename from paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_manifest.py rename to paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py index 1cb48af13a1208789d938305c8022d4f8a4fca6e..a9ac554ede6e42587e50ee9ae7e2dfe1e3f466c6 100644 --- a/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_manifest.py +++ b/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_manifest.py @@ -18,7 +18,7 @@ import shutil from gather_gemm_scatter_operation import ( EmitGatherGemmScatterConfigurationLibrary, ) -from library import OperationKind, OperationKindNames +from library import OperationKind, OperationKindNames, SubstituteTemplate from manifest import EmitOperationKindLibrary, GeneratorTarget, Manifest @@ -28,11 +28,25 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary): self.emitters = { OperationKind.Gemm: EmitGatherGemmScatterConfigurationLibrary } - self.header_template = "#pragma once\n#ifdef PADDLE_WITH_CUTLASS\n" + self.header_template = "#pragma once\n#ifdef PADDLE_WITH_CUTLASS\n#include \"paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h\"\n" self.entry_template = "" self.configuration_prototype_template = "" self.configuration_template = "" - self.epilogue_template = "#endif" + self.namespace_template = """ +namespace phi { +namespace sparse { +""" + self.epilogue_template = """ +} // namespace sparse +} // namespace phi +#endif +""" + self.fp16_kernels_list = ( + "static std::vector fp16_kernels = {\n" + ) + self.fp32_kernels_list = ( + "static std::vector fp32_kernels = {\n" + ) def __enter__(self): self.operation_path = os.path.join( @@ -64,6 +78,21 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary): self.source_files.append(configuration_emitter.configuration_path) self.configurations.append(configuration_name) + if 'h' == operations[0].short_math_name(): + self.fp16_kernels_list += ( + """ +launchKernel<""" + + configuration_name + + "::Gemm>," + ) + if 's' == operations[0].short_math_name(): + self.fp32_kernels_list += ( + """ +launchKernel<""" + + configuration_name + + "::Gemm>," + ) + self.top_level_file.write( '#include "' + self.operation_path @@ -72,6 +101,30 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary): + '.h"\n' ) + def __exit__(self, exception_type, exception_value, traceback): + self.top_level_file.write( + SubstituteTemplate( + self.entry_template, + {'operation_name': OperationKindNames[self.kind]}, + ) + ) + + for configuration_name in self.configurations: + self.top_level_file.write( + SubstituteTemplate( + self.configuration_template, + {'configuration_name': configuration_name}, + ) + ) + + self.fp16_kernels_list += "\n};\n" + self.fp32_kernels_list += "\n};\n" + self.top_level_file.write(self.namespace_template) + self.top_level_file.write(self.fp16_kernels_list) + self.top_level_file.write(self.fp32_kernels_list) + self.top_level_file.write(self.epilogue_template) + self.top_level_file.close() + class GatherGemmScatterManifest(Manifest): def emit(self, target=GeneratorTarget.Library): diff --git a/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_operation.py b/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_operation.py similarity index 95% rename from paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_operation.py rename to paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_operation.py index 1b2296b8c44541fe3932ef73af06a3fb0d7a2337..b14036c55ddc2577ee4ccbf0b462b1c278b163e3 100644 --- a/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_operation.py +++ b/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_operation.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import collections import enum import os.path @@ -40,16 +41,7 @@ from library import ( class EmitGatherGemmScatterInstance(EmitGemmInstance): def __init__(self, operation_suffix=''): self.operation_suffix = operation_suffix - self.includes = [ - "cutlass/cutlass.h", - "cutlass/numeric_types.h", - "cutlass/arch/arch.h", - "cutlass/arch/mma.h", - "cutlass/layout/matrix.h", - "cutlass/gemm/device/gemm.h", - "cutlass/gemm/device/gemm_universal_adapter.h", - "cutlass/gemm/kernel/default_gemm_universal.h", - ] + self.includes = [] self.builtin_epilogue_functor_template = """ ${epilogue_functor}< ${element_c}, @@ -247,6 +239,18 @@ namespace sparse { #endif """ + def __enter__(self): + self.configuration_file = open(self.configuration_path, "w") + self.configuration_file.write(self.header_template) + self.configuration_file.write(self.separator) + + self.includes = collections.OrderedDict([]) + self.instance_definitions = [] + self.instance_wrappers = [] + + self.operations = [] + return self + def __exit__(self, exception_type, exception_value, traceback): # Write includes diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu deleted file mode 100644 index cfbaa7f1d63068db4942a1102db9b95b13649c56..0000000000000000000000000000000000000000 --- a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.cu +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifdef PADDLE_WITH_CUTLASS -#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h" -namespace phi { -namespace sparse { -fp16_gather_gemm_scatter getBestFp16Kernel(const int M, - const int N, - const int K) { - if (K == 4 && N == 16) { - return launchKernel; - } - if (K == 16 && N == 16) { - return launchKernel; - } - if (K == 16 && N == 32) { - return launchKernel; - } - if (K == 32 && N == 32) { - return launchKernel; - } - if (K == 32 && N == 64) { - return launchKernel; - } - if (K == 64 && N == 64) { - if (M > 100000) - launchKernel< - cutlass::half_t, - cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8::Gemm>; - if (M > 20000) - launchKernel< - cutlass::half_t, - cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8::Gemm>; - if (M > 15000) - return launchKernel< - cutlass::half_t, - cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8::Gemm>; - return launchKernel; - } - if (K == 128) { - if (M >= 5000) - return launchKernel< - cutlass::half_t, - cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8::Gemm>; - return launchKernel; - } - if (N == 128) { - return launchKernel; - } - return launchKernel; -} -fp32_gather_gemm_scatter getBestFp32Kernel(const int M, - const int N, - const int K, - const int SM) { - if (SM == 75) { - return launchKernel< - float, - cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4::Gemm>; - } - if (K == 4 && N == 16) { - return launchKernel< - float, - cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; - } - if (K == 16 && N == 16) { - return launchKernel< - float, - cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; - } - if (K == 16 && N == 32) { - if (M >= 10000) - return launchKernel< - float, - cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>; - return launchKernel< - float, - cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; - } - if (K == 32 && N == 32) { - if (M >= 10000) - return launchKernel< - float, - cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>; - return launchKernel< - float, - cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; - } - if (K == 32 && N == 64) { - if (M >= 10000) - return launchKernel< - float, - cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>; - return launchKernel< - float, - cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; - } - if (K == 64 && N == 64) { - if (M >= 15000) - return launchKernel< - float, - cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4::Gemm>; - return launchKernel< - float, - cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; - } - if (K == 128) { - if (M >= 100000) - return launchKernel< - float, - cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>; - if (M >= 5000) - return launchKernel< - float, - cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4::Gemm>; - return launchKernel< - float, - cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>; - } - if (N == 128) { - if (M >= 100000) - return launchKernel< - float, - cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4::Gemm>; - if (M >= 5000) - return launchKernel< - float, - cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4::Gemm>; - return launchKernel< - float, - cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4::Gemm>; - } - return launchKernel< - float, - cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4::Gemm>; -} -fp64_gather_gemm_scatter getBestFp64Kernel(const int M, - const int N, - const int K) { - if (K == 4 && N == 16) { - return launchKernel; - } - if (K == 16 && N == 16) { - if (M >= 10000) - return launchKernel; - return launchKernel; - } - if (K == 16 && N == 32) { - return launchKernel; - } - if (K == 32 && N == 32) { - return launchKernel; - } - if (K == 32 && N == 64) { - return launchKernel; - } - if (K == 64 && N == 64) { - return launchKernel; - } - return launchKernel; -} - -} // namespace sparse -} // namespace phi -#endif diff --git a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h index dab35ed47737a6e2e1dcf2e40b924ad1a8a8645c..158a875b5403c0e4241fe050a3817e13fb752b8e 100644 --- a/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h +++ b/paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h @@ -13,628 +13,75 @@ // limitations under the License. #pragma once +#include #ifdef PADDLE_WITH_CUTLASS -#include "cutlass/arch/mma.h" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/device/gemm_universal.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/util/device_memory.h" -#include "examples/common/helper.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/common/data_type.h" +#include "paddle/phi/kernels/autotune/auto_tune_base.h" +#include "paddle/phi/kernels/sparse/gpu/cutlass_generator/build/generated/gemm/all_gemm_operations.h" + namespace phi { namespace sparse { -typedef void (*fp16_gather_gemm_scatter)(const GPUContext& dev_ctx, - const cutlass::half_t* const a, - const cutlass::half_t* const b, - const cutlass::half_t* const c, - cutlass::half_t* const d, - const int m, - const int n, - const int k, - const int32_t* a_indices, - const int32_t* c_d_indices, - cutlass::half_t const alpha, - cutlass::half_t const beta); -typedef void (*fp32_gather_gemm_scatter)(const GPUContext& dev_ctx, - const float* const a, - const float* const b, - const float* const c, - float* const d, - const int m, - const int n, - const int k, - const int32_t* a_indices, - const int32_t* c_d_indices, - float const alpha, - float const beta); -typedef void (*fp64_gather_gemm_scatter)(const GPUContext& dev_ctx, - const double* const a, - const double* const b, - const double* const c, - double* const d, - const int m, - const int n, - const int k, - const int32_t* a_indices, - const int32_t* c_d_indices, - double const alpha, - double const beta); -fp16_gather_gemm_scatter getBestFp16Kernel(const int M, - const int K, - const int N); -fp32_gather_gemm_scatter getBestFp32Kernel(const int M, - const int K, - const int N, - const int SM); -fp64_gather_gemm_scatter getBestFp64Kernel(const int M, - const int K, - const int N); -template -void launchKernel(const GPUContext& dev_ctx, - const T* const a, - const T* const b, - const T* const c, - T* const d, - const int m, - const int n, - const int k, - const int32_t* a_indices, - const int32_t* c_d_indices, - T const alpha, - T const beta) { - cutlass::gemm::GemmCoord problem_size_real({m, n, k}); - int split_k_slices = 1; - typename Gemm::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size_real, - split_k_slices, - {alpha, beta}, - a, - b, - c, - d, - cutlass::layout::RowMajor().capacity(problem_size_real.mk()), - cutlass::layout::RowMajor().capacity(problem_size_real.kn()), - cutlass::layout::RowMajor().capacity(problem_size_real.mn()), - cutlass::layout::RowMajor().capacity(problem_size_real.mn()), - problem_size_real.k(), - problem_size_real.n(), - problem_size_real.n(), - problem_size_real.n(), - a_indices, - nullptr, - c_d_indices}; - size_t workspace_size = Gemm::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - Gemm gemm_op; - cutlass::Status status = gemm_op.can_implement(arguments); - CUTLASS_CHECK(status); - status = gemm_op.initialize(arguments, workspace.get()); - CUTLASS_CHECK(status); - gemm_op(dev_ctx.stream()); -} -static void dispatchKernel(const GPUContext& dev_ctx, - const void* const a, - const void* const b, - const void* const c, - void* const d, - const int m, - const int n, - const int k, - const void* a_indices, - const void* c_d_indices, - const bool cutlass, - const phi::DataType type) { - if (!cutlass) return; - if (type == phi::DataType::FLOAT16) { - fp16_gather_gemm_scatter gather_gemm_scatter = getBestFp16Kernel(m, n, k); - gather_gemm_scatter(dev_ctx, - static_cast(a), - static_cast(b), - static_cast(c), - static_cast(d), - m, - n, - k, - static_cast(a_indices), - static_cast(c_d_indices), - static_cast(1), - static_cast(1)); - } else if (type == phi::DataType::FLOAT32) { - fp32_gather_gemm_scatter gather_gemm_scatter = - getBestFp32Kernel(m, n, k, dev_ctx.GetComputeCapability()); - gather_gemm_scatter(dev_ctx, - static_cast(a), - static_cast(b), - static_cast(c), - static_cast(d), - m, - n, - k, - static_cast(a_indices), - static_cast(c_d_indices), - static_cast(1), - static_cast(1)); - } else if (type == phi::DataType::FLOAT64) { - fp64_gather_gemm_scatter gather_gemm_scatter = getBestFp64Kernel(m, n, k); - gather_gemm_scatter(dev_ctx, - static_cast(a), - static_cast(b), - static_cast(c), - static_cast(d), - m, - n, - k, - static_cast(a_indices), - static_cast(c_d_indices), - static_cast(1), - static_cast(1)); +// To reduce tuning time, map shape (m,n,k) to (m/features_num_range,n,k) so +// that shapes in this range share the same key. +constexpr int features_num_range = 10000; + +#define DEFINE_GATHER_GEMM_SCATTER_DRIVER(dtype, kernels) \ + template \ + typename std::enable_if::value && \ + std::is_same::value, \ + void>::type \ + GatherGemmScatterDriver(const phi::GPUContext& ctx, \ + const T* const a, \ + const T* const b, \ + const T* const c, \ + T* const d, \ + const int& m, \ + const int& n, \ + const int& k, \ + const IntT* a_indices, \ + const IntT* c_d_indices, \ + T alpha, \ + T beta) { \ + auto* tuner = autotune::MakeGatherGemmScatterTuner(kernels[0]); \ + for (auto i = 1; i < kernels.size(); i++) tuner->AddCallBack(kernels[i]); \ + size_t key = autotune::GenKey(m / features_num_range, n, k); \ + tuner->Run(ctx, \ + key, \ + alpha, \ + beta, \ + ctx, \ + a, \ + b, \ + c, \ + d, \ + m, \ + n, \ + k, \ + a_indices, \ + c_d_indices); \ } -} -struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align8 { - using Gemm = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm75, - cutlass::gemm::GemmShape<128, 64, 32>, - cutlass::gemm::GemmShape<64, 32, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 2, - 8, - 8, - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_h1688gemm_64x128_32x2_nn_align8 { - using Gemm = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm75, - cutlass::gemm::GemmShape<64, 128, 32>, - cutlass::gemm::GemmShape<32, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 2, - 8, - 8, - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_h1688gemm_128x64_32x2_nn_align4 { - using Gemm = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm75, - cutlass::gemm::GemmShape<128, 64, 32>, - cutlass::gemm::GemmShape<64, 32, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 2, - 4, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align4 { - using Gemm = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm75, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<32, 32, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 2, - 4, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_h1688gemm_64x64_32x2_nn_align8 { - using Gemm = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm75, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<32, 32, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 2, - 8, - 8, - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_h16816gemm_64x64_64x5_nn_align8 { - using Gemm = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<32, 32, 64>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 5, - 8, - 8, - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_f16_s1688gemm_f16_64x128_32x2_nn_align8 { - using Gemm = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm75, - cutlass::gemm::GemmShape<64, 128, 32>, - cutlass::gemm::GemmShape<32, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread:: - LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 2, - 8, - 8, - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_f16_s1688gemm_f16_64x64_32x2_nn_align8 { - using Gemm = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm75, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<32, 32, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread:: - LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 2, - 8, - 8, - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_s1688f16gemm_64x64_16x10_nn_align4 { - using Gemm = cutlass::gemm::device::GemmUniversal< - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 64, 16>, - cutlass::gemm::GemmShape<32, 32, 16>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 10, - 4, - 4, - cutlass::arch::OpMultiplyAddFastF16, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_s1688f16gemm_128x128_16x3_nn_align4 { - using Gemm = cutlass::gemm::device::GemmUniversal< - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 16>, - cutlass::gemm::GemmShape<64, 64, 16>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 3, - 4, - 4, - cutlass::arch::OpMultiplyAddFastF16, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_s1688f16gemm_256x64_16x4_nn_align4 { - using Gemm = cutlass::gemm::device::GemmUniversal< - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 64, 16>, - cutlass::gemm::GemmShape<64, 64, 16>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 4, - 4, - 4, - cutlass::arch::OpMultiplyAddFastF16, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_s1688tf32gemm_256x128_16x3_nn_align4 { - using Gemm = cutlass::gemm::device::GemmUniversal< - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<256, 128, 16>, - cutlass::gemm::GemmShape<64, 64, 16>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 3, - 4, - 4, - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_s1688f16gemm_64x128_16x6_nn_align4 { - using Gemm = cutlass::gemm::device::GemmUniversal< - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 128, 16>, - cutlass::gemm::GemmShape<32, 64, 16>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 6, - 4, - 4, - cutlass::arch::OpMultiplyAddFastF16, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_s1688gemm_64x64_16x3_nn_align4 { - using Gemm = cutlass::gemm::device::GemmUniversal< - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<64, 64, 16>, - cutlass::gemm::GemmShape<32, 32, 16>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 3, - 4, - 4, - cutlass::arch::OpMultiplyAddFastF32, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_d884gemm_16x32_16x5_nn_align1 { - using Gemm = cutlass::gemm::device::GemmUniversal< - double, - cutlass::layout::RowMajor, - double, - cutlass::layout::RowMajor, - double, - cutlass::layout::RowMajor, - double, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<16, 32, 16>, - cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::gemm::GemmShape<8, 8, 4>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 5, - 1, - 1, - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; -struct cutlass_tensorop_d884gemm_32x16_16x5_nn_align1 { - using Gemm = cutlass::gemm::device::GemmUniversal< - double, - cutlass::layout::RowMajor, - double, - cutlass::layout::RowMajor, - double, - cutlass::layout::RowMajor, - double, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<32, 16, 16>, - cutlass::gemm::GemmShape<16, 16, 16>, - cutlass::gemm::GemmShape<8, 8, 4>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 5, - 1, - 1, - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, - cutlass::ComplexTransform::kNone, - true, - false, - true>; -}; +template +typename std::enable_if::value || + !std::is_same::value, + void>::type +GatherGemmScatterDriver(const phi::GPUContext& ctx, + const T* const a, + const T* const b, + const T* const c, + T* const d, + const int& m, + const int& n, + const int& k, + const IntT* a_indices, + const IntT* c_d_indices, + T alpha, + T beta) {} -// sm75 -struct cutlass_tensorop_s1688gemm_f16_64x64_32x2_nn_align4 { - using Gemm = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, - cutlass::layout::RowMajor, - cutlass::half_t, - cutlass::layout::RowMajor, - float, - cutlass::layout::RowMajor, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm75, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<32, 32, 32>, - cutlass::gemm::GemmShape<16, 8, 8>, - cutlass::epilogue::thread::LinearCombination, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 2, - 8, - 8, - cutlass::arch::OpMultiplyAdd>; -}; +DEFINE_GATHER_GEMM_SCATTER_DRIVER(phi::dtype::float16, fp16_kernels) +DEFINE_GATHER_GEMM_SCATTER_DRIVER(float, fp32_kernels) } // namespace sparse } // namespace phi