From b18feaab33eeff87cfc56813f7beccb7fd4c3873 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 6 Jul 2021 16:33:46 +0800 Subject: [PATCH] feat(dnn/cuda): use cutlass remove shared load imma conv kernel GitOrigin-RevId: 0b5574f52669ba88237967d3f00fceca0857b80f --- .../cutlass_generator/conv2d_operation.py | 29 +- dnn/scripts/cutlass_generator/generator.py | 58 +- dnn/scripts/cutlass_generator/library.py | 20 +- dnn/scripts/cutlass_generator/list.bzl | Bin 45553 -> 49641 bytes dnn/src/cuda/conv_bias/algo.cpp | 60 +- dnn/src/cuda/conv_bias/algo.h | 3 + .../conv_bias/cutlass_convolution_wrapper.cuh | 12 +- .../cutlass_convolution_wrapper_int4.cu | 595 ++++++++++++++++++ ...cu => cutlass_convolution_wrapper_int8.cu} | 570 ++--------------- .../cuda/conv_bias/cutlass_reorder_filter.cu | 194 ++++++ .../cuda/conv_bias/cutlass_reorder_filter.cuh | 33 + .../implicit_gemm_int4_int4_nchw64_imma.cpp | 2 +- .../implicit_gemm_int4_int4_nhwc_imma.cpp | 4 +- .../implicit_gemm_int4_nchw64_imma_base.cpp | 44 +- .../implicit_gemm_int4_nhwc_imma_base.cpp | 40 +- .../implicit_gemm_int8_nchw32_imma.cpp | 75 +-- .../implicit_gemm_uint4_int4_nchw64_imma.cpp | 3 +- .../implicit_gemm_uint4_int4_nhwc_imma.cpp | 4 +- dnn/test/cuda/conv_bias_int8.cpp | 32 +- 19 files changed, 1063 insertions(+), 715 deletions(-) create mode 100644 dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int4.cu rename dnn/src/cuda/conv_bias/{cutlass_convolution_wrapper.cu => cutlass_convolution_wrapper_int8.cu} (58%) create mode 100644 dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu create mode 100644 dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh diff --git a/dnn/scripts/cutlass_generator/conv2d_operation.py b/dnn/scripts/cutlass_generator/conv2d_operation.py index 6cd180479..324c2a524 100644 --- a/dnn/scripts/cutlass_generator/conv2d_operation.py +++ b/dnn/scripts/cutlass_generator/conv2d_operation.py @@ -20,7 +20,7 @@ class Conv2dOperation: # def __init__(self, conv_kind, conv_type, arch, tile_description, src, flt, bias, dst, element_epilogue, \ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity4, \ - need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNt): + need_load_from_const = True, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False): self.operation_kind = OperationKind.Conv2d self.conv_kind = conv_kind @@ -36,6 +36,7 @@ class Conv2dOperation: self.swizzling_functor = swizzling_functor self.need_load_from_const = need_load_from_const self.implicit_gemm_mode = implicit_gemm_mode + self.without_shared_load = without_shared_load # def accumulator_type(self): accum = self.tile_description.math_instruction.element_accumulator @@ -58,11 +59,15 @@ class Conv2dOperation: unity_kernel = '' if not self.need_load_from_const: - unity_kernel = '_1x1' + unity_kernel = '_1x1' - return "%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ + reorder_k = '' + if self.without_shared_load: + reorder_k = '_roc' + + return "%s%s%s%s%s%s_%s" % (ShortDataTypeNames[self.accumulator_type()], \ inst_shape, intermediate_type, ConvKindNames[self.conv_kind], unity_kernel, \ - ShortEpilogueNames[self.epilogue_functor]) + reorder_k, ShortEpilogueNames[self.epilogue_functor]) # def extended_name(self): @@ -177,7 +182,8 @@ using Convolution = ${alignment_filter}, ${nonuninity_kernel}, ${math_operator}, - ${implicit_gemm_mode}>; + ${implicit_gemm_mode}, + ${without_shared_load}>; """ @@ -219,7 +225,8 @@ using Convolution = 'alignment_filter': str(operation.flt.alignment), 'nonuninity_kernel': str(operation.need_load_from_const).lower(), 'math_operator': MathOperationTag[operation.tile_description.math_instruction.math_operation], - 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode] + 'implicit_gemm_mode': ImplicitGemmModeTag[operation.implicit_gemm_mode], + 'without_shared_load': str(operation.without_shared_load).lower() } return SubstituteTemplate(self.template, values) @@ -312,13 +319,13 @@ using Deconvolution = # def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_layout, dst_type, min_cc, src_align = 32, flt_align = 32, dst_align = 128, \ - skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNt): + skip_unity_kernel = False, implicit_gemm_mode = ImplicitGemmMode.GemmNT, without_shared_load = False): operations = [] element_epilogue = DataType.f32 if conv_kind == ConvKind.Fprop: - if src_layout == LayoutType.TensorNHWC: - swizzling_functor = SwizzlingFunctor.ConvFpropNHWC + if implicit_gemm_mode == ImplicitGemmMode.GemmTN: + swizzling_functor = SwizzlingFunctor.ConvFpropTrans else: swizzling_functor = SwizzlingFunctor.ConvFpropNCxHWx else: @@ -399,10 +406,10 @@ def GenerateConv2d(conv_kind, tile_descriptions, src_layout, flt_layout, dst_lay bias = TensorDescription(bias_type, dst_layout, max(1, int(32 / DataTypeSize[bias_type]))) dst = TensorDescription(dst_type, dst_layout, int(dst_align / DataTypeSize[dst_type])) - new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode) + new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, True, implicit_gemm_mode, without_shared_load) operations.append(new_operation) if not skip_unity_kernel: - new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode) + new_operation = Conv2dOperation(conv_kind, ConvType.Convolution, min_cc, tile, src, flt, bias, dst, element_epilogue, epilogue, swizzling_functor, False, implicit_gemm_mode, without_shared_load) operations.append(new_operation) return operations diff --git a/dnn/scripts/cutlass_generator/generator.py b/dnn/scripts/cutlass_generator/generator.py index 61117b45f..ed6149de7 100644 --- a/dnn/scripts/cutlass_generator/generator.py +++ b/dnn/scripts/cutlass_generator/generator.py @@ -175,12 +175,10 @@ def GenerateConv2d_Simt(args): TileDescription([128, 128, 32], 2, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 32], 2, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 128, 32], 2, [1, 4, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 32, 32], 2, [2, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 32, 128, 32], 2, [1, 2, 1], math_inst, min_cc, max_cc), TileDescription([ 32, 64, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 64, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 32, 32], 2, [1, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 16, 128, 16], 1, [1, 1, 1], math_inst, min_cc, max_cc), TileDescription([ 16, 64, 8], 2, [1, 1, 1], math_inst, min_cc, max_cc), ] @@ -223,28 +221,36 @@ def GenerateConv2d_TensorOp_8816(args): for dst_type, dst_layout in zip(dst_types, dst_layouts): if dst_layout == LayoutType.TensorNC32HW32: tile_descriptions = [ - TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 64, 64], 2, [1, 4, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc), ] + operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], + dst_layout, dst_type, min_cc, 128, 128, 64, + False, ImplicitGemmMode.GemmTN, True) else: assert dst_layout == LayoutType.TensorNC4HW4 tile_descriptions = [ - TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 256, 64], 2, [2, 4, 1], math_inst, min_cc, max_cc), + TileDescription([256, 128, 64], 2, [4, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 64, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), - TileDescription([ 32, 64, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 64], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 1, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([ 64, 128, 32], 1, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([ 32, 128, 32], 1, [1, 2, 1], math_inst, min_cc, max_cc), ] - operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], + operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], dst_layout, dst_type, min_cc, 128, 128, 64, False) + return operations def GenerateConv2d_TensorOp_8832(args): @@ -279,12 +285,14 @@ def GenerateConv2d_TensorOp_8832(args): for dst_layout in dst_layouts: dst_type = math_inst.element_b tile_descriptions = [ - TileDescription([256, 128, 128], 2, [4, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 256, 128], 2, [2, 4, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 128], 2, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 128], 2, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), ] operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], dst_layout, dst_type, min_cc, 128, 128, 64, - True) + True, ImplicitGemmMode.GemmTN, True) layouts_nhwc = [ (LayoutType.TensorNHWC, LayoutType.TensorNC8HW8, 32), @@ -299,14 +307,21 @@ def GenerateConv2d_TensorOp_8832(args): for math_inst in math_instructions: for layout in layouts_nhwc: for dst_layout in dst_layouts_nhwc: - dst_type = math_inst.element_b - tile_descriptions = [ - TileDescription([128, 32, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), - TileDescription([128, 64, 64], 2, [2, 1, 1], math_inst, min_cc, max_cc), - ] - operations += GenerateConv2d(ConvKind.Fprop, tile_descriptions, layout[0], layout[1], - dst_layout, dst_type, min_cc, layout[2], layout[2], 32, - False, ImplicitGemmMode.GemmTn) + dst_type = math_inst.element_b + tile_descriptions = [ + TileDescription([128, 32, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 1, [2, 1, 1], math_inst, min_cc, max_cc), + ] + for tile in tile_descriptions: + operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], + dst_layout, dst_type, min_cc, layout[2], layout[2], 32, + False, ImplicitGemmMode.GemmTN, False) + if tile.threadblock_shape[1] == 32 or tile.threadblock_shape[1] == 64: + dst_align = 32 if tile.threadblock_shape[1] == 32 else 64 + operations += GenerateConv2d(ConvKind.Fprop, [tile], layout[0], layout[1], + dst_layout, dst_type, min_cc, layout[2], layout[2], dst_align, + False, ImplicitGemmMode.GemmTN, True) + return operations def GenerateDeconv_Simt(args): @@ -649,3 +664,4 @@ if __name__ == "__main__": # ################################################################################################### + \ No newline at end of file diff --git a/dnn/scripts/cutlass_generator/library.py b/dnn/scripts/cutlass_generator/library.py index 5e82a2952..ae6ff2462 100644 --- a/dnn/scripts/cutlass_generator/library.py +++ b/dnn/scripts/cutlass_generator/library.py @@ -464,10 +464,10 @@ EpilogueFunctorTag = { ShortEpilogueNames = { EpilogueFunctor.BiasAddLinearCombinationHSwishClamp: 'hswish', EpilogueFunctor.BiasAddLinearCombinationReluClamp: 'relu', - EpilogueFunctor.BiasAddLinearCombinationClamp: 'identity', + EpilogueFunctor.BiasAddLinearCombinationClamp: 'id', EpilogueFunctor.BiasAddLinearCombinationHSwish: 'hswish', EpilogueFunctor.BiasAddLinearCombinationRelu: 'relu', - EpilogueFunctor.BiasAddLinearCombination: 'identity', + EpilogueFunctor.BiasAddLinearCombination: 'id', } @@ -482,7 +482,7 @@ class SwizzlingFunctor(enum.Enum): Identity4 = enum_auto() Identity8 = enum_auto() ConvFpropNCxHWx = enum_auto() - ConvFpropNHWC = enum_auto() + ConvFpropTrans = enum_auto() ConvDgradNCxHWx = enum_auto() # @@ -492,7 +492,7 @@ SwizzlingFunctorTag = { SwizzlingFunctor.Identity4: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>', SwizzlingFunctor.Identity8: 'cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>', SwizzlingFunctor.ConvFpropNCxHWx: 'cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle', - SwizzlingFunctor.ConvFpropNHWC: 'cutlass::conv::threadblock::ConvolutionFpropNHWCThreadblockSwizzle', + SwizzlingFunctor.ConvFpropTrans: 'cutlass::conv::threadblock::ConvolutionFpropTransThreadblockSwizzle', SwizzlingFunctor.ConvDgradNCxHWx: 'cutlass::conv::threadblock::ConvolutionDgradNCxHWxThreadblockSwizzle', } @@ -563,17 +563,17 @@ StrideSupportNames = { } class ImplicitGemmMode(enum.Enum): - GemmNt = enum_auto() - GemmTn = enum_auto() + GemmNT = enum_auto() + GemmTN = enum_auto() ImplicitGemmModeNames = { - ImplicitGemmMode.GemmNt: 'gemm_nt', - ImplicitGemmMode.GemmTn: 'gemm_tn', + ImplicitGemmMode.GemmNT: 'gemm_nt', + ImplicitGemmMode.GemmTN: 'gemm_tn', } ImplicitGemmModeTag = { - ImplicitGemmMode.GemmNt: 'cutlass::conv::ImplicitGemmMode::GEMM_NT', - ImplicitGemmMode.GemmTn: 'cutlass::conv::ImplicitGemmMode::GEMM_TN', + ImplicitGemmMode.GemmNT: 'cutlass::conv::ImplicitGemmMode::GEMM_NT', + ImplicitGemmMode.GemmTN: 'cutlass::conv::ImplicitGemmMode::GEMM_TN', } ################################################################################################### diff --git a/dnn/scripts/cutlass_generator/list.bzl b/dnn/scripts/cutlass_generator/list.bzl index d1e10ac57eaf264fb0e3c6022daff0d7b2c5ffc7..d84c0a3b78e58126301efe07923b791d91349077 100644 GIT binary patch literal 49641 zcmcJY(QX?#5=Gzh6~^rAVgfl{LofS0AF)^n9@!HcaU4TZ)|tQG%Vx8S)l^GetXqS? z)+Fh}s=Aj&Hq}=Car4L1+tcCU^z?Z1=ifJvZ*QNE`@_ro>G*m7^Yr>~Jnp{czwqPm z8LxJ)FUQmG?|*;#RQ>Vrn>YWy`Ol{t{P$b=ljFHGKH_SctNe7W7dz3m=e zU%tL=e&7H6?VpQ(|M>p;a@u{{y+0ftUSFSH&4A6lcSJJ*7u?^uk~9;(+*^TNM!*am`(N*nW&$p-#n=?qOh8L;W@#o=Yi@64?Fba6 zG@;%>?Ff{{si&EcI;B%kI|5O9i)%-uFrC%4Bj5ov=waD**vN=LZH!kmeZh`MwegDD z!AQ%W)4z5^y25ECJ3{7Z=l;(Ms21JM5zPc#Q0m%0q>X?XTKv;Fq?v#VY%w;|(@em8 zc4lcN)N5{MW$g$QW@;@Gc7Hyc_Rmj`yW{Eb@`&%7q1?#*G#cyf z*WIOeyO#)?`-*VeL)cb;i*LBE2;sIb2?+rct|y1PDF-zE9)Lu!=}SQ(*z~0!PJlJ6 zDZtM7<**X2R$er#0D!V$PYP889Y)d-bQlQ$u|bEAbO0SbQUP?>2-k}VI$VVMr#0Zz z-;qE-5CSmU2!P53*l1X1f{viWMmmBH8v#(wgbp9+2s(VEBj_*^0AhnnM%wOP9={$Q z9`S8=+(P@bs-+bjudsQ-3p~uVzkB|;+kLw`9QVsh+>DfW-wA+M1FV;nZytU@lHsm0 zz37Zdk%r9P5{2_2QJCURWci7dWk?u!d{tOOerF zMa>~Qe_TYIs7Z>A3d{mrysxr1Tnes8K!MdWlmZkHC^$V!Dac7#*zU}2ZkMV)RHdj7 zG*i50JheQ8fyqx>qU(>2%RNEq~l%E5OcD`bgr7)dzTZa*diDeTK9h{SmdH& z^}0Ys%jCVh9q)3&^&5Hvwm5+bt(`yxixa3=?F1@1JK_5Mx|^`gM!ic(Cf zUK0G|pLgY6FMA*@GI^mn^GHD1CO&`cmmc#FUeSt2l#240lJ!taX3`X(&Y6OoCOcPW z#Kr+i$~2^?jD_gpy|CHZ6riM>OGt<@r;O-EQI=O=Q_3`;q)bCjDd%az%b3Db2~&vE zL`J4OQxBmLc#5=P`@!JYuaok1G1`m|AxpSLXY*{Y%e$ z=P--z*_xMkMeC#FQ$=@fq!HY#JMn1on&NDoEjq?D$-En48Uf2R^9wsn>!I& z1R^5qnNGxtL`10G&=J9t`mpJRqlymY8Nf<)25zP}1F6(z02`%ypr@y|ca4*JQteEt-_S!Z~e(M@0y+p4~!yG|8 z18Q<4mZ^O6ju+|XowUJW(`(PL*`?~mGi<*044Yn}7tgRq5YJ@8sB@3>pd86FYFH}hA{X~v%eIP+lEg@eA>5sI$K?ok;;k0I(tW)VM1R(j>3ki$z+ zBNUyZ_`_@-KMy!+3{g31gp@`6DAG0ZZv5i;vBOJ=9x+1kQWU@TQq<a?}XgOHreU zm$2i?OB|Nb^JhFxD+LTtJ1GD^H&Vb5+D8Ecv^EMDB)bUUWK)=Dun^xxK?4*Q1>tKK z1q~rC3K~GWC}{ml`Fjz2fa0P&1+JZAb;Tb%gt#a$hwwqfMd(6ok<&Z~9iWyt zDE{Cfw7@|*gb$)tv@2&tU2{2ffFkkOs3YPJ9zp~T%prV`EHdUmYN!z^gbq+74var| z2oX3ihloM&$rRh|ms|a`A^zRG@gxu5$D982D!m|_08*UXNu!Lzr|Z)2`KmJc>;r?~MS2gBHlZ?d_)1|n zRdGa~J~OZBGx*JMHjT%m{f=q*L#qiL0jdh$J)f>-5HNzuz-#-f3GfG)$4@or%s|Z9 zCH(-qg&cH#fqsCGO+i!GK7VX4zQd>HQ@L5pOB07#FV!5P98WRh)Ln`fX#}@C8^Fwe zsOtD*989h}55T0BX!-7{p8<%7IKmIo2ior9?*klWtrUki?*kmF)`!DPkw$QJAK=9r z!JYR3ZnS)NwS9maafBbF4=k-;oBQXF%{_ZZV5Qb-F(ZxOYMq$T#*p>gWyTvrp7#ZI z#1WoodjmV_7;i{_Sn6fgnVeQOKDE}28EFJpE5(d9hOB2XGu{~Tyg#rbj_^d=AJ|dH zctd9LQa@Y#*AN#D>vex%QESau%=z5}i>sAlQQJET7Fo|?7ImoZs?!Cl!{+^g9dU%$ z+Wx?fI>vRgK816hJeR87ZPt?l-cNM*0I5A+OgrvXiHPh z>0b_ZW-|_SsBQX@CaEh2-1wJ+EnWza?Tml9F^34K3bE9GuOIP1{^}u%)g9Oho|S+Z zZxFiiF*WA3o!DGwfP)HBW?&Fr`FUynSK f3Z9vc8E+7}@iR5%xu03^Gk#{F7yP^i|KI-u4f#7n literal 45553 zcmchg-EJE>5{2*c6voW$VgkupK{tE753yJX9$6C$aTG&R#>w07VYAt+V*gmhI^2OV zX(Z{7Rdo(oY__fZ{_yYP`FMG{9iJcm{PpmBK7YO*E~mHK_4DD!?e*z;-G7OH;rrz? zuJ*5|>+O%XU!OjSKkk0>;a?B``SgJQeoKFHJ$<|FuV0SezU_BEAJ_QbdVfCepI%R2 z&Z|ETKYshi?BBn>y`FCSulu*B%hT)Y@zo4iZM`F!3AkaqaV2RcJg&W2nhChUnWdSq z-FUM!6QIOuz4MA`M4&LnJ4VyQjzDRgdYTDH(krMPfhfJjwIf29&g$9`@WA??TbFn3 z2-L>!DD=TY9ptQpxR5_>}fzsI280l#ygq6}Ms2zbQ zy~VX7LYU6#+7a*o7UuqF2pf<)krhH^XgxYvTtzaNZuYc``aEH@McEm7GZ{DZv zI{~@r-g(2wMP>qS=V5ur9-Q7ags z{W<+>M}#|^R?T8Sjv$}Q!Jb(o{w(WWw84;+B z@s82DU`K@7ctx#Xg!bq3uN@KYa9YWV_f>m7#*$J?QH3iri zUk)qbZskR@3IHf8)}$aJs4$X_pu$K1WE)iYNC!~iBNae}jc|W9L4}J@|Ih<&^(ztx z2qFW_HUdDn02>YKOwbWj*hoiEVIu&Tn3={js)-Z%TNo6Geu-{>ZJaVH)yxayETUVNPR36Z+z z%y8j~Pw~1<5Xhc{V_3UByqzVwzuqq2-;^5tu54P41LXL+C`Sas1AOS^`1&LGg%%+d z_;P(eU0&qY(`o`q|;CL}SEI8!#B89qMR3WYxr?mCL zM$&qLQrV8hU6={Tq$~j?l_lV^vILk`mH?E>Vybpu;lLnXFy^cmnUo!Z8ZN4$BjU2M zLrB9#RcutCRCe*KrFK20K$9F4n4Gy3U>1ae)3cX?EL6sdFmV9pP}m0)>iR&1vObtn z)rT4>>cc5D&eWC$8r`r+e~JjW_Qg(pW)q0$mmaP|Zht4%>gOIP5*HL@qZ zTfp5rsPNVJ6+7uRggIRJ;%&jrstrL7 z7rtnjQ)yF(GV2aVv^k5Hvt|)x){waSTG-KAyu!C4A5J@9^?IcaSQu<+*z3 z4u>Ibh)c4zMr)2debvAV>0Z zuu}av{aeV-BPacNL}>sXH9LUE)E3}zr3dg(5y4B}qF1#9@W{ynJfidfkD5KeV`>lZ zxP=Ge;sSiU7#;7_Z3YM4a8U^s4&X&fhbsMfaW=q8^}9e@Gyxho2}J|S=}&`ZnP^Bo z`Ds|GK4xb3wQ%fJJ@+}_q&_uLIy9?Kjk5t(s(;tPTAltxbP|h*lq&!co8=-x^(sIF z7f!(XvTzt$?E(f?IDmmG)Mp@-@(kce^&05yc)orQpF)X`c0BENtL^wRDZ|HShbI7h zRxC#7zPh)N@n^vI_*rm#Eal`6gU`T2;L=(20g@T$U}+0FJR5|L)>fh8rI^7fVSr|C zp%0MULI+E?(BauFbhLI09bdX74A9gq)B&VhsNm8qRJe8v6+Lzf6))X_-qW+Og>KM@ z*!<)bHdk7O&CX6?leJOUbm^05l73O?tis)}u=&X+Y_9YPo1J~aCTpLt>Cz|gVE5G@ zZJ-H-(?{CAmW)p>A!DUU$mr}5GFV%L43`cW`V21cT+Ma}8J`?N#!82f(b*wnuyzO; zE*)~`!xYw)@K0(?eB%`NA=x}m@=VJ-T91)k)G~|qqex$czIVrSv{~;kG5#OpB+Xmq z(Rz%mc*`u>k5cL#_f(6gc@$r@7)NT}f-jq0;I|kCLmw9`jV@oq;o8YVfa5nei~(Gb~DjRw$;YBY#+ z6inNFEk5tN4b^6p(%gR(S0=dw`X;WNo5+=Ac6PZm-Pqetir(~LgI8b@eQArU`?9)5mv zqW<{pfbZ6X{zv2Dp)fwC1@m?H^1fR<<5SeI&J2~IlGF6(#>%JaNyDpZ3vQ`Yx5rS2 zC@c?f$Wkp1F?aSj3piA*6^AK1I5XDL19OgDZC~D3+cqnKl`Vb1jJE|>dx05k8*=G~ zIrmyMo9YZKw$zQqymw`=)PF~^TF;4SEym2M!NL~>VPOrfwiaZf o)^iT|xUIT;ZN(S0W=7ozJyGj9hggeY^tBaV6oeUd1^U1L19h@g*8l(j diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index 80f2846ec..ed0a2eb73 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -217,56 +217,68 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { #if CUDA_VERSION >= 10020 { using AlgoParam = AlgoInt8NCHW32IMMAImplicitGemm::AlgoParam; - int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64}); - int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64}); - int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64}); - int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64}); - int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64}); - int8_nchw32_imma.emplace_back(AlgoParam{64, 64, 64, 32, 32, 64}); - int8_nchw32_imma.emplace_back(AlgoParam{32, 64, 64, 32, 16, 64}); + int8_nchw32_imma.emplace_back(AlgoParam{128, 256, 64, 64, 64, 64, 2}); + int8_nchw32_imma.emplace_back(AlgoParam{256, 128, 64, 64, 64, 64, 2}); + int8_nchw32_imma.emplace_back(AlgoParam{128, 128, 64, 64, 64, 64, 2}); + int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 64, 64, 32, 64, 2}); + int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 64, 32, 64, 64, 2}); + int8_nchw32_imma.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 1}); + int8_nchw32_imma.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 1}); + int8_nchw32_imma.emplace_back(AlgoParam{64, 128, 32, 32, 64, 32, 1}); + int8_nchw32_imma.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 1}); } { using AlgoParam = AlgoInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; int4_int4_nchw64_imma.emplace_back( - AlgoParam{128, 128, 128, 64, 64, 128}); + AlgoParam{128, 128, 128, 64, 64, 128, 2}); int4_int4_nchw64_imma.emplace_back( - AlgoParam{256, 128, 128, 64, 64, 128}); + AlgoParam{128, 256, 128, 64, 64, 128, 2}); + int4_int4_nchw64_imma.emplace_back( + AlgoParam{128, 64, 128, 64, 64, 128, 2}); + int4_int4_nchw64_imma.emplace_back( + AlgoParam{128, 64, 64, 64, 64, 64, 1}); } { using AlgoParam = AlgoUInt4Int4NCHW64IMMAImplicitGemm::AlgoParam; uint4_int4_nchw64_imma.emplace_back( - AlgoParam{128, 128, 128, 64, 64, 128}); + AlgoParam{128, 128, 128, 64, 64, 128, 2}); + uint4_int4_nchw64_imma.emplace_back( + AlgoParam{128, 256, 128, 64, 64, 128, 2}); uint4_int4_nchw64_imma.emplace_back( - AlgoParam{256, 128, 128, 64, 64, 128}); + AlgoParam{128, 64, 128, 64, 64, 128, 2}); + uint4_int4_nchw64_imma.emplace_back( + AlgoParam{128, 64, 64, 64, 64, 64, 1}); } { using AlgoParam = AlgoInt4Int4NHWCIMMAImplicitGemm::AlgoParam; int4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 32, 64, 64, 32, 64, 32}); + AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); + int4_int4_nhwc_imma.emplace_back( + AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); + int4_int4_nhwc_imma.emplace_back( + AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); int4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 32, 64, 64, 32, 64, 16}); - int4_int4_nhwc_imma.emplace_back(AlgoParam{128, 32, 64, 64, 32, 64, 8}); + AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); int4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 32}); + AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); int4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 16}); - int4_int4_nhwc_imma.emplace_back(AlgoParam{128, 64, 64, 64, 64, 64, 8}); + AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); } { using AlgoParam = AlgoUInt4Int4NHWCIMMAImplicitGemm::AlgoParam; uint4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 32, 64, 64, 32, 64, 32}); + AlgoParam{128, 32, 64, 64, 32, 64, 1, 32}); uint4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 32, 64, 64, 32, 64, 16}); + AlgoParam{128, 32, 64, 64, 32, 64, 1, 16}); uint4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 32, 64, 64, 32, 64, 8}); + AlgoParam{128, 32, 64, 64, 32, 64, 1, 8}); uint4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 32}); + AlgoParam{128, 64, 64, 64, 64, 64, 1, 32}); uint4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 16}); + AlgoParam{128, 64, 64, 64, 64, 64, 1, 16}); uint4_int4_nhwc_imma.emplace_back( - AlgoParam{128, 64, 64, 64, 64, 64, 8}); + AlgoParam{128, 64, 64, 64, 64, 64, 1, 8}); } #endif } @@ -279,10 +291,8 @@ void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2}); - int8_nchw4_dotprod.emplace_back(AlgoParam{64, 64, 32, 64, 32, 32, 2}); int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2}); int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2}); - int8_nchw4_dotprod.emplace_back(AlgoParam{32, 32, 32, 32, 32, 32, 2}); int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); } diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index e892ff8cd..fe4f044e9 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -723,6 +723,7 @@ public: int warp_m; int warp_n; int warp_k; + int stage; }; AlgoInt8NCHW32IMMAImplicitGemm(AlgoParam algo_param) : m_algo_param{algo_param} { @@ -770,6 +771,7 @@ public: int warp_m; int warp_n; int warp_k; + int stage; }; AlgoInt4NCHW64IMMAImplicitGemmBase(AlgoParam algo_param) @@ -897,6 +899,7 @@ public: int warp_m; int warp_n; int warp_k; + int stage; int access_size; }; diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh index 0b695d33f..424cdd61d 100644 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh @@ -38,7 +38,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( const convolution::ConvParam& param, uint32_t nonlinear_mode, float alpha, float beta, float gamma, float scale, const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - cudaStream_t stream); + int stages, cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( @@ -47,7 +47,7 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( const convolution::ConvParam& param, uint32_t nonlinear_mode, float alpha, float beta, float gamma, float scale, const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - cudaStream_t stream); + int stages, cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( @@ -83,7 +83,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( const convolution::ConvParam& param, uint32_t nonlinear_mode, float alpha, float beta, float gamma, float scale, const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - cudaStream_t stream); + int stages, cudaStream_t stream); template void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( @@ -92,7 +92,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( const convolution::ConvParam& param, uint32_t nonlinear_mode, float alpha, float beta, float gamma, float delta, float theta, float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, cudaStream_t stream); + const GemmCoord& warp_shape, int stages, cudaStream_t stream); template void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( @@ -110,7 +110,7 @@ void do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( const convolution::ConvParam& param, uint32_t nonlinear_mode, float alpha, float beta, float gamma, float scale, const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - const int32_t access_size, cudaStream_t stream); + const int32_t access_size, int stages, cudaStream_t stream); template void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( @@ -119,7 +119,7 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( const convolution::ConvParam& param, uint32_t nonlinear_mode, float alpha, float beta, float gamma, float delta, float theta, float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, const int32_t access_size, + const GemmCoord& warp_shape, const int32_t access_size, int stages, cudaStream_t stream); } // namespace cutlass_wrapper diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int4.cu b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int4.cu new file mode 100644 index 000000000..d29a41a51 --- /dev/null +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int4.cu @@ -0,0 +1,595 @@ +/** + * \file dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" + +#if !MEGDNN_TEGRA_X1 +#include "cutlass/convolution/device/convolution.h" +#endif +#include "src/common/opr_param_defs_enumv.cuh" +#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" +#pragma GCC diagnostic pop + +using namespace megdnn; +using namespace cuda; +using namespace cutlass_wrapper; + +/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ + +#if MEGDNN_TEGRA_X1 +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( + const int8_t* /* d_src */, const int8_t* /* d_filter */, + const int32_t* /* d_bias */, const int8_t* /* d_z */, + int8_t* /* d_dst */, int* /* workspace */, + const convolution::ConvParam& /* param */, + uint32_t /* nonlinear_mode */, float /* alpha */, + float /* beta */, float /* gamma */, float /* scale */, + const GemmCoord& /* threadblock_shape */, + const GemmCoord& /* warp_shape */, int /* stages */, + cudaStream_t /* stream */) {} +#else +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( + const int8_t* d_src, const int8_t* d_filter, + const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, + int* workspace, const convolution::ConvParam& param, + uint32_t nonlinear_mode, float alpha, float beta, float gamma, + float scale, const GemmCoord& threadblock_shape, + const GemmCoord& warp_shape, int stages, cudaStream_t stream) { +#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ + threadblock_k_, warp_m_, warp_n_, \ + warp_k_, stage_) \ + if (threadblock_shape.m() == threadblock_m_ && \ + threadblock_shape.n() == threadblock_n_ && \ + threadblock_shape.k() == threadblock_k_ && \ + warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ + warp_shape.k() == warp_k_ && stages == stage_) { \ + using ThreadBlockShape = \ + cutlass::gemm::GemmShape; \ + using WarpShape = cutlass::gemm::GemmShape; \ + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ + using Convolution = cutlass::conv::device::Convolution< \ + cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ + cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ + ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ + cutlass::layout::TensorNCxHWx<64>, int32_t, \ + cutlass::conv::ConvType::kConvolution, \ + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ + cutlass::conv::threadblock:: \ + ConvolutionFpropTransThreadblockSwizzle, \ + stage_, 32, 32, NeedLoadFromConstMem, \ + cutlass::arch::OpMultiplyAddSaturate, \ + cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ + typename Convolution::ConvolutionParameter conv_param( \ + param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ + param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ + param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ + return cutlass_convolution_wrapper( \ + reinterpret_cast(d_src), \ + reinterpret_cast(d_filter), d_bias, \ + reinterpret_cast(d_z), \ + reinterpret_cast(d_dst), workspace, \ + conv_param, epilogue, stream); \ + } +#define DISPATCH_KERNEL \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ + megdnn_assert(false, \ + "unsupported threadblock shape (%dx%dx%d) and warp shape " \ + "(%dx%dx%d)", \ + threadblock_shape.m(), threadblock_shape.n(), \ + threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ + warp_shape.k()); + using ElementOutput = cutlass::int4b_t; + using ElementAccumulator = int32_t; + using ElementBias = int32_t; + using ElementCompute = float; + using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; + switch (nonlinear_mode) { + case NonlineMode::IDENTITY: { + using EpilogueOp = + cutlass::epilogue::thread::BiasAddLinearCombinationClamp< + ElementOutput, 16, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma}; + DISPATCH_KERNEL; + } + case NonlineMode::RELU: { + using EpilogueOp = cutlass::epilogue::thread:: + BiasAddLinearCombinationReluClamp< + ElementOutput, 16, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; + DISPATCH_KERNEL; + } + case NonlineMode::H_SWISH: { + using EpilogueOp = cutlass::epilogue::thread:: + BiasAddLinearCombinationHSwishClamp< + ElementOutput, 16, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; + DISPATCH_KERNEL; + } + default: + megdnn_assert(false, + "unsupported nonlinear mode for conv bias operator"); + } +#undef DISPATCH_KERNEL_WITH_TILE_SHAPE +#undef DISPATCH_KERNEL +} +#endif + +#define INST(need_load_from_const_mem) \ + template void megdnn::cuda::cutlass_wrapper:: \ + do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ + need_load_from_const_mem>( \ + const int8_t* d_src, const int8_t* d_filter, \ + const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ + int* workspace, const convolution::ConvParam& param, \ + uint32_t nonlinear_mode, float alpha, float beta, \ + float gamma, float scale, \ + const GemmCoord& threadblock_shape, \ + const GemmCoord& warp_shape, int stages, \ + cudaStream_t stream); +INST(true); +#undef INST + +/* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ + +#if MEGDNN_TEGRA_X1 +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( + const uint8_t* /* d_src */, const int8_t* /* d_filter */, + const int32_t* /* d_bias */, const uint8_t* /* d_z */, + uint8_t* /* d_dst */, int* /* workspace */, + const convolution::ConvParam& /* param */, + uint32_t /* nonlinear_mode */, float /* alpha */, + float /* beta */, float /* gamma */, float /* delta */, + float /* theta */, float /* scale */, + uint8_t /* src_zero_point */, + const GemmCoord& /* threadblock_shape */, + const GemmCoord& /* warp_shape */, int /* stages */, + cudaStream_t /* stream */) {} +#else +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( + const uint8_t* d_src, const int8_t* d_filter, + const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, + int* workspace, const convolution::ConvParam& param, + uint32_t nonlinear_mode, float alpha, float beta, float gamma, + float delta, float theta, float /* scale */, + uint8_t src_zero_point, const GemmCoord& threadblock_shape, + const GemmCoord& warp_shape, int stages, cudaStream_t stream) { +#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ + threadblock_k_, warp_m_, warp_n_, \ + warp_k_, stage_) \ + if (threadblock_shape.m() == threadblock_m_ && \ + threadblock_shape.n() == threadblock_n_ && \ + threadblock_shape.k() == threadblock_k_ && \ + warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ + warp_shape.k() == warp_k_ && stages == stage_) { \ + using ThreadBlockShape = \ + cutlass::gemm::GemmShape; \ + using WarpShape = cutlass::gemm::GemmShape; \ + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ + using Convolution = cutlass::conv::device::Convolution< \ + cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ + cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ + ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ + cutlass::layout::TensorNCxHWx<64>, int32_t, \ + cutlass::conv::ConvType::kConvolution, \ + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ + cutlass::conv::threadblock:: \ + ConvolutionFpropTransThreadblockSwizzle, \ + stage_, 32, 32, NeedLoadFromConstMem, \ + cutlass::arch::OpMultiplyAddSaturate, \ + cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ + typename Convolution::ConvolutionParameter conv_param( \ + param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ + param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ + param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ + return cutlass_convolution_wrapper( \ + reinterpret_cast(d_src), \ + reinterpret_cast(d_filter), d_bias, \ + reinterpret_cast(d_z), \ + reinterpret_cast(d_dst), workspace, \ + conv_param, epilogue, stream, {src_zero_point}); \ + } +#define DISPATCH_KERNEL \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 128, 64, 64, 128, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 128, 64, 64, 128, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1); \ + megdnn_assert(false, \ + "unsupported threadblock shape (%dx%dx%d) and warp shape " \ + "(%dx%dx%d)", \ + threadblock_shape.m(), threadblock_shape.n(), \ + threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ + warp_shape.k()); + using ElementOutput = cutlass::uint4b_t; + using ElementAccumulator = int32_t; + using ElementBias = int32_t; + using ElementCompute = float; + using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; + switch (nonlinear_mode) { + case NonlineMode::IDENTITY: { + using EpilogueOp = + cutlass::epilogue::thread::BiasAddLinearCombinationClamp< + ElementOutput, 16, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, + delta + theta}; + DISPATCH_KERNEL; + } + case NonlineMode::RELU: { + using EpilogueOp = cutlass::epilogue::thread:: + BiasAddLinearCombinationReluClamp< + ElementOutput, 16, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, + 0, delta, theta}; + DISPATCH_KERNEL; + } + default: + megdnn_assert(false, + "unsupported nonlinear mode for conv bias operator"); + } +#undef DISPATCH_KERNEL_WITH_TILE_SHAPE +#undef DISPATCH_KERNEL +} +#endif + +#define INST(need_load_from_const_mem) \ + template void megdnn::cuda::cutlass_wrapper:: \ + do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ + need_load_from_const_mem>( \ + const uint8_t* d_src, const int8_t* d_filter, \ + const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ + int* workspace, const convolution::ConvParam& param, \ + uint32_t nonlinear_mode, float alpha, float beta, \ + float gamma, float delta, float theta, float scale, \ + uint8_t src_zero_point, \ + const GemmCoord& threadblock_shape, \ + const GemmCoord& warp_shape, int stages, \ + cudaStream_t stream); +INST(true); +#undef INST + +/* ====== cutlass kernel wrapper for int4 x int4 nhwc layout ====== */ + +#if MEGDNN_TEGRA_X1 +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( + const int8_t* /* d_src */, const int8_t* /* d_filter */, + const int32_t* /* d_bias */, const int8_t* /* d_z */, + int8_t* /* d_dst */, int* /* workspace */, + const convolution::ConvParam& /* param */, + uint32_t /* nonlinear_mode */, float /* alpha */, + float /* beta */, float /* gamma */, float /* scale */, + const GemmCoord& /* threadblock_shape */, + const GemmCoord& /* warp_shape */, + const int32_t /* access_size */, int /* stages */, + cudaStream_t /* stream */) {} +#else +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( + const int8_t* d_src, const int8_t* d_filter, + const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, + int* workspace, const convolution::ConvParam& param, + uint32_t nonlinear_mode, float alpha, float beta, float gamma, + float scale, const GemmCoord& threadblock_shape, + const GemmCoord& warp_shape, const int32_t access_size, + int stages, cudaStream_t stream) { + bool without_shared_load = + ((param.co % threadblock_shape.n() == 0) && + (threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); + int out_elements_per_access = + without_shared_load ? threadblock_shape.n() / 4 : 8; + +#define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ + using Convolution = cutlass::conv::device::Convolution< \ + cutlass::int4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ + cutlass::layout::TensorNCxHWx, ElementOutput, \ + cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ + int32_t, cutlass::conv::ConvType::kConvolution, \ + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ + cutlass::conv::threadblock:: \ + ConvolutionFpropTransThreadblockSwizzle, \ + stage_, access_size_, access_size_, NeedLoadFromConstMem, \ + cutlass::arch::OpMultiplyAddSaturate, \ + cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ + typename Convolution::ConvolutionParameter conv_param( \ + param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ + param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ + param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ + return cutlass_convolution_wrapper( \ + reinterpret_cast(d_src), \ + reinterpret_cast(d_filter), d_bias, \ + reinterpret_cast(d_z), \ + reinterpret_cast(d_dst), workspace, conv_param, \ + epilogue, stream); +#define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ + threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ + warp_k_, stage_, access_size_, out_elements_per_access_, \ + without_shared_load_) \ + if (threadblock_shape.m() == threadblock_m_ && \ + threadblock_shape.n() == threadblock_n_ && \ + threadblock_shape.k() == threadblock_k_ && \ + warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ + warp_shape.k() == warp_k_ && stages == stage_ && \ + access_size == access_size_ && \ + out_elements_per_access == out_elements_per_access_ && \ + without_shared_load == without_shared_load_) { \ + using ThreadBlockShape = \ + cutlass::gemm::GemmShape; \ + using WarpShape = cutlass::gemm::GemmShape; \ + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ + using ElementOutput = cutlass::int4b_t; \ + using ElementAccumulator = int32_t; \ + using ElementBias = int32_t; \ + using ElementCompute = float; \ + using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ + switch (nonlinear_mode) { \ + case NonlineMode::IDENTITY: { \ + using EpilogueOp = cutlass::epilogue::thread:: \ + BiasAddLinearCombinationClamp< \ + ElementOutput, out_elements_per_access_, \ + ElementAccumulator, ElementBias, \ + ElementCompute>; \ + typename EpilogueOp::Params epilogue{alpha, beta, gamma}; \ + RUN_CUTLASS_WRAPPER(stage_, access_size_, \ + without_shared_load_); \ + } \ + case NonlineMode::RELU: { \ + using EpilogueOp = cutlass::epilogue::thread:: \ + BiasAddLinearCombinationReluClamp< \ + ElementOutput, out_elements_per_access_, \ + ElementAccumulator, ElementBias, \ + ElementCompute>; \ + typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; \ + RUN_CUTLASS_WRAPPER(stage_, access_size_, \ + without_shared_load_); \ + } \ + case NonlineMode::H_SWISH: { \ + using EpilogueOp = cutlass::epilogue::thread:: \ + BiasAddLinearCombinationHSwishClamp< \ + ElementOutput, out_elements_per_access_, \ + ElementAccumulator, ElementBias, \ + ElementCompute>; \ + typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ + scale}; \ + RUN_CUTLASS_WRAPPER(stage_, access_size_, \ + without_shared_load_); \ + } \ + default: \ + megdnn_assert( \ + false, \ + "unsupported nonlinear mode for conv bias operator"); \ + } \ + } +#define DISPATCH_KERNEL \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ + megdnn_assert(false, \ + "unsupported threadblock shape (%dx%dx%d) and warp shape " \ + "(%dx%dx%d) and access_size (%d)", \ + threadblock_shape.m(), threadblock_shape.n(), \ + threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ + warp_shape.k(), access_size); + DISPATCH_KERNEL; + +#undef RUN_CUTLASS_WRAPPER +#undef DISPATCH_KERNEL_WITH_TILE_SHAPE +#undef DISPATCH_KERNEL +} +#endif + +#define INST(need_load_from_const_mem) \ + template void megdnn::cuda::cutlass_wrapper:: \ + do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ + need_load_from_const_mem>( \ + const int8_t* d_src, const int8_t* d_filter, \ + const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ + int* workspace, const convolution::ConvParam& param, \ + uint32_t nonlinear_mode, float alpha, float beta, \ + float gamma, float scale, \ + const GemmCoord& threadblock_shape, \ + const GemmCoord& warp_shape, const int32_t access_size, \ + int stages, cudaStream_t stream); +INST(true); +INST(false); +#undef INST + +/* ====== cutlass kernel wrapper for uint4 x int4 nhwc layout ====== */ + +#if MEGDNN_TEGRA_X1 +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( + const uint8_t* /* d_src */, const int8_t* /* d_filter */, + const int32_t* /* d_bias */, const uint8_t* /* d_z */, + uint8_t* /* d_dst */, int* /* workspace */, + const convolution::ConvParam& /* param */, + uint32_t /* nonlinear_mode */, float /* alpha */, + float /* beta */, float /* gamma */, float /* delta */, + float /* theta */, float /* scale */, + uint8_t /* src_zero_point */, + const GemmCoord& /* threadblock_shape */, + const GemmCoord& /* warp_shape */, + const int32_t /* access_size */, int /* stages */, + cudaStream_t /* stream */) {} +#else +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( + const uint8_t* d_src, const int8_t* d_filter, + const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, + int* workspace, const convolution::ConvParam& param, + uint32_t nonlinear_mode, float alpha, float beta, float gamma, + float delta, float theta, float /* scale */, + uint8_t src_zero_point, const GemmCoord& threadblock_shape, + const GemmCoord& warp_shape, const int32_t access_size, + int stages, cudaStream_t stream) { + bool without_shared_load = + ((param.co % threadblock_shape.n() == 0) && + (threadblock_shape.n() == 32 || threadblock_shape.n() == 64)); + int out_elements_per_access = + without_shared_load ? threadblock_shape.n() / 4 : 8; + +#define RUN_CUTLASS_WRAPPER(stage_, access_size_, without_shared_load_) \ + using Convolution = cutlass::conv::device::Convolution< \ + cutlass::uint4b_t, cutlass::layout::TensorNHWC, cutlass::int4b_t, \ + cutlass::layout::TensorNCxHWx, ElementOutput, \ + cutlass::layout::TensorNHWC, int32_t, cutlass::layout::TensorNHWC, \ + int32_t, cutlass::conv::ConvType::kConvolution, \ + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ + cutlass::conv::threadblock:: \ + ConvolutionFpropTransThreadblockSwizzle, \ + stage_, access_size_, access_size_, NeedLoadFromConstMem, \ + cutlass::arch::OpMultiplyAddSaturate, \ + cutlass::conv::ImplicitGemmMode::GEMM_TN, without_shared_load_>; \ + typename Convolution::ConvolutionParameter conv_param( \ + param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ + param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ + param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ + return cutlass_convolution_wrapper( \ + reinterpret_cast(d_src), \ + reinterpret_cast(d_filter), d_bias, \ + reinterpret_cast(d_z), \ + reinterpret_cast(d_dst), workspace, \ + conv_param, epilogue, stream, {src_zero_point}); + +#define DISPATCH_KERNEL_WITH_TILE_SHAPE( \ + threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ + warp_k_, stage_, access_size_, out_elements_per_access_, \ + without_shared_load_) \ + if (threadblock_shape.m() == threadblock_m_ && \ + threadblock_shape.n() == threadblock_n_ && \ + threadblock_shape.k() == threadblock_k_ && \ + warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ + warp_shape.k() == warp_k_ && stages == stage_ && \ + access_size == access_size_ && \ + out_elements_per_access == out_elements_per_access_ && \ + without_shared_load == without_shared_load_) { \ + using ThreadBlockShape = \ + cutlass::gemm::GemmShape; \ + using WarpShape = cutlass::gemm::GemmShape; \ + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ + using ElementOutput = cutlass::uint4b_t; \ + using ElementAccumulator = int32_t; \ + using ElementBias = int32_t; \ + using ElementCompute = float; \ + using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; \ + switch (nonlinear_mode) { \ + case NonlineMode::IDENTITY: { \ + using EpilogueOp = cutlass::epilogue::thread:: \ + BiasAddLinearCombinationClamp< \ + ElementOutput, out_elements_per_access_, \ + ElementAccumulator, ElementBias, \ + ElementCompute>; \ + typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ + delta + theta}; \ + RUN_CUTLASS_WRAPPER(stage_, access_size_, \ + without_shared_load_); \ + } \ + case NonlineMode::RELU: { \ + using EpilogueOp = cutlass::epilogue::thread:: \ + BiasAddLinearCombinationReluClamp< \ + ElementOutput, out_elements_per_access_, \ + ElementAccumulator, ElementBias, \ + ElementCompute>; \ + typename EpilogueOp::Params epilogue{alpha, beta, gamma, \ + 0, delta, theta}; \ + RUN_CUTLASS_WRAPPER(stage_, access_size_, \ + without_shared_load_); \ + } \ + default: \ + megdnn_assert( \ + false, \ + "unsupported nonlinear mode for conv bias operator"); \ + } \ + } +#define DISPATCH_KERNEL \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, false); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, false); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, false); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 8, false); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 8, false); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 8, false); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 32, 8, true); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 16, 8, true); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 1, 8, 8, true); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 32, 16, true); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 16, 16, true); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 1, 8, 16, true); \ + megdnn_assert(false, \ + "unsupported threadblock shape (%dx%dx%d) and warp shape " \ + "(%dx%dx%d) and access_size (%d)", \ + threadblock_shape.m(), threadblock_shape.n(), \ + threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ + warp_shape.k(), access_size); + + DISPATCH_KERNEL; + +#undef RUN_CUTLASS_WRAPPER +#undef DISPATCH_KERNEL_WITH_TILE_SHAPE +#undef DISPATCH_KERNEL +} +#endif + +#define INST(need_load_from_const_mem) \ + template void megdnn::cuda::cutlass_wrapper:: \ + do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ + need_load_from_const_mem>( \ + const uint8_t* d_src, const int8_t* d_filter, \ + const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ + int* workspace, const convolution::ConvParam& param, \ + uint32_t nonlinear_mode, float alpha, float beta, \ + float gamma, float delta, float theta, float scale, \ + uint8_t src_zero_point, \ + const GemmCoord& threadblock_shape, \ + const GemmCoord& warp_shape, const int32_t access_size, \ + int stages, cudaStream_t stream); +INST(true); +INST(false); +#undef INST + +// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int8.cu similarity index 58% rename from dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu rename to dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int8.cu index e77a599ab..208e080bf 100644 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper_int8.cu @@ -38,7 +38,8 @@ void megdnn::cuda::cutlass_wrapper:: uint32_t /* nonlinear_mode */, float /* alpha */, float /* beta */, float /* gamma */, float /* scale */, const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} + const GemmCoord& /* warp_shape */, int /* stages */, + cudaStream_t /* stream */) {} #else template void megdnn::cuda::cutlass_wrapper:: @@ -48,15 +49,15 @@ void megdnn::cuda::cutlass_wrapper:: int* workspace, const convolution::ConvParam& param, uint32_t nonlinear_mode, float alpha, float beta, float gamma, float scale, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, cudaStream_t stream) { + const GemmCoord& warp_shape, int stages, cudaStream_t stream) { #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ threadblock_k_, warp_m_, warp_n_, \ - warp_k_) \ + warp_k_, stage_) \ if (threadblock_shape.m() == threadblock_m_ && \ threadblock_shape.n() == threadblock_n_ && \ threadblock_shape.k() == threadblock_k_ && \ warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_) { \ + warp_shape.k() == warp_k_ && stages == stage_) { \ using ThreadBlockShape = \ cutlass::gemm::GemmShape; \ @@ -71,8 +72,10 @@ void megdnn::cuda::cutlass_wrapper:: cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ cutlass::conv::threadblock:: \ - ConvolutionFpropNCxHWxThreadblockSwizzle, \ - 2, 16, 16, NeedLoadFromConstMem>; \ + ConvolutionFpropTransThreadblockSwizzle, \ + stage_, 16, 16, NeedLoadFromConstMem, \ + cutlass::arch::OpMultiplyAddSaturate, \ + cutlass::conv::ImplicitGemmMode::GEMM_TN, true>; \ typename Convolution::ConvolutionParameter conv_param( \ param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ @@ -82,13 +85,15 @@ void megdnn::cuda::cutlass_wrapper:: epilogue, stream); \ } #define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 32, 16, 64); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ megdnn_assert(false, \ "unsupported threadblock shape (%dx%dx%d) and warp shape " \ "(%dx%dx%d)", \ @@ -144,7 +149,8 @@ void megdnn::cuda::cutlass_wrapper:: uint32_t nonlinear_mode, float alpha, float beta, \ float gamma, float scale, \ const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, cudaStream_t stream); + const GemmCoord& warp_shape, int stages, \ + cudaStream_t stream); INST(true); INST(false); #undef INST @@ -162,7 +168,8 @@ void megdnn::cuda::cutlass_wrapper:: uint32_t /* nonlinear_mode */, float /* alpha */, float /* beta */, float /* gamma */, float /* scale */, const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} + const GemmCoord& /* warp_shape */, int /* stages */, + cudaStream_t /* stream */) {} #else template void megdnn::cuda::cutlass_wrapper:: @@ -172,15 +179,15 @@ void megdnn::cuda::cutlass_wrapper:: int* workspace, const convolution::ConvParam& param, uint32_t nonlinear_mode, float alpha, float beta, float gamma, float scale, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, cudaStream_t stream) { + const GemmCoord& warp_shape, int stages, cudaStream_t stream) { #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ threadblock_k_, warp_m_, warp_n_, \ - warp_k_) \ + warp_k_, stage_) \ if (threadblock_shape.m() == threadblock_m_ && \ threadblock_shape.n() == threadblock_n_ && \ threadblock_shape.k() == threadblock_k_ && \ warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_) { \ + warp_shape.k() == warp_k_ && stages == stage_) { \ using ThreadBlockShape = \ cutlass::gemm::GemmShape; \ @@ -196,7 +203,7 @@ void megdnn::cuda::cutlass_wrapper:: ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ cutlass::conv::threadblock:: \ ConvolutionFpropNCxHWxThreadblockSwizzle, \ - 2, 16, 16, NeedLoadFromConstMem>; \ + stage_, 16, 16, NeedLoadFromConstMem>; \ typename Convolution::ConvolutionParameter conv_param( \ param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ @@ -206,13 +213,15 @@ void megdnn::cuda::cutlass_wrapper:: epilogue, stream); \ } #define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 16, 32, 64); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64, 2); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 1); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 1); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 32, 64, 32, 1); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 1); \ megdnn_assert(false, \ "unsupported threadblock shape (%dx%dx%d) and warp shape " \ "(%dx%dx%d)", \ @@ -268,7 +277,8 @@ void megdnn::cuda::cutlass_wrapper:: uint32_t nonlinear_mode, float alpha, float beta, \ float gamma, float scale, \ const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, cudaStream_t stream); + const GemmCoord& warp_shape, int stages, \ + cudaStream_t stream); INST(true); INST(false); #undef INST @@ -337,10 +347,8 @@ void megdnn::cuda::cutlass_wrapper:: DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ megdnn_assert(false, \ @@ -468,10 +476,8 @@ void megdnn::cuda::cutlass_wrapper:: DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ megdnn_assert(false, \ @@ -599,10 +605,8 @@ void megdnn::cuda::cutlass_wrapper:: DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ megdnn_assert(false, \ "unsupported threadblock shape (%dx%dx%d) and warp shape " \ "(%dx%dx%d)", \ @@ -664,246 +668,6 @@ INST(true); INST(false); #undef INST -/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( - const int8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const int8_t* /* d_z */, - int8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* scale */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64( - const int8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float scale, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ - using Convolution = cutlass::conv::device::Convolution< \ - cutlass::int4b_t, cutlass::layout::TensorNCxHWx<64>, \ - cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ - ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ - cutlass::layout::TensorNCxHWx<64>, int32_t, \ - cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropNCxHWxThreadblockSwizzle, \ - 2, 32, 32, NeedLoadFromConstMem>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - reinterpret_cast(d_src), \ - reinterpret_cast(d_filter), d_bias, \ - reinterpret_cast(d_z), \ - reinterpret_cast(d_dst), workspace, \ - conv_param, epilogue, stream); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k()); - using ElementOutput = cutlass::int4b_t; - using ElementAccumulator = int32_t; - using ElementBias = int32_t; - using ElementCompute = float; - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; - switch (nonlinear_mode) { - case NonlineMode::IDENTITY: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationClamp< - ElementOutput, 16, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma}; - DISPATCH_KERNEL; - } - case NonlineMode::RELU: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationReluClamp< - ElementOutput, 16, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; - DISPATCH_KERNEL; - } - case NonlineMode::H_SWISH: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationHSwishClamp< - ElementOutput, 16, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; - DISPATCH_KERNEL; - } - default: - megdnn_assert(false, - "unsupported nonlinear mode for conv bias operator"); - } -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_int4_int4_implicit_gemm_imma_ncdiv64hw64< \ - need_load_from_const_mem>( \ - const int8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float scale, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, cudaStream_t stream); -INST(true); -#undef INST - -/* ====== cutlass kernel wrapper for uint4 x int4 nchw64 layout ====== */ - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( - const uint8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const uint8_t* /* d_z */, - uint8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* delta */, - float /* theta */, float /* scale */, - uint8_t /* src_zero_point */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( - const uint8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float delta, float theta, float /* scale */, - uint8_t src_zero_point, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ - using Convolution = cutlass::conv::device::Convolution< \ - cutlass::uint4b_t, cutlass::layout::TensorNCxHWx<64>, \ - cutlass::int4b_t, cutlass::layout::TensorCxRSKx<64>, \ - ElementOutput, cutlass::layout::TensorNCxHWx<64>, int32_t, \ - cutlass::layout::TensorNCxHWx<64>, int32_t, \ - cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropNCxHWxThreadblockSwizzle, \ - 2, 32, 32, NeedLoadFromConstMem>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - reinterpret_cast(d_src), \ - reinterpret_cast(d_filter), d_bias, \ - reinterpret_cast(d_z), \ - reinterpret_cast(d_dst), workspace, \ - conv_param, epilogue, stream, {src_zero_point}); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 128, 64, 64, 128); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 128, 64, 64, 128); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k()); - using ElementOutput = cutlass::uint4b_t; - using ElementAccumulator = int32_t; - using ElementBias = int32_t; - using ElementCompute = float; - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; - switch (nonlinear_mode) { - case NonlineMode::IDENTITY: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationClamp< - ElementOutput, 16, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, - delta + theta}; - DISPATCH_KERNEL; - } - case NonlineMode::RELU: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationReluClamp< - ElementOutput, 16, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, - 0, delta, theta}; - DISPATCH_KERNEL; - } - default: - megdnn_assert(false, - "unsupported nonlinear mode for conv bias operator"); - } -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64< \ - need_load_from_const_mem>( \ - const uint8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float delta, float theta, float scale, \ - uint8_t src_zero_point, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, cudaStream_t stream); -INST(true); -#undef INST - /* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ #if MEGDNN_TEGRA_X1 template @@ -970,10 +734,8 @@ void megdnn::cuda::cutlass_wrapper:: DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ megdnn_assert(false, \ @@ -1039,262 +801,4 @@ INST(true); INST(false); #undef INST -/* ====== cutlass kernel wrapper for int4 x int4 nchw64 layout ====== */ - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( - const int8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const int8_t* /* d_z */, - int8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* scale */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, - const int32_t /* access_size */, cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( - const int8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float scale, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, const int32_t access_size, - cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_, access_size_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && access_size == access_size_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ - using Convolution = cutlass::conv::device::Convolution< \ - cutlass::int4b_t, cutlass::layout::TensorNHWC, \ - cutlass::int4b_t, cutlass::layout::TensorNCxHWx, \ - ElementOutput, cutlass::layout::TensorNHWC, int32_t, \ - cutlass::layout::TensorNHWC, int32_t, \ - cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropNHWCThreadblockSwizzle, \ - 2, access_size_, access_size_, NeedLoadFromConstMem, \ - cutlass::arch::OpMultiplyAddSaturate, \ - cutlass::conv::ImplicitGemmMode::GEMM_TN>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - reinterpret_cast(d_src), \ - reinterpret_cast(d_filter), d_bias, \ - reinterpret_cast(d_z), \ - reinterpret_cast(d_dst), workspace, \ - conv_param, epilogue, stream); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 32); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 8); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 32); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 8); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d) and access_size (%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k(), access_size); - using ElementOutput = cutlass::int4b_t; - using ElementAccumulator = int32_t; - using ElementBias = int32_t; - using ElementCompute = float; - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; - switch (nonlinear_mode) { - case NonlineMode::IDENTITY: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationClamp< - ElementOutput, 8, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma}; - DISPATCH_KERNEL; - } - case NonlineMode::RELU: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationReluClamp< - ElementOutput, 8, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; - DISPATCH_KERNEL; - } - case NonlineMode::H_SWISH: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationHSwishClamp< - ElementOutput, 8, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; - DISPATCH_KERNEL; - } - default: - megdnn_assert(false, - "unsupported nonlinear mode for conv bias operator"); - } -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_int4_int4_implicit_gemm_imma_nhwc< \ - need_load_from_const_mem>( \ - const int8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float scale, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, const int32_t access_size, \ - cudaStream_t stream); -INST(true); -INST(false); -#undef INST - -#if MEGDNN_TEGRA_X1 -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( - const uint8_t* /* d_src */, const int8_t* /* d_filter */, - const int32_t* /* d_bias */, const uint8_t* /* d_z */, - uint8_t* /* d_dst */, int* /* workspace */, - const convolution::ConvParam& /* param */, - uint32_t /* nonlinear_mode */, float /* alpha */, - float /* beta */, float /* gamma */, float /* delta */, - float /* theta */, float /* scale */, - uint8_t /* src_zero_point */, - const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, - const int32_t /* access_size */, cudaStream_t /* stream */) {} -#else -template -void megdnn::cuda::cutlass_wrapper:: - do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( - const uint8_t* d_src, const int8_t* d_filter, - const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, - int* workspace, const convolution::ConvParam& param, - uint32_t nonlinear_mode, float alpha, float beta, float gamma, - float delta, float theta, float /* scale */, - uint8_t src_zero_point, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, const int32_t access_size, - cudaStream_t stream) { -#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ - threadblock_k_, warp_m_, warp_n_, \ - warp_k_, access_size_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_ && access_size == access_size_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 32>; \ - using Convolution = cutlass::conv::device::Convolution< \ - cutlass::uint4b_t, cutlass::layout::TensorNHWC, \ - cutlass::int4b_t, cutlass::layout::TensorNCxHWx, \ - ElementOutput, cutlass::layout::TensorNHWC, int32_t, \ - cutlass::layout::TensorNHWC, int32_t, \ - cutlass::conv::ConvType::kConvolution, \ - cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ - ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ - cutlass::conv::threadblock:: \ - ConvolutionFpropNHWCThreadblockSwizzle, \ - 2, access_size_, access_size_, NeedLoadFromConstMem, \ - cutlass::arch::OpMultiplyAddSaturate, \ - cutlass::conv::ImplicitGemmMode::GEMM_TN>; \ - typename Convolution::ConvolutionParameter conv_param( \ - param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ - param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ - param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ - return cutlass_convolution_wrapper( \ - reinterpret_cast(d_src), \ - reinterpret_cast(d_filter), d_bias, \ - reinterpret_cast(d_z), \ - reinterpret_cast(d_dst), workspace, \ - conv_param, epilogue, stream, {src_zero_point}); \ - } -#define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 32); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 64, 64, 32, 64, 8); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 32); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 64, 64, 8); \ - megdnn_assert(false, \ - "unsupported threadblock shape (%dx%dx%d) and warp shape " \ - "(%dx%dx%d) and access_size (%d)", \ - threadblock_shape.m(), threadblock_shape.n(), \ - threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ - warp_shape.k(), access_size); - using ElementOutput = cutlass::uint4b_t; - using ElementAccumulator = int32_t; - using ElementBias = int32_t; - using ElementCompute = float; - using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; - switch (nonlinear_mode) { - case NonlineMode::IDENTITY: { - using EpilogueOp = - cutlass::epilogue::thread::BiasAddLinearCombinationClamp< - ElementOutput, 8, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, - delta + theta}; - DISPATCH_KERNEL; - } - case NonlineMode::RELU: { - using EpilogueOp = cutlass::epilogue::thread:: - BiasAddLinearCombinationReluClamp< - ElementOutput, 8, ElementAccumulator, ElementBias, - ElementCompute>; - typename EpilogueOp::Params epilogue{alpha, beta, gamma, - 0, delta, theta}; - DISPATCH_KERNEL; - } - default: - megdnn_assert(false, - "unsupported nonlinear mode for conv bias operator"); - } -#undef DISPATCH_KERNEL_WITH_TILE_SHAPE -#undef DISPATCH_KERNEL -} -#endif - -#define INST(need_load_from_const_mem) \ - template void megdnn::cuda::cutlass_wrapper:: \ - do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc< \ - need_load_from_const_mem>( \ - const uint8_t* d_src, const int8_t* d_filter, \ - const int32_t* d_bias, const uint8_t* d_z, uint8_t* d_dst, \ - int* workspace, const convolution::ConvParam& param, \ - uint32_t nonlinear_mode, float alpha, float beta, \ - float gamma, float delta, float theta, float scale, \ - uint8_t src_zero_point, \ - const GemmCoord& threadblock_shape, \ - const GemmCoord& warp_shape, const int32_t access_size, \ - cudaStream_t stream); -INST(true); -INST(false); -#undef INST - // vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu b/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu new file mode 100644 index 000000000..2f296aca3 --- /dev/null +++ b/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu @@ -0,0 +1,194 @@ +/** + * \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" +#include "src/cuda/query_blocksize.cuh" +#include "src/cuda/integer_subbyte_utils.cuh" + +using namespace megdnn; +using namespace cuda; +using namespace cutlass_wrapper; + +namespace { +template +__device__ __forceinline__ void reorder_ncxhwx_imma_filter_func( + int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, + uint32_t FW, uint32_t lane, bool trans_oc) { + static constexpr uint32_t elements_per_lane = 128 / size_bits; + static constexpr uint32_t threads_per_interleaved = + interleaved / elements_per_lane; + static constexpr uint32_t instruction_shape_col = 8; + // 4 threads per Quad + static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; + // 4 threads per Quad + static constexpr uint32_t reordered_elements_per_thread = interleaved / 4; + + uint32_t id = lane / threads_per_interleaved; + uint32_t residue = lane % threads_per_interleaved; + uint32_t ICx = IC / interleaved; + uint32_t row = id / (ICx * FH * FW); + uint32_t col = id - row * ICx * FH * FW; + // transpose ncxhwx to cxhwnx + uint32_t src_offset = id * interleaved + residue * elements_per_lane; + + row = (trans_oc) ? (row / interleaved) * interleaved + + ((row % reordered_elements_per_thread) / + elements_per_thread) * + instruction_shape_col + + ((row % interleaved) / + reordered_elements_per_thread) * + elements_per_thread + + (row % elements_per_thread) + : row; + + uint32_t dst_offset = + (col * OC + row) * interleaved + residue * elements_per_lane; + + *(reinterpret_cast(dst + dst_offset * size_bits / 8)) = + *(reinterpret_cast(src + src_offset * size_bits / 8)); +} + +template +__global__ void reorder_ncxhwx_imma_filter_kernel( + int8_t* __restrict__ dst_filter, const int8_t* __restrict__ src_filter, + uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc) { + static constexpr uint32_t elements_per_lane = 128 / size_bits; + const uint32_t size = OC * IC * FH * FW / elements_per_lane; + uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; + if (lane < size) { + reorder_ncxhwx_imma_filter_func( + dst_filter, src_filter, OC, IC, FH, FW, lane, trans_oc); + } +} + +template +__device__ __forceinline__ void reorder_nhwc_imma_filter_func( + int8_t* dst, const int8_t* src, uint32_t OC, uint32_t IC, uint32_t FH, + uint32_t FW, uint32_t lane, bool trans_oc) { + static constexpr uint32_t elements_per_access = alignbits / size_bits; + static constexpr uint32_t instruction_shape_col = 8; + // 4 threads per Quad + static constexpr uint32_t elements_per_thread = instruction_shape_col / 4; + // 4 threads per Quad + static constexpr uint32_t reordered_elements_per_thread = interleaved / 4; + uint32_t ICx = IC / elements_per_access; + uint32_t k = lane / (ICx * FH * FW); + uint32_t cxrs = lane - k * ICx * FH * FW; + uint32_t rs = cxrs / ICx; + uint32_t cx = cxrs - rs * ICx; + // transpose nhwc to ncxhwx + uint32_t src_offset = lane * elements_per_access; + // reorder k + k = (trans_oc) + ? (k / interleaved) * interleaved + + ((k % reordered_elements_per_thread) / + elements_per_thread) * + instruction_shape_col + + ((k % interleaved) / reordered_elements_per_thread) * + elements_per_thread + + (k % elements_per_thread) + : k; + uint32_t dst_offset = + (k * ICx * FH * FW + cx * FH * FW + rs) * elements_per_access; + + if (alignbits == 32) { + *(reinterpret_cast(dst + dst_offset * size_bits / 8)) = *( + reinterpret_cast(src + src_offset * size_bits / 8)); + } else if (alignbits == 64) { + *(reinterpret_cast(dst + dst_offset * size_bits / 8)) = + *(reinterpret_cast(src + + src_offset * size_bits / 8)); + } else { + *(reinterpret_cast(dst + dst_offset * size_bits / 8)) = + *(reinterpret_cast(src + + src_offset * size_bits / 8)); + } +} + +template +__global__ void reorder_nhwc_imma_filter_kernel( + int8_t* __restrict__ dst_filter, const int8_t* __restrict__ src_filter, + uint32_t OC, uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc) { + static constexpr uint32_t elements_per_access = alignbits / size_bits; + const uint32_t size = OC * IC * FH * FW / elements_per_access; + uint32_t lane = threadIdx.x + blockIdx.x * blockDim.x; + if (lane < size) { + reorder_nhwc_imma_filter_func( + dst_filter, src_filter, OC, IC, FH, FW, lane, trans_oc); + } +} +} // namespace + +template +void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter( + int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, + uint32_t FH, uint32_t FW, bool trans_oc, cudaStream_t stream) { + static constexpr uint32_t elements_per_lane = 128 / size_bits; + uint32_t nr_threads = + query_blocksize_for_kernel(reinterpret_cast( + reorder_ncxhwx_imma_filter_kernel)); + uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_lane); + nr_threads = std::min(nr_threads, vthreads); + uint32_t nr_blocks = DIVUP(vthreads, nr_threads); + reorder_ncxhwx_imma_filter_kernel + <<>>(dst_filter, src_filter, OC, + IC, FH, FW, trans_oc); + after_kernel_launch(); +} + +template +void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter( + int8_t* dst_filter, const int8_t* src_filter, uint32_t OC, uint32_t IC, + uint32_t FH, uint32_t FW, bool trans_oc, uint32_t oc_interleaved, + cudaStream_t stream) { + static constexpr uint32_t elements_per_access = alignbits / size_bits; + uint32_t nr_threads = + query_blocksize_for_kernel(reinterpret_cast( + reorder_nhwc_imma_filter_kernel)); + uint32_t vthreads = DIVUP(OC * IC * FH * FW, elements_per_access); + nr_threads = std::min(nr_threads, vthreads); + uint32_t nr_blocks = DIVUP(vthreads, nr_threads); + if (oc_interleaved == 32) { + reorder_nhwc_imma_filter_kernel + <<>>( + dst_filter, src_filter, OC, IC, FH, FW, trans_oc); + } else { + reorder_nhwc_imma_filter_kernel + <<>>( + dst_filter, src_filter, OC, IC, FH, FW, trans_oc); + } + after_kernel_launch(); +} + +#define INST(_size_bits, _interleaved) \ + template void megdnn::cuda::cutlass_wrapper::reorder_ncxhwx_imma_filter< \ + _size_bits, _interleaved>(int8_t * dst_filter, \ + const int8_t* src_filter, uint32_t OC, \ + uint32_t IC, uint32_t FH, uint32_t FW, \ + bool trans_oc, cudaStream_t stream); + +INST(8, 32) +INST(4, 64) +#undef INST + +#define INST(_size_bits, _alignbits) \ + template void megdnn::cuda::cutlass_wrapper::reorder_nhwc_imma_filter< \ + _size_bits, _alignbits>( \ + int8_t * dst_filter, const int8_t* src_filter, uint32_t OC, \ + uint32_t IC, uint32_t FH, uint32_t FW, bool trans_oc, \ + uint32_t oc_interleaved, cudaStream_t stream); +INST(4, 32) +INST(4, 64) +INST(4, 128) +#undef INST + +// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh b/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh new file mode 100644 index 000000000..dc46f9c60 --- /dev/null +++ b/dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh @@ -0,0 +1,33 @@ +/** + * \file dnn/src/cuda/conv_bias/cutlass_reorder_filter.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#pragma once +#include "src/cuda/utils.cuh" + +namespace megdnn { +namespace cuda { +namespace cutlass_wrapper { + +template +void reorder_ncxhwx_imma_filter(int8_t* dst_filter, const int8_t* src_filter, + uint32_t OC, uint32_t IC, uint32_t FH, + uint32_t FW, bool trans_oc, + cudaStream_t stream); + +template +void reorder_nhwc_imma_filter(int8_t* dst_filter, const int8_t* src_filter, + uint32_t OC, uint32_t IC, uint32_t FH, + uint32_t FW, bool trans_oc, + uint32_t oc_interleaved, cudaStream_t stream); +} // namespace cutlass_wrapper +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp index d2a8d0c38..47a659350 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nchw64_imma.cpp @@ -102,7 +102,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NCHW64IMMAImplicitGemm::do_exec( reinterpret_cast(z_ptr), reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, - threadblock_shape, warp_shape, stream); + threadblock_shape, warp_shape, m_algo_param.stage, stream); } #endif diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp index 5a72afc98..30797287a 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_int4_nhwc_imma.cpp @@ -104,7 +104,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, threadblock_shape, warp_shape, m_algo_param.access_size, - stream); + m_algo_param.stage, stream); } else { cutlass_wrapper::do_conv_bias_int4_int4_implicit_gemm_imma_nhwc( reinterpret_cast(args.src_tensor->raw_ptr), @@ -114,7 +114,7 @@ void ConvBiasForwardImpl::AlgoInt4Int4NHWCIMMAImplicitGemm::do_exec( reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, threadblock_shape, warp_shape, m_algo_param.access_size, - stream); + m_algo_param.stage, stream); } } #endif diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp index d99bf90a8..eeb01c62a 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_nchw64_imma_base.cpp @@ -12,6 +12,7 @@ #include "./algo.h" #include "src/common/conv_bias.h" +#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" #include "src/cuda/conv_bias/reduce_filter.cuh" #include "src/cuda/convolution_helper/parameter.cuh" @@ -121,41 +122,26 @@ void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::exec( std::string ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::to_string( AlgoParam algo_param) { - return ssprintf("%dX%dX%d_%dX%dX%d", algo_param.threadblock_m, + return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m, algo_param.threadblock_n, algo_param.threadblock_k, - algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); + algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, + algo_param.stage); } void ConvBiasForwardImpl::AlgoInt4NCHW64IMMAImplicitGemmBase::reorder_filter( const ExecArgs& args, void* reordered_filter) const { - auto&& param = args.opr->param(); + size_t ci = args.src_layout->operator[](1) * 64; + size_t co = args.dst_layout->operator[](1) * 64; auto&& fm = args.filter_meta; - size_t n = args.src_layout->operator[](0), - ci = args.src_layout->operator[](1) * 64, - hi = args.src_layout->operator[](2), - wi = args.src_layout->operator[](3); - size_t co = args.dst_layout->operator[](1) * 64, - ho = args.dst_layout->operator[](2), - wo = args.dst_layout->operator[](3); - UNPACK_CONV_PARAMETER(fm, param); - MARK_USED_VAR; - - // filter: KCRS64 => CRSK64 - TensorLayout src{{co, ci / 64, fh, fw, 64}, dtype::QuantizedS4()}; - src.init_contiguous_stride(); - TensorLayout dst = src; - dst.stride[0] = 64; - dst.stride[1] = co * fh * fw * 64; - dst.stride[2] = co * fw * 64; - dst.stride[3] = co * 64; - dst.stride[4] = 1; - TensorND ts_src, ts_dst; - ts_src.raw_ptr = args.filter_tensor->raw_ptr; - ts_src.layout = src; - ts_dst.raw_ptr = reordered_filter; - ts_dst.layout = dst; - auto&& transpose = args.opr->handle()->create_operator(); - transpose->exec(ts_src, ts_dst); + size_t fh = fm.spatial[0], fw = fm.spatial[1]; + + cudaStream_t stream = cuda_stream(args.opr->handle()); + + // filter: KCRS64 => CRSK64 and reorder oc + cutlass_wrapper::reorder_ncxhwx_imma_filter<4, 64>( + reinterpret_cast(reordered_filter), + reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, fh, + fw, true, stream); } #endif diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp index 21aca9255..9cb11de60 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int4_nhwc_imma_base.cpp @@ -12,6 +12,7 @@ #include "./algo.h" #include "src/common/conv_bias.h" +#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" #include "src/cuda/conv_bias/reduce_filter.cuh" #include "src/cuda/convolution_helper/parameter.cuh" @@ -128,10 +129,10 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::exec( std::string ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::to_string( AlgoParam algo_param) { - return ssprintf("%dX%dX%d_%dX%dX%d_%d", algo_param.threadblock_m, + return ssprintf("%dX%dX%d_%dX%dX%d_%d_%d", algo_param.threadblock_m, algo_param.threadblock_n, algo_param.threadblock_k, algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, - algo_param.access_size); + algo_param.stage, algo_param.access_size); } void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( @@ -142,17 +143,32 @@ void ConvBiasForwardImpl::AlgoInt4NHWCIMMAImplicitGemmBase::reorder_filter( fh = args.filter_layout->operator[](1), fw = args.filter_layout->operator[](2); - // reformat grad from nhwc to ncxhwx - TensorLayout exec_src{{co, fh, fw, ci / iterleaved, (size_t)iterleaved / 2}, - dtype::Int8()}; - TensorLayout exec_dst{{co, ci / iterleaved, fh, fw, (size_t)iterleaved / 2}, - dtype::Int8()}; - - exec_src = exec_src.dimshuffle({0, 3, 1, 2, 4}); + cudaStream_t stream = cuda_stream(args.opr->handle()); - auto&& relayout = args.opr->handle()->create_operator(); - relayout->exec({args.filter_tensor->raw_ptr, exec_src}, - {reordered_filter, exec_dst}); + // reformat filter from nhwc to ncxhwx and reorder oc + // use trans_oc threadblock_n must be 32 or 64 + bool trans_oc = ((co % m_algo_param.threadblock_n == 0) && + (m_algo_param.threadblock_n == 32 || + m_algo_param.threadblock_n == 64)); + uint32_t oc_iterleave = (m_algo_param.threadblock_n == 64) ? 64 : 32; + + if (iterleaved == 8) { + cutlass_wrapper::reorder_nhwc_imma_filter<4, 32>( + reinterpret_cast(reordered_filter), + reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, + fh, fw, trans_oc, oc_iterleave, stream); + } else if (iterleaved == 16) { + cutlass_wrapper::reorder_nhwc_imma_filter<4, 64>( + reinterpret_cast(reordered_filter), + reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, + fh, fw, trans_oc, oc_iterleave, stream); + } else { + megdnn_assert(iterleaved == 32); + cutlass_wrapper::reorder_nhwc_imma_filter<4, 128>( + reinterpret_cast(reordered_filter), + reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, + fh, fw, trans_oc, oc_iterleave, stream); + } } #endif diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp index b04b1a01f..f66545801 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp @@ -11,6 +11,7 @@ */ #include "./algo.h" +#include "src/cuda/conv_bias/cutlass_reorder_filter.cuh" #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" #include "src/cuda/convolution_helper/parameter.cuh" #include "src/cuda/utils.h" @@ -110,11 +111,14 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( size_t ho = args.dst_layout->operator[](2), wo = args.dst_layout->operator[](3); size_t co; + bool trans_oc; if (param.format == Format::NCHW32) { co = args.dst_layout->operator[](1) * 32; + trans_oc = true; } else { megdnn_assert(param.format == Format::NCHW32_NCHW4); co = args.dst_layout->operator[](1) * 4; + trans_oc = false; } UNPACK_CONV_PARAMETER(fm, param); MARK_USED_VAR @@ -123,23 +127,11 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( int8_t* filter_ptr = nullptr; if (args.preprocessed_filter == nullptr) { filter_ptr = reinterpret_cast(args.workspace.raw_ptr); - // reformat filter from nchw32 to chwn32 - TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()}; - src.init_contiguous_stride(); - TensorLayout dst = src; - dst.stride[0] = 32; - dst.stride[1] = co * fh * fw * 32; - dst.stride[2] = co * fw * 32; - dst.stride[3] = co * 32; - dst.stride[4] = 1; - TensorND ts_src, ts_dst; - ts_src.raw_ptr = args.filter_tensor->raw_ptr; - ts_src.layout = src; - ts_dst.raw_ptr = args.workspace.raw_ptr; - ts_dst.layout = dst; - auto&& transpose = - args.opr->handle()->create_operator(); - transpose->exec(ts_src, ts_dst); + // filter: KCRS32 => CRSK32 and reorder oc + cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( + filter_ptr, + reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, + fh, fw, trans_oc, stream); } else { filter_ptr = reinterpret_cast( args.preprocessed_filter->tensors[0].raw_ptr); @@ -182,7 +174,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}, - stream); + m_algo_param.stage, stream); } else { megdnn_assert(param.format == Format::NCHW32_NCHW4); cutlass_wrapper:: @@ -202,7 +194,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}, - stream); + m_algo_param.stage, stream); } } else { if (param.format == Format::NCHW32) { @@ -218,7 +210,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}, - stream); + m_algo_param.stage, stream); } else { megdnn_assert(param.format == Format::NCHW32_NCHW4); cutlass_wrapper:: @@ -238,7 +230,7 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( cutlass_wrapper::GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, m_algo_param.warp_k}, - stream); + m_algo_param.stage, stream); } } after_kernel_launch(); @@ -246,9 +238,10 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( std::string ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::to_string( AlgoParam algo_param) { - return ssprintf("%uX%uX%u_%uX%uX%u", algo_param.threadblock_m, + return ssprintf("%uX%uX%u_%uX%uX%u_%u", algo_param.threadblock_m, algo_param.threadblock_n, algo_param.threadblock_k, - algo_param.warp_m, algo_param.warp_n, algo_param.warp_k); + algo_param.warp_m, algo_param.warp_n, algo_param.warp_k, + algo_param.stage); } size_t ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm:: @@ -267,36 +260,26 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec_preprocess( using Format = Param::Format; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - size_t n = args.src_layout->operator[](0), - ci = args.src_layout->operator[](1) * 32, - hi = args.src_layout->operator[](2), - wi = args.src_layout->operator[](3); - size_t ho = args.dst_layout->operator[](2), - wo = args.dst_layout->operator[](3); + size_t ci = args.src_layout->operator[](1) * 32; size_t co; + bool trans_oc; if (param.format == Format::NCHW32) { co = args.dst_layout->operator[](1) * 32; + trans_oc = true; } else { megdnn_assert(param.format == Format::NCHW32_NCHW4); co = args.dst_layout->operator[](1) * 4; + trans_oc = false; } - UNPACK_CONV_PARAMETER(fm, param); - MARK_USED_VAR - TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()}; - src.init_contiguous_stride(); - TensorLayout dst = src; - dst.stride[0] = 32; - dst.stride[1] = co * fh * fw * 32; - dst.stride[2] = co * fw * 32; - dst.stride[3] = co * 32; - dst.stride[4] = 1; - TensorND ts_src, ts_dst; - ts_src.raw_ptr = args.filter_tensor->raw_ptr; - ts_src.layout = src; - ts_dst.raw_ptr = args.preprocessed_filter->tensors[0].raw_ptr; - ts_dst.layout = dst; - auto&& transpose = args.opr->handle()->create_operator(); - transpose->exec(ts_src, ts_dst); + size_t fh = fm.spatial[0], fw = fm.spatial[1]; + + cudaStream_t stream = cuda_stream(args.opr->handle()); + // filter: KCRS32 => CRSK32 and reorder oc + cutlass_wrapper::reorder_ncxhwx_imma_filter<8, 32>( + reinterpret_cast( + args.preprocessed_filter->tensors[0].raw_ptr), + reinterpret_cast(args.filter_tensor->raw_ptr), co, ci, fh, + fw, trans_oc, stream); } #endif diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp index 0d4762025..d94833eef 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nchw64_imma.cpp @@ -144,7 +144,8 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::do_exec( reinterpret_cast(z_ptr), reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, - dst_scale, src_zero, threadblock_shape, warp_shape, stream); + dst_scale, src_zero, threadblock_shape, warp_shape, + m_algo_param.stage, stream); } void ConvBiasForwardImpl::AlgoUInt4Int4NCHW64IMMAImplicitGemm::update_bias( diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp index ab62e1d98..b074a48da 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_uint4_int4_nhwc_imma.cpp @@ -147,7 +147,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, dst_scale, src_zero, threadblock_shape, warp_shape, - m_algo_param.access_size, stream); + m_algo_param.access_size, m_algo_param.stage, stream); } else { cutlass_wrapper::do_conv_bias_uint4_int4_implicit_gemm_imma_nhwc( reinterpret_cast(args.src_tensor->raw_ptr), @@ -157,7 +157,7 @@ void ConvBiasForwardImpl::AlgoUInt4Int4NHWCIMMAImplicitGemm::do_exec( reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, kern_param, nonlinear_mode, alpha, beta, gamma, delta, theta, dst_scale, src_zero, threadblock_shape, warp_shape, - m_algo_param.access_size, stream); + m_algo_param.access_size, m_algo_param.stage, stream); } } diff --git a/dnn/test/cuda/conv_bias_int8.cpp b/dnn/test/cuda/conv_bias_int8.cpp index 3c593dcc6..4b551e3dc 100644 --- a/dnn/test/cuda/conv_bias_int8.cpp +++ b/dnn/test/cuda/conv_bias_int8.cpp @@ -840,21 +840,21 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) { param.pad_h = param.pad_w = 1; param.stride_h = param.stride_w = 1; param.format = param::ConvBias::Format::NCHW32; - checker.set_param(param).execs({{16, 16, 7, 7, 32}, - {512, 16, 3, 3, 32}, - {1, 16, 1, 1, 32}, + checker.set_param(param).execs({{16, 8, 7, 7, 32}, + {256, 8, 3, 3, 32}, + {1, 8, 1, 1, 32}, {}, {}}); param.nonlineMode = param::ConvBias::NonlineMode::RELU; - checker.set_param(param).execs({{16, 16, 7, 7, 32}, - {512, 16, 1, 1, 32}, - {1, 16, 1, 1, 32}, + checker.set_param(param).execs({{16, 8, 7, 7, 32}, + {256, 8, 1, 1, 32}, + {1, 8, 1, 1, 32}, {}, {}}); param.nonlineMode = param::ConvBias::NonlineMode::H_SWISH; - checker.set_param(param).execs({{16, 16, 7, 7, 32}, - {512, 16, 3, 3, 32}, - {1, 16, 1, 1, 32}, + checker.set_param(param).execs({{16, 8, 7, 7, 32}, + {256, 8, 3, 3, 32}, + {1, 8, 1, 1, 32}, {}, {}}); // use non integer scale @@ -867,18 +867,18 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_IMMA) { .set_epsilon(1 + 1e-3) .set_max_avg_error(1e-1) .set_max_avg_biased_error(1e-1) - .execs({{16, 16, 7, 7, 32}, - {512, 16, 3, 3, 32}, - {1, 16, 1, 1, 32}, - {16, 16, 7, 7, 32}, + .execs({{16, 8, 7, 7, 32}, + {256, 8, 3, 3, 32}, + {1, 8, 1, 1, 32}, + {16, 8, 7, 7, 32}, {}}); }; std::string algo = ConvBias::algo_name( - "INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64", + "INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X128X64_64X64X64_2", ConvBias::DirectParam{}); check(algo); algo = ConvBias::algo_name( - "INT8_NCHW32_IMMA_IMPLICIT_GEMM_32X64X64_32X16X64", + "INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X32X32_64X32X32_1", ConvBias::DirectParam{}); check(algo); } @@ -969,7 +969,7 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_NCHW4) { checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker< ConvBiasForward>( ConvBias::algo_name( - "INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64", + "INT8_NCHW32_IMMA_IMPLICIT_GEMM_128X128X64_64X64X64_2", ConvBias::DirectParam{}) .c_str())); checker.set_dtype(0, dtype::QuantizedS8(1.9980618f)) -- GitLab