diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index eee868900b585445721195d13af06a5640e43b5b..822b71c71d31db07b79a2a72eb1cf977c7c1ab07 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -34,7 +34,14 @@ ExternalProject_Add( PREFIX ${CUTLASS_PREFIX_DIR} UPDATE_COMMAND "" CONFIGURE_COMMAND "" - BUILD_COMMAND "" + BUILD_COMMAND + mkdir -p + ${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/build/generated/gemm + && ${PYTHON_EXECUTABLE} -B + ${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py + "${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/library/scripts/" + "${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/build" + "${CMAKE_CUDA_COMPILER_VERSION}" INSTALL_COMMAND "" TEST_COMMAND "") diff --git a/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py b/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..49c6e5975181a0e50a7f87ccf2c6fe16a8cea669 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py @@ -0,0 +1,250 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +sys.path.append(sys.argv[1]) +from gather_gemm_scatter_manifest import GatherGemmScatterManifest +from gather_gemm_scatter_operation import GatherGemmScatterOperation +from generator import ( + ComplexTransform, + CudaToolkitVersionSatisfies, + EpilogueFunctor, + GemmKind, + SwizzlingFunctor, + TensorDescription, +) +from library import ( + DataType, + LayoutType, + MathInstruction, + MathOperation, + OpcodeClass, + TileDescription, +) +from manifest import GeneratorTarget + + +def CreateGatherGemmScatterOperator( + manifest, + layouts, + tile_descriptions, + data_type, + alignment_constraints, + complex_transforms=None, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, +): + # To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK` + + if complex_transforms is None: + complex_transforms = [ + (ComplexTransform.none, ComplexTransform.none), + ] + + element_a, element_b, element_c, element_epilogue = data_type + + operations = [] + + # by default, only generate the largest tile and largest alignment + # if manifest.kernel_filter == '': + # tile_descriptions = [tile_descriptions[0],] + # alignment_constraints = [alignment_constraints[0],] + + for layout in layouts: + for tile_description in tile_descriptions: + for alignment in alignment_constraints: + for complex_transform in complex_transforms: + + alignment_c = min(8, alignment) + + A = TensorDescription( + element_a, layout[0], alignment, complex_transform[0] + ) + B = TensorDescription( + element_b, layout[1], alignment, complex_transform[1] + ) + C = TensorDescription(element_c, layout[2], alignment_c) + + new_operation = GatherGemmScatterOperation( + GemmKind.Universal, + tile_description.minimum_compute_capability, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor, + swizzling_functor, + ) + + manifest.append(new_operation) + operations.append(new_operation) + + return operations + + +def GenerateSM70_TensorOp_884(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 10, 1): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor), + ] + + math_instructions = [ + MathInstruction( + [8, 8, 4], + DataType.f16, + DataType.f16, + DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ), + MathInstruction( + [8, 8, 4], + DataType.f16, + DataType.f16, + DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add, + ), + ] + + min_cc = 70 + max_cc = 75 + + alignment_constraints = [8, 4, 2, 1] + + for math_inst in math_instructions: + tile_descriptions = [ + TileDescription( + [256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + TileDescription( + [64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc + ), + ] + + data_type = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_accumulator, + math_inst.element_accumulator, + ] + + CreateGatherGemmScatterOperator( + manifest, + layouts, + tile_descriptions, + data_type, + alignment_constraints, + ) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + CreateGatherGemmScatterOperator( + manifest, + layouts, + tile_descriptions, + data_type_mixed, + alignment_constraints, + ) + + +def GenerateSM70(manifest, cuda_version): + GenerateSM70_TensorOp_884(manifest, cuda_version) + + +class KernelCfg: + def __init__( + self, + architectures, + build_dir, + cuda_version, + curr_build_dir, + disable_full_archs_compilation, + filter_by_cc, + generator_target, + ignore_kernels, + interface_dir, + kernel_filter_file, + kernels, + operations, + selected_kernel_list, + ): + self.architectures = architectures + self.build_dir = build_dir + self.cuda_version = cuda_version + self.curr_build_dir = curr_build_dir + self.disable_full_archs_compilation = disable_full_archs_compilation + self.filter_by_cc = filter_by_cc + self.generator_target = generator_target + self.ignore_kernels = ignore_kernels + self.interface_dir = interface_dir + self.kernel_filter_file = kernel_filter_file + self.kernels = kernels + self.operations = operations + self.selected_kernel_list = selected_kernel_list + + +if __name__ == "__main__": + + args = KernelCfg( + architectures='70', + build_dir=sys.argv[2], + cuda_version=sys.argv[3], + curr_build_dir=sys.argv[2], + disable_full_archs_compilation=False, + filter_by_cc='True', + generator_target='library', + ignore_kernels='', + interface_dir=None, + kernel_filter_file=None, + kernels='', + operations='all', + selected_kernel_list=None, + ) + manifest = GatherGemmScatterManifest(args) + + GenerateSM70(manifest, args.cuda_version) + + manifest.emit(GeneratorTarget.Library) diff --git a/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_manifest.py b/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb48af13a1208789d938305c8022d4f8a4fca6e --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_manifest.py @@ -0,0 +1,101 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil + +from gather_gemm_scatter_operation import ( + EmitGatherGemmScatterConfigurationLibrary, +) +from library import OperationKind, OperationKindNames +from manifest import EmitOperationKindLibrary, GeneratorTarget, Manifest + + +class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary): + def __init__(self, generated_path, kind, args): + super().__init__(generated_path, kind, args) + self.emitters = { + OperationKind.Gemm: EmitGatherGemmScatterConfigurationLibrary + } + self.header_template = "#pragma once\n#ifdef PADDLE_WITH_CUTLASS\n" + self.entry_template = "" + self.configuration_prototype_template = "" + self.configuration_template = "" + self.epilogue_template = "#endif" + + def __enter__(self): + self.operation_path = os.path.join( + self.generated_path, OperationKindNames[self.kind] + ) + os.mkdir(self.operation_path) + + self.top_level_path = os.path.join( + self.operation_path, + "all_%s_operations.h" % OperationKindNames[self.kind], + ) + + self.top_level_file = open(self.top_level_path, "w") + self.top_level_file.write(self.header_template) + + self.source_files = [ + self.top_level_path, + ] + + return self + + def emit(self, configuration_name, operations): + with self.emitters[self.kind]( + self.operation_path, configuration_name + ) as configuration_emitter: + for operation in operations: + configuration_emitter.emit(operation) + + self.source_files.append(configuration_emitter.configuration_path) + + self.configurations.append(configuration_name) + self.top_level_file.write( + '#include "' + + self.operation_path + + '/' + + configuration_name + + '.h"\n' + ) + + +class GatherGemmScatterManifest(Manifest): + def emit(self, target=GeneratorTarget.Library): + + operation_emitters = { + GeneratorTarget.Library: GatherGemmScatterEmitOperationKindLibrary + } + + generated_path = os.path.join(self.curr_build_dir, 'generated') + + # create generated/ + if os.path.exists(generated_path): + shutil.rmtree(generated_path) + + os.mkdir(generated_path) + + source_files = [] + + # for each operation kind, emit initializer for all configurations + for operation_kind, configurations in self.operations.items(): + with operation_emitters[target]( + generated_path, operation_kind, self.args + ) as operation_kind_emitter: + for configuration_name, operations in configurations.items(): + operation_kind_emitter.emit(configuration_name, operations) + + source_files += operation_kind_emitter.source_files diff --git a/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_operation.py b/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..1b2296b8c44541fe3932ef73af06a3fb0d7a2337 --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_operation.py @@ -0,0 +1,320 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import os.path + +from gemm_operation import ( + EmitGemmConfigurationLibrary, + EmitGemmInstance, + EpilogueFunctor, + GemmOperation, + SwizzlingFunctor, +) +from library import ( + ComplexTransformTag, + DataTypeSize, + DataTypeTag, + EpilogueFunctorTag, + GemmKind, + LayoutTag, + LayoutType, + MathOperationTag, + OpcodeClassTag, + SubstituteTemplate, + SwizzlingFunctorTag, +) + + +class EmitGatherGemmScatterInstance(EmitGemmInstance): + def __init__(self, operation_suffix=''): + self.operation_suffix = operation_suffix + self.includes = [ + "cutlass/cutlass.h", + "cutlass/numeric_types.h", + "cutlass/arch/arch.h", + "cutlass/arch/mma.h", + "cutlass/layout/matrix.h", + "cutlass/gemm/device/gemm.h", + "cutlass/gemm/device/gemm_universal_adapter.h", + "cutlass/gemm/kernel/default_gemm_universal.h", + ] + self.builtin_epilogue_functor_template = """ + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + > +""" + self.gemm_template = """ +// Gemm operator ${operation_name} +struct ${operation_name} { + using Gemm = + cutlass::gemm::device::GemmUniversal< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}, + ${swizzling_functor}, + ${stages}, + ${align_a}, + ${align_b}, + ${math_operation}, + ${transform_a}, + ${transform_b}, + true, // gather a + false, // gather b + true // scatter d + >; +}; +""" + + def instance_template(self): + return "" + + def emit(self, operation): + + threadblock_shape = operation.tile_description.threadblock_shape + warp_count = operation.tile_description.warp_count + + warp_shape = [ + threadblock_shape[idx] // warp_count[idx] for idx in range(3) + ] + + transpose_layouts = { + LayoutType.ColumnMajor: LayoutType.ColumnMajor, + LayoutType.RowMajor: LayoutType.RowMajor, + } + + if ( + operation.A.layout in transpose_layouts.keys() + and operation.B.layout in transpose_layouts.keys() + and operation.C.layout in transpose_layouts.keys() + ): + + instance_layout_A = transpose_layouts[operation.A.layout] + instance_layout_B = transpose_layouts[operation.B.layout] + instance_layout_C = transpose_layouts[operation.C.layout] + + gemm_template = self.gemm_template + else: + instance_layout_A, instance_layout_B, instance_layout_C = ( + operation.A.layout, + operation.B.layout, + operation.C.layout, + ) + + gemm_template = self.gemm_template_interleaved + + # Support built-in epilogue functors or user-defined functions + if isinstance(operation.epilogue_functor, enum.Enum): + + epilogue_vector_length = ( + min( + operation.C.alignment * DataTypeSize[operation.C.element], + 128, + ) + // DataTypeSize[operation.C.element] + ) + + values = { + 'epilogue_vector_length': str(epilogue_vector_length), + 'element_epilogue': str( + DataTypeTag[operation.element_epilogue] + ), + 'epilogue_functor': EpilogueFunctorTag[ + operation.epilogue_functor + ], + } + epilogue_functor = SubstituteTemplate( + self.builtin_epilogue_functor_template, values + ) + else: + epilogue_functor = self.epilogue_functor.emit_declaration() + + values = { + 'operation_name': operation.procedural_name(), + 'operation_suffix': self.operation_suffix, + 'element_a': DataTypeTag[operation.A.element], + 'layout_a': LayoutTag[instance_layout_A], + 'element_b': DataTypeTag[operation.B.element], + 'layout_b': LayoutTag[instance_layout_B], + 'element_c': DataTypeTag[operation.C.element], + 'layout_c': LayoutTag[instance_layout_C], + 'element_accumulator': DataTypeTag[operation.accumulator_type()], + 'opcode_class': OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + 'arch': "cutlass::arch::Sm%d" % operation.arch, + 'threadblock_shape_m': str( + operation.tile_description.threadblock_shape[0] + ), + 'threadblock_shape_n': str( + operation.tile_description.threadblock_shape[1] + ), + 'threadblock_shape_k': str( + operation.tile_description.threadblock_shape[2] + ), + 'warp_shape_m': str(warp_shape[0]), + 'warp_shape_n': str(warp_shape[1]), + 'warp_shape_k': str(warp_shape[2]), + 'instruction_shape_m': str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + 'instruction_shape_n': str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + 'instruction_shape_k': str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + 'epilogue_functor': epilogue_functor, + 'swizzling_functor': SwizzlingFunctorTag[ + operation.swizzling_functor + ], + 'stages': str(operation.tile_description.stages), + 'align_a': str(operation.A.alignment), + 'align_b': str(operation.B.alignment), + 'transform_a': ComplexTransformTag[operation.A.complex_transform], + 'transform_b': ComplexTransformTag[operation.B.complex_transform], + 'math_operation': MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + } + + return SubstituteTemplate(gemm_template, values) + + +class EmitGatherGemmScatterConfigurationLibrary(EmitGemmConfigurationLibrary): + def __init__(self, operation_path, configuration_name): + self.configuration_name = configuration_name + self.configuration_path = os.path.join( + operation_path, "%s.h" % configuration_name + ).replace('\\', '/') + + self.instance_emitter = { + GemmKind.Universal: EmitGatherGemmScatterInstance, + } + + self.gemm_kind_wrappers = { + GemmKind.Universal: 'GemmUniversalOperation', + } + + self.wmma_guard_start = ( + "#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)" + ) + + self.separator = """ +/////////////////////////////////////////////////////////////////////////////////////////////////// + +""" + + self.header_template = """ +/* + Generated by gemm_operation.py - Do not edit. +*/ +#pragma once +#ifdef PADDLE_WITH_CUTLASS +""" + + self.namespace_template = """ +namespace phi { +namespace sparse { +""" + self.epilogue_template = """ +} // namespace sparse +} // namespace phi +#endif +""" + + def __exit__(self, exception_type, exception_value, traceback): + + # Write includes + for incl, _ in self.includes.items(): + include_statement = "#include \"%s\"\n" % incl + self.configuration_file.write(include_statement) + + self.configuration_file.write(self.separator) + self.configuration_file.write(self.namespace_template) + + # Write instance definitions in top-level namespace + for instance_definition in self.instance_definitions: + self.configuration_file.write(instance_definition) + + for instance_wrapper in self.instance_wrappers: + self.configuration_file.write(instance_wrapper) + + self.configuration_file.write(self.epilogue_template) + self.configuration_file.close() + + +class GatherGemmScatterOperation(GemmOperation): + # cutlass transpose A and B in the library.py, so we transpose it back here + def __init__( + self, + gemm_kind, + arch, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, + ): + + super().__init__( + gemm_kind, + arch, + tile_description, + A, + B, + C, + element_epilogue, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity8, + ) + self.ShortLayoutTypeNames = { + LayoutType.ColumnMajor: 't', + LayoutType.ColumnMajorInterleaved2: 't2', + LayoutType.ColumnMajorInterleaved32: 't32', + LayoutType.ColumnMajorInterleaved64: 't64', + LayoutType.RowMajor: 'n', + LayoutType.RowMajorInterleaved2: 'n2', + LayoutType.RowMajorInterleaved32: 'n32', + LayoutType.RowMajorInterleaved64: 'n64', + LayoutType.TensorNHWC: 'nhwc', + LayoutType.TensorNDHWC: 'ndhwc', + LayoutType.TensorNCHW: 'nchw', + LayoutType.TensorNGHWC: 'nghwc', + LayoutType.TensorNC32HW32: 'nc32hw32', + LayoutType.TensorNC64HW64: 'nc64hw64', + LayoutType.TensorC32RSK32: 'c32rsk32', + LayoutType.TensorC64RSK64: 'c64rsk64', + } + + def layout_name(self): + return "%s%s" % ( + self.ShortLayoutTypeNames[self.A.layout], + self.ShortLayoutTypeNames[self.B.layout], + )