# # \file generator.py # # \brief Generates the CUTLASS Library's instances # import re ################################################################################################### import enum # The following block implements enum.auto() for Python 3.5 variants that don't include it such # as the default 3.5.2 on Ubuntu 16.04. # # https://codereview.stackexchange.com/questions/177309/reimplementing-pythons-enum-auto-for-compatibility try: from enum import auto as enum_auto except ImportError: __cutlass_library_auto_enum = 0 def enum_auto() -> int: global __cutlass_library_auto_enum i = __cutlass_library_auto_enum __cutlass_library_auto_enum += 1 return i ################################################################################################### # class GeneratorTarget(enum.Enum): Library = enum_auto() # GeneratorTargetNames = { GeneratorTarget.Library: 'library' } # ################################################################################################### # class DataType(enum.Enum): b1 = enum_auto() u4 = enum_auto() u8 = enum_auto() u16 = enum_auto() u32 = enum_auto() u64 = enum_auto() s4 = enum_auto() s8 = enum_auto() s16 = enum_auto() s32 = enum_auto() s64 = enum_auto() f16 = enum_auto() bf16 = enum_auto() f32 = enum_auto() tf32 = enum_auto() f64 = enum_auto() cf16 = enum_auto() cbf16 = enum_auto() cf32 = enum_auto() ctf32 = enum_auto() cf64 = enum_auto() cs4 = enum_auto() cs8 = enum_auto() cs16 = enum_auto() cs32 = enum_auto() cs64 = enum_auto() cu4 = enum_auto() cu8 = enum_auto() cu16 = enum_auto() cu32 = enum_auto() cu64 = enum_auto() invalid = enum_auto() # ShortDataTypeNames = { DataType.s32: 'i', DataType.f16: 'h', DataType.f32: 's', DataType.f64: 'd', DataType.cf32: 'c', DataType.cf64: 'z', } # DataTypeNames = { DataType.b1: "b1", DataType.u4: "u4", DataType.u8: "u8", DataType.u16: "u16", DataType.u32: "u32", DataType.u64: "u64", DataType.s4: "s4", DataType.s8: "s8", DataType.s16: "s16", DataType.s32: "s32", DataType.s64: "s64", DataType.f16: "f16", DataType.bf16: "bf16", DataType.f32: "f32", DataType.tf32: "tf32", DataType.f64: "f64", DataType.cf16: "cf16", DataType.cbf16: "cbf16", DataType.cf32: "cf32", DataType.ctf32: "ctf32", DataType.cf64: "cf64", DataType.cu4: "cu4", DataType.cu8: "cu8", DataType.cu16: "cu16", DataType.cu32: "cu32", DataType.cu64: "cu64", DataType.cs4: "cs4", DataType.cs8: "cs8", DataType.cs16: "cs16", DataType.cs32: "cs32", DataType.cs64: "cs64", } DataTypeTag = { DataType.b1: "cutlass::uint1b_t", DataType.u4: "cutlass::uint4b_t", DataType.u8: "uint8_t", DataType.u16: "uint16_t", DataType.u32: "uint32_t", DataType.u64: "uint64_t", DataType.s4: "cutlass::int4b_t", DataType.s8: "int8_t", DataType.s16: "int16_t", DataType.s32: "int32_t", DataType.s64: "int64_t", DataType.f16: "cutlass::half_t", DataType.bf16: "cutlass::bfloat16_t", DataType.f32: "float", DataType.tf32: "cutlass::tfloat32_t", DataType.f64: "double", DataType.cf16: "cutlass::complex", DataType.cbf16: "cutlass::complex", DataType.cf32: "cutlass::complex", DataType.ctf32: "cutlass::complex", DataType.cf64: "cutlass::complex", DataType.cu4: "cutlass::complex", DataType.cu8: "cutlass::complex", DataType.cu16: "cutlass::complex", DataType.cu32: "cutlass::complex", DataType.cu64: "cutlass::complex", DataType.cs4: "cutlass::complex", DataType.cs8: "cutlass::complex", DataType.cs16: "cutlass::complex", DataType.cs32: "cutlass::complex", DataType.cs64: "cutlass::complex", } DataTypeSize = { DataType.b1: 1, DataType.u4: 4, DataType.u8: 4, DataType.u16: 16, DataType.u32: 32, DataType.u64: 64, DataType.s4: 4, DataType.s8: 8, DataType.s16: 16, DataType.s32: 32, DataType.s64: 64, DataType.f16: 16, DataType.bf16: 16, DataType.f32: 32, DataType.tf32: 32, DataType.f64: 64, DataType.cf16: 32, DataType.cbf16: 32, DataType.cf32: 64, DataType.ctf32: 32, DataType.cf64: 128, DataType.cu4: 8, DataType.cu8: 16, DataType.cu16: 32, DataType.cu32: 64, DataType.cu64: 128, DataType.cs4: 8, DataType.cs8: 16, DataType.cs16: 32, DataType.cs32: 64, DataType.cs64: 128, } ################################################################################################### # class ComplexTransform(enum.Enum): none = enum_auto() conj = enum_auto() # ComplexTransformTag = { ComplexTransform.none: 'cutlass::ComplexTransform::kNone', ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', } # RealComplexBijection = [ (DataType.f16, DataType.cf16), (DataType.f32, DataType.cf32), (DataType.f64, DataType.cf64), ] # def is_complex(data_type): for r, c in RealComplexBijection: if data_type == c: return True return False # def get_complex_from_real(real_type): for r, c in RealComplexBijection: if real_type == r: return c return DataType.invalid # def get_real_from_complex(complex_type): for r, c in RealComplexBijection: if complex_type == c: return r return DataType.invalid # class ComplexMultiplyOp(enum.Enum): multiply_add = enum_auto() gaussian = enum_auto() ################################################################################################### # class MathOperation(enum.Enum): multiply_add = enum_auto() multiply_add_saturate = enum_auto() xor_popc = enum_auto() multiply_add_fast_bf16 = enum_auto() multiply_add_fast_f16 = enum_auto() multiply_add_complex = enum_auto() multiply_add_complex_gaussian = enum_auto() # MathOperationTag = { MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd', MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate', MathOperation.xor_popc: 'cutlass::arch::OpXorPopc', MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16', MathOperation.multiply_add_fast_f16: 'cutlass::arch::OpMultiplyAddFastF16', MathOperation.multiply_add_complex: 'cutlass::arch::OpMultiplyAddComplex', MathOperation.multiply_add_complex_gaussian: 'cutlass::arch::OpMultiplyAddGaussianComplex', } ################################################################################################### # class LayoutType(enum.Enum): ColumnMajor = enum_auto() RowMajor = enum_auto() ColumnMajorInterleaved2 = enum_auto() RowMajorInterleaved2 = enum_auto() ColumnMajorInterleaved32 = enum_auto() RowMajorInterleaved32 = enum_auto() ColumnMajorInterleaved64 = enum_auto() RowMajorInterleaved64 = enum_auto() TensorNHWC = enum_auto() TensorNDHWC = enum_auto() TensorNCHW = enum_auto() TensorNGHWC = enum_auto() TensorNC4HW4 = enum_auto() TensorC4RSK4 = enum_auto() TensorNC8HW8 = enum_auto() TensorNC16HW16 = enum_auto() TensorNC32HW32 = enum_auto() TensorNC64HW64 = enum_auto() TensorC32RSK32 = enum_auto() TensorC64RSK64 = enum_auto() TensorK4RSC4 = enum_auto() TensorCK4RS4 = enum_auto() TensorCK8RS8 = enum_auto() TensorCK16RS16 = enum_auto() # LayoutTag = { LayoutType.ColumnMajor: 'cutlass::layout::ColumnMajor', LayoutType.RowMajor: 'cutlass::layout::RowMajor', LayoutType.ColumnMajorInterleaved2: 'cutlass::layout::ColumnMajorInterleaved<2>', LayoutType.RowMajorInterleaved2: 'cutlass::layout::RowMajorInterleaved<2>', LayoutType.ColumnMajorInterleaved32: 'cutlass::layout::ColumnMajorInterleaved<32>', LayoutType.RowMajorInterleaved32: 'cutlass::layout::RowMajorInterleaved<32>', LayoutType.ColumnMajorInterleaved64: 'cutlass::layout::ColumnMajorInterleaved<64>', LayoutType.RowMajorInterleaved64: 'cutlass::layout::RowMajorInterleaved<64>', LayoutType.TensorNHWC: 'cutlass::layout::TensorNHWC', LayoutType.TensorNDHWC: 'cutlass::layout::TensorNDHWC', LayoutType.TensorNCHW: 'cutlass::layout::TensorNCHW', LayoutType.TensorNGHWC: 'cutlass::layout::TensorNGHWC', LayoutType.TensorNC4HW4: 'cutlass::layout::TensorNCxHWx<4>', LayoutType.TensorC4RSK4: 'cutlass::layout::TensorCxRSKx<4>', LayoutType.TensorNC8HW8: 'cutlass::layout::TensorNCxHWx<8>', LayoutType.TensorNC16HW16: 'cutlass::layout::TensorNCxHWx<16>', LayoutType.TensorNC32HW32: 'cutlass::layout::TensorNCxHWx<32>', LayoutType.TensorC32RSK32: 'cutlass::layout::TensorCxRSKx<32>', LayoutType.TensorNC64HW64: 'cutlass::layout::TensorNCxHWx<64>', LayoutType.TensorC64RSK64: 'cutlass::layout::TensorCxRSKx<64>', LayoutType.TensorK4RSC4: 'cutlass::layout::TensorKxRSCx<4>', LayoutType.TensorCK4RS4: 'cutlass::layout::TensorCKxRSx<4>', LayoutType.TensorCK8RS8: 'cutlass::layout::TensorCKxRSx<8>', LayoutType.TensorCK16RS16: 'cutlass::layout::TensorCKxRSx<16>', } # TransposedLayout = { LayoutType.ColumnMajor: LayoutType.RowMajor, LayoutType.RowMajor: LayoutType.ColumnMajor, LayoutType.ColumnMajorInterleaved2: LayoutType.RowMajorInterleaved2, LayoutType.RowMajorInterleaved2: LayoutType.ColumnMajorInterleaved2, LayoutType.ColumnMajorInterleaved32: LayoutType.RowMajorInterleaved32, LayoutType.RowMajorInterleaved32: LayoutType.ColumnMajorInterleaved32, LayoutType.ColumnMajorInterleaved64: LayoutType.RowMajorInterleaved64, LayoutType.RowMajorInterleaved64: LayoutType.ColumnMajorInterleaved64, LayoutType.TensorNHWC: LayoutType.TensorNHWC } # ShortLayoutTypeNames = { LayoutType.ColumnMajor: 'n', LayoutType.ColumnMajorInterleaved32: 'n2', LayoutType.ColumnMajorInterleaved32: 'n32', LayoutType.ColumnMajorInterleaved64: 'n64', LayoutType.RowMajor: 't', LayoutType.RowMajorInterleaved2: 't2', LayoutType.RowMajorInterleaved32: 't32', LayoutType.RowMajorInterleaved64: 't64', LayoutType.TensorNHWC: 'nhwc', LayoutType.TensorNDHWC: 'ndhwc', LayoutType.TensorNCHW: 'nchw', LayoutType.TensorNGHWC: 'nghwc', LayoutType.TensorNC4HW4: 'nc4hw4', LayoutType.TensorC4RSK4: 'c4rsk4', LayoutType.TensorNC8HW8: 'nc8hw8', LayoutType.TensorNC16HW16: 'nc16hw16', LayoutType.TensorNC32HW32: 'nc32hw32', LayoutType.TensorNC64HW64: 'nc64hw64', LayoutType.TensorC32RSK32: 'c32rsk32', LayoutType.TensorC64RSK64: 'c64rsk64', LayoutType.TensorK4RSC4: 'k4rsc4', LayoutType.TensorCK4RS4: 'ck4rs4', LayoutType.TensorCK8RS8: 'ck8rs8', LayoutType.TensorCK16RS16: 'ck16rs16', } # ShortComplexLayoutNames = { (LayoutType.ColumnMajor, ComplexTransform.none): 'n', (LayoutType.ColumnMajor, ComplexTransform.conj): 'c', (LayoutType.RowMajor, ComplexTransform.none): 't', (LayoutType.RowMajor, ComplexTransform.conj): 'h' } ################################################################################################### # class OpcodeClass(enum.Enum): Simt = enum_auto() TensorOp = enum_auto() WmmaTensorOp = enum_auto() OpcodeClassNames = { OpcodeClass.Simt: 'simt', OpcodeClass.TensorOp: 'tensorop', OpcodeClass.WmmaTensorOp: 'wmma_tensorop', } OpcodeClassTag = { OpcodeClass.Simt: 'cutlass::arch::OpClassSimt', OpcodeClass.TensorOp: 'cutlass::arch::OpClassTensorOp', OpcodeClass.WmmaTensorOp: 'cutlass::arch::OpClassWmmaTensorOp', } ################################################################################################### # class OperationKind(enum.Enum): Gemm = enum_auto() Conv2d = enum_auto() # OperationKindNames = { OperationKind.Gemm: 'gemm' , OperationKind.Conv2d: 'conv2d' } # class Target(enum.Enum): library = enum_auto() ArchitectureNames = { 50: 'maxwell', 60: 'pascal', 61: 'pascal', 70: 'volta', 75: 'turing', 80: 'ampere', } ################################################################################################### # def SubstituteTemplate(template, values): text = template changed = True while changed: changed = False for key, value in values.items(): regex = "\\$\\{%s\\}" % key newtext = re.sub(regex, value, text) if newtext != text: changed = True text = newtext return text ################################################################################################### # class GemmKind(enum.Enum): Gemm = enum_auto() Sparse = enum_auto() Universal = enum_auto() PlanarComplex = enum_auto() PlanarComplexArray = enum_auto() SplitKParallel = enum_auto() GemvBatchedStrided = enum_auto() # GemmKindNames = { GemmKind.Gemm: "gemm", GemmKind.Sparse: "spgemm", GemmKind.Universal: "gemm", GemmKind.PlanarComplex: "gemm_planar_complex", GemmKind.PlanarComplexArray: "gemm_planar_complex_array", GemmKind.SplitKParallel: "gemm_split_k_parallel", GemmKind.GemvBatchedStrided: "gemv_batched_strided", } # class EpilogueFunctor(enum.Enum): LinearCombination = enum_auto() LinearCombinationClamp = enum_auto() BiasAddLinearCombination = enum_auto() BiasAddLinearCombinationRelu = enum_auto() BiasAddLinearCombinationHSwish = enum_auto() BiasAddLinearCombinationClamp = enum_auto() BiasAddLinearCombinationReluClamp = enum_auto() BiasAddLinearCombinationHSwishClamp = enum_auto() # EpilogueFunctorTag = { EpilogueFunctor.LinearCombination: 'cutlass::epilogue::thread::LinearCombination', EpilogueFunctor.LinearCombinationClamp: 'cutlass::epilogue::thread::LinearCombinationClamp', EpilogueFunctor.BiasAddLinearCombination: 'cutlass::epilogue::thread::BiasAddLinearCombination', EpilogueFunctor.BiasAddLinearCombinationRelu: 'cutlass::epilogue::thread::BiasAddLinearCombinationRelu', EpilogueFunctor.BiasAddLinearCombinationHSwish: 'cutlass::epilogue::thread::BiasAddLinearCombinationHSwish', EpilogueFunctor.BiasAddLinearCombinationClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationClamp', EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationReluClamp', EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp', } # ShortEpilogueNames = { EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish', EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu', EpilogueFunctor.BiasAddLinearCombinationClamp: 'id', EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish', EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu', EpilogueFunctor.BiasAddLinearCombination: 'id', } # class SwizzlingFunctor(enum.Enum): Identity1 = enum_auto() Identity2 = enum_auto() Identity4 = enum_auto() Identity8 = enum_auto() ConvFpropNCxHWx = enum_auto() ConvFpropTrans = enum_auto() ConvDgradNCxHWx = enum_auto() ConvDgradTrans = enum_auto() # SwizzlingFunctorTag = { SwizzlingFunctor.Identity1: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>', SwizzlingFunctor.Identity2: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>', SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle', SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', SwizzlingFunctor.ConvDgradTrans: 'cutlass::conv::threadblock::ConvolutionDgradTransThreadblockSwizzle', } ################################################################################################### class ConvType(enum.Enum): Convolution = enum_auto() BatchConvolution = enum_auto() Local = enum_auto() LocalShare = enum_auto() ConvTypeTag = { ConvType.Convolution: 'cutlass::conv::ConvType::kConvolution', ConvType.BatchConvolution: 'cutlass::conv::ConvType::kBatchConvolution', ConvType.Local: 'cutlass::conv::ConvType::kLocal', ConvType.LocalShare : 'cutlass::conv::ConvType::kLocalShare', } # class ConvKind(enum.Enum): Fprop = enum_auto() Dgrad = enum_auto() Wgrad = enum_auto() # ConvKindTag = { ConvKind.Fprop: 'cutlass::conv::Operator::kFprop', ConvKind.Dgrad: 'cutlass::conv::Operator::kDgrad', ConvKind.Wgrad: 'cutlass::conv::Operator::kWgrad' } ConvKindNames = { ConvKind.Fprop: 'fprop', ConvKind.Dgrad: 'dgrad', ConvKind.Wgrad: 'wgrad', } # class IteratorAlgorithm(enum.Enum): Analytic = enum_auto() Optimized = enum_auto() # IteratorAlgorithmTag = { IteratorAlgorithm.Analytic: 'cutlass::conv::IteratorAlgorithm::kAnalytic', IteratorAlgorithm.Optimized: 'cutlass::conv::IteratorAlgorithm::kOptimized', } IteratorAlgorithmNames = { IteratorAlgorithm.Analytic: 'analytic', IteratorAlgorithm.Optimized: 'optimized', } # class StrideSupport(enum.Enum): Strided = enum_auto() Unity = enum_auto() # StrideSupportTag = { StrideSupport.Strided: 'cutlass::conv::StrideSupport::kStrided', StrideSupport.Unity: 'cutlass::conv::StrideSupport::kUnity', } StrideSupportNames = { StrideSupport.Strided: '', StrideSupport.Unity: 'unity_stride', } class SpecialOptimizeDesc(enum.Enum): NoneSpecialOpt = enum_auto() ConvFilterUnity = enum_auto() DeconvDoubleUpsampling = enum_auto() SpecialOptimizeDescNames = { SpecialOptimizeDesc.NoneSpecialOpt: 'none', SpecialOptimizeDesc.ConvFilterUnity: 'conv_filter_unity', SpecialOptimizeDesc.DeconvDoubleUpsampling: 'deconv_double_upsampling', } SpecialOptimizeDescTag = { SpecialOptimizeDesc.NoneSpecialOpt: 'cutlass::conv::SpecialOptimizeDesc::NONE', SpecialOptimizeDesc.ConvFilterUnity: 'cutlass::conv::SpecialOptimizeDesc::CONV_FILTER_UNITY', SpecialOptimizeDesc.DeconvDoubleUpsampling: 'cutlass::conv::SpecialOptimizeDesc::DECONV_DOUBLE_UPSAMPLING', } class ImplicitGemmMode(enum.Enum): GemmNT = enum_auto() GemmTN = enum_auto() ImplicitGemmModeNames = { ImplicitGemmMode.GemmNT: 'gemm_nt', ImplicitGemmMode.GemmTN: 'gemm_tn', } ImplicitGemmModeTag = { ImplicitGemmMode.GemmNT: 'cutlass::conv::ImplicitGemmMode::GEMM_NT', ImplicitGemmMode.GemmTN: 'cutlass::conv::ImplicitGemmMode::GEMM_TN', } ################################################################################################### # class MathInstruction: def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add): self.instruction_shape = instruction_shape self.element_a = element_a self.element_b = element_b self.element_accumulator = element_accumulator self.opcode_class = opcode_class self.math_operation = math_operation # class TileDescription: def __init__(self, threadblock_shape, stages, warp_count, math_instruction, min_compute, max_compute): self.threadblock_shape = threadblock_shape self.stages = stages self.warp_count = warp_count self.math_instruction = math_instruction self.minimum_compute_capability = min_compute self.maximum_compute_capability = max_compute def procedural_name(self): return "%dx%d_%dx%d" % (self.threadblock_shape[0], self.threadblock_shape[1], self.threadblock_shape[2], self.stages) # class TensorDescription: def __init__(self, element, layout, alignment = 1, complex_transform = ComplexTransform.none): self.element = element self.layout = layout self.alignment = alignment self.complex_transform = complex_transform ################################################################################################### class GlobalCnt: cnt = 0