未验证 提交 507af1c8 编写于 作者: U umiswing 提交者: GitHub

Add generator scripts for cutlass (#50364)

上级 c92b1c54
......@@ -34,7 +34,14 @@ ExternalProject_Add(
PREFIX ${CUTLASS_PREFIX_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
BUILD_COMMAND
mkdir -p
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/build/generated/gemm
&& ${PYTHON_EXECUTABLE} -B
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/gather_gemm_scatter_generator.py
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/library/scripts/"
"${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass/build"
"${CMAKE_CUDA_COMPILER_VERSION}"
INSTALL_COMMAND ""
TEST_COMMAND "")
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append(sys.argv[1])
from gather_gemm_scatter_manifest import GatherGemmScatterManifest
from gather_gemm_scatter_operation import GatherGemmScatterOperation
from generator import (
ComplexTransform,
CudaToolkitVersionSatisfies,
EpilogueFunctor,
GemmKind,
SwizzlingFunctor,
TensorDescription,
)
from library import (
DataType,
LayoutType,
MathInstruction,
MathOperation,
OpcodeClass,
TileDescription,
)
from manifest import GeneratorTarget
def CreateGatherGemmScatterOperator(
manifest,
layouts,
tile_descriptions,
data_type,
alignment_constraints,
complex_transforms=None,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity8,
):
# To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK`
if complex_transforms is None:
complex_transforms = [
(ComplexTransform.none, ComplexTransform.none),
]
element_a, element_b, element_c, element_epilogue = data_type
operations = []
# by default, only generate the largest tile and largest alignment
# if manifest.kernel_filter == '':
# tile_descriptions = [tile_descriptions[0],]
# alignment_constraints = [alignment_constraints[0],]
for layout in layouts:
for tile_description in tile_descriptions:
for alignment in alignment_constraints:
for complex_transform in complex_transforms:
alignment_c = min(8, alignment)
A = TensorDescription(
element_a, layout[0], alignment, complex_transform[0]
)
B = TensorDescription(
element_b, layout[1], alignment, complex_transform[1]
)
C = TensorDescription(element_c, layout[2], alignment_c)
new_operation = GatherGemmScatterOperation(
GemmKind.Universal,
tile_description.minimum_compute_capability,
tile_description,
A,
B,
C,
element_epilogue,
epilogue_functor,
swizzling_functor,
)
manifest.append(new_operation)
operations.append(new_operation)
return operations
def GenerateSM70_TensorOp_884(manifest, cuda_version):
if not CudaToolkitVersionSatisfies(cuda_version, 10, 1):
return
layouts = [
(LayoutType.RowMajor, LayoutType.RowMajor, LayoutType.RowMajor),
]
math_instructions = [
MathInstruction(
[8, 8, 4],
DataType.f16,
DataType.f16,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
MathInstruction(
[8, 8, 4],
DataType.f16,
DataType.f16,
DataType.f16,
OpcodeClass.TensorOp,
MathOperation.multiply_add,
),
]
min_cc = 70
max_cc = 75
alignment_constraints = [8, 4, 2, 1]
for math_inst in math_instructions:
tile_descriptions = [
TileDescription(
[256, 128, 32], 2, [4, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 256, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[256, 64, 32], 2, [4, 1, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 256, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 128, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
TileDescription(
[64, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc
),
]
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_accumulator,
]
CreateGatherGemmScatterOperator(
manifest,
layouts,
tile_descriptions,
data_type,
alignment_constraints,
)
# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
if math_inst.element_a != math_inst.element_accumulator:
data_type_mixed = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_a,
math_inst.element_accumulator,
]
CreateGatherGemmScatterOperator(
manifest,
layouts,
tile_descriptions,
data_type_mixed,
alignment_constraints,
)
def GenerateSM70(manifest, cuda_version):
GenerateSM70_TensorOp_884(manifest, cuda_version)
class KernelCfg:
def __init__(
self,
architectures,
build_dir,
cuda_version,
curr_build_dir,
disable_full_archs_compilation,
filter_by_cc,
generator_target,
ignore_kernels,
interface_dir,
kernel_filter_file,
kernels,
operations,
selected_kernel_list,
):
self.architectures = architectures
self.build_dir = build_dir
self.cuda_version = cuda_version
self.curr_build_dir = curr_build_dir
self.disable_full_archs_compilation = disable_full_archs_compilation
self.filter_by_cc = filter_by_cc
self.generator_target = generator_target
self.ignore_kernels = ignore_kernels
self.interface_dir = interface_dir
self.kernel_filter_file = kernel_filter_file
self.kernels = kernels
self.operations = operations
self.selected_kernel_list = selected_kernel_list
if __name__ == "__main__":
args = KernelCfg(
architectures='70',
build_dir=sys.argv[2],
cuda_version=sys.argv[3],
curr_build_dir=sys.argv[2],
disable_full_archs_compilation=False,
filter_by_cc='True',
generator_target='library',
ignore_kernels='',
interface_dir=None,
kernel_filter_file=None,
kernels='',
operations='all',
selected_kernel_list=None,
)
manifest = GatherGemmScatterManifest(args)
GenerateSM70(manifest, args.cuda_version)
manifest.emit(GeneratorTarget.Library)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
from gather_gemm_scatter_operation import (
EmitGatherGemmScatterConfigurationLibrary,
)
from library import OperationKind, OperationKindNames
from manifest import EmitOperationKindLibrary, GeneratorTarget, Manifest
class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
def __init__(self, generated_path, kind, args):
super().__init__(generated_path, kind, args)
self.emitters = {
OperationKind.Gemm: EmitGatherGemmScatterConfigurationLibrary
}
self.header_template = "#pragma once\n#ifdef PADDLE_WITH_CUTLASS\n"
self.entry_template = ""
self.configuration_prototype_template = ""
self.configuration_template = ""
self.epilogue_template = "#endif"
def __enter__(self):
self.operation_path = os.path.join(
self.generated_path, OperationKindNames[self.kind]
)
os.mkdir(self.operation_path)
self.top_level_path = os.path.join(
self.operation_path,
"all_%s_operations.h" % OperationKindNames[self.kind],
)
self.top_level_file = open(self.top_level_path, "w")
self.top_level_file.write(self.header_template)
self.source_files = [
self.top_level_path,
]
return self
def emit(self, configuration_name, operations):
with self.emitters[self.kind](
self.operation_path, configuration_name
) as configuration_emitter:
for operation in operations:
configuration_emitter.emit(operation)
self.source_files.append(configuration_emitter.configuration_path)
self.configurations.append(configuration_name)
self.top_level_file.write(
'#include "'
+ self.operation_path
+ '/'
+ configuration_name
+ '.h"\n'
)
class GatherGemmScatterManifest(Manifest):
def emit(self, target=GeneratorTarget.Library):
operation_emitters = {
GeneratorTarget.Library: GatherGemmScatterEmitOperationKindLibrary
}
generated_path = os.path.join(self.curr_build_dir, 'generated')
# create generated/
if os.path.exists(generated_path):
shutil.rmtree(generated_path)
os.mkdir(generated_path)
source_files = []
# for each operation kind, emit initializer for all configurations
for operation_kind, configurations in self.operations.items():
with operation_emitters[target](
generated_path, operation_kind, self.args
) as operation_kind_emitter:
for configuration_name, operations in configurations.items():
operation_kind_emitter.emit(configuration_name, operations)
source_files += operation_kind_emitter.source_files
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import os.path
from gemm_operation import (
EmitGemmConfigurationLibrary,
EmitGemmInstance,
EpilogueFunctor,
GemmOperation,
SwizzlingFunctor,
)
from library import (
ComplexTransformTag,
DataTypeSize,
DataTypeTag,
EpilogueFunctorTag,
GemmKind,
LayoutTag,
LayoutType,
MathOperationTag,
OpcodeClassTag,
SubstituteTemplate,
SwizzlingFunctorTag,
)
class EmitGatherGemmScatterInstance(EmitGemmInstance):
def __init__(self, operation_suffix=''):
self.operation_suffix = operation_suffix
self.includes = [
"cutlass/cutlass.h",
"cutlass/numeric_types.h",
"cutlass/arch/arch.h",
"cutlass/arch/mma.h",
"cutlass/layout/matrix.h",
"cutlass/gemm/device/gemm.h",
"cutlass/gemm/device/gemm_universal_adapter.h",
"cutlass/gemm/kernel/default_gemm_universal.h",
]
self.builtin_epilogue_functor_template = """
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>
"""
self.gemm_template = """
// Gemm operator ${operation_name}
struct ${operation_name} {
using Gemm =
cutlass::gemm::device::GemmUniversal<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accumulator},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k}>,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor},
${swizzling_functor},
${stages},
${align_a},
${align_b},
${math_operation},
${transform_a},
${transform_b},
true, // gather a
false, // gather b
true // scatter d
>;
};
"""
def instance_template(self):
return ""
def emit(self, operation):
threadblock_shape = operation.tile_description.threadblock_shape
warp_count = operation.tile_description.warp_count
warp_shape = [
threadblock_shape[idx] // warp_count[idx] for idx in range(3)
]
transpose_layouts = {
LayoutType.ColumnMajor: LayoutType.ColumnMajor,
LayoutType.RowMajor: LayoutType.RowMajor,
}
if (
operation.A.layout in transpose_layouts.keys()
and operation.B.layout in transpose_layouts.keys()
and operation.C.layout in transpose_layouts.keys()
):
instance_layout_A = transpose_layouts[operation.A.layout]
instance_layout_B = transpose_layouts[operation.B.layout]
instance_layout_C = transpose_layouts[operation.C.layout]
gemm_template = self.gemm_template
else:
instance_layout_A, instance_layout_B, instance_layout_C = (
operation.A.layout,
operation.B.layout,
operation.C.layout,
)
gemm_template = self.gemm_template_interleaved
# Support built-in epilogue functors or user-defined functions
if isinstance(operation.epilogue_functor, enum.Enum):
epilogue_vector_length = (
min(
operation.C.alignment * DataTypeSize[operation.C.element],
128,
)
// DataTypeSize[operation.C.element]
)
values = {
'epilogue_vector_length': str(epilogue_vector_length),
'element_epilogue': str(
DataTypeTag[operation.element_epilogue]
),
'epilogue_functor': EpilogueFunctorTag[
operation.epilogue_functor
],
}
epilogue_functor = SubstituteTemplate(
self.builtin_epilogue_functor_template, values
)
else:
epilogue_functor = self.epilogue_functor.emit_declaration()
values = {
'operation_name': operation.procedural_name(),
'operation_suffix': self.operation_suffix,
'element_a': DataTypeTag[operation.A.element],
'layout_a': LayoutTag[instance_layout_A],
'element_b': DataTypeTag[operation.B.element],
'layout_b': LayoutTag[instance_layout_B],
'element_c': DataTypeTag[operation.C.element],
'layout_c': LayoutTag[instance_layout_C],
'element_accumulator': DataTypeTag[operation.accumulator_type()],
'opcode_class': OpcodeClassTag[
operation.tile_description.math_instruction.opcode_class
],
'arch': "cutlass::arch::Sm%d" % operation.arch,
'threadblock_shape_m': str(
operation.tile_description.threadblock_shape[0]
),
'threadblock_shape_n': str(
operation.tile_description.threadblock_shape[1]
),
'threadblock_shape_k': str(
operation.tile_description.threadblock_shape[2]
),
'warp_shape_m': str(warp_shape[0]),
'warp_shape_n': str(warp_shape[1]),
'warp_shape_k': str(warp_shape[2]),
'instruction_shape_m': str(
operation.tile_description.math_instruction.instruction_shape[0]
),
'instruction_shape_n': str(
operation.tile_description.math_instruction.instruction_shape[1]
),
'instruction_shape_k': str(
operation.tile_description.math_instruction.instruction_shape[2]
),
'epilogue_functor': epilogue_functor,
'swizzling_functor': SwizzlingFunctorTag[
operation.swizzling_functor
],
'stages': str(operation.tile_description.stages),
'align_a': str(operation.A.alignment),
'align_b': str(operation.B.alignment),
'transform_a': ComplexTransformTag[operation.A.complex_transform],
'transform_b': ComplexTransformTag[operation.B.complex_transform],
'math_operation': MathOperationTag[
operation.tile_description.math_instruction.math_operation
],
}
return SubstituteTemplate(gemm_template, values)
class EmitGatherGemmScatterConfigurationLibrary(EmitGemmConfigurationLibrary):
def __init__(self, operation_path, configuration_name):
self.configuration_name = configuration_name
self.configuration_path = os.path.join(
operation_path, "%s.h" % configuration_name
).replace('\\', '/')
self.instance_emitter = {
GemmKind.Universal: EmitGatherGemmScatterInstance,
}
self.gemm_kind_wrappers = {
GemmKind.Universal: 'GemmUniversalOperation',
}
self.wmma_guard_start = (
"#if defined(CUTLASS_ARCH_WMMA_SM${sm_number}_ENABLED)"
)
self.separator = """
///////////////////////////////////////////////////////////////////////////////////////////////////
"""
self.header_template = """
/*
Generated by gemm_operation.py - Do not edit.
*/
#pragma once
#ifdef PADDLE_WITH_CUTLASS
"""
self.namespace_template = """
namespace phi {
namespace sparse {
"""
self.epilogue_template = """
} // namespace sparse
} // namespace phi
#endif
"""
def __exit__(self, exception_type, exception_value, traceback):
# Write includes
for incl, _ in self.includes.items():
include_statement = "#include \"%s\"\n" % incl
self.configuration_file.write(include_statement)
self.configuration_file.write(self.separator)
self.configuration_file.write(self.namespace_template)
# Write instance definitions in top-level namespace
for instance_definition in self.instance_definitions:
self.configuration_file.write(instance_definition)
for instance_wrapper in self.instance_wrappers:
self.configuration_file.write(instance_wrapper)
self.configuration_file.write(self.epilogue_template)
self.configuration_file.close()
class GatherGemmScatterOperation(GemmOperation):
# cutlass transpose A and B in the library.py, so we transpose it back here
def __init__(
self,
gemm_kind,
arch,
tile_description,
A,
B,
C,
element_epilogue,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity8,
):
super().__init__(
gemm_kind,
arch,
tile_description,
A,
B,
C,
element_epilogue,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity8,
)
self.ShortLayoutTypeNames = {
LayoutType.ColumnMajor: 't',
LayoutType.ColumnMajorInterleaved2: 't2',
LayoutType.ColumnMajorInterleaved32: 't32',
LayoutType.ColumnMajorInterleaved64: 't64',
LayoutType.RowMajor: 'n',
LayoutType.RowMajorInterleaved2: 'n2',
LayoutType.RowMajorInterleaved32: 'n32',
LayoutType.RowMajorInterleaved64: 'n64',
LayoutType.TensorNHWC: 'nhwc',
LayoutType.TensorNDHWC: 'ndhwc',
LayoutType.TensorNCHW: 'nchw',
LayoutType.TensorNGHWC: 'nghwc',
LayoutType.TensorNC32HW32: 'nc32hw32',
LayoutType.TensorNC64HW64: 'nc64hw64',
LayoutType.TensorC32RSK32: 'c32rsk32',
LayoutType.TensorC64RSK64: 'c64rsk64',
}
def layout_name(self):
return "%s%s" % (
self.ShortLayoutTypeNames[self.A.layout],
self.ShortLayoutTypeNames[self.B.layout],
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册