From 739f927c4c1bfb033e947523c780438daf7785d0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 9 Nov 2020 18:01:02 +0800 Subject: [PATCH] feat(dnn/cuda): opt dp4a conv for small channel base on cutlass GitOrigin-RevId: 2a74c35f27ea8186d79313ad607bb8fa89e846fb --- dnn/src/cuda/conv_bias/algo.cpp | 21 +++++++------ dnn/src/cuda/conv_bias/algo.h | 7 +++-- .../conv_bias/cutlass_convolution_wrapper.cu | 25 +++++++-------- ...4a_ncdiv4hw4_16x128x16_16x128x16_hswish.cu | Bin 0 -> 1686 bytes ...m_dp4a_ncdiv4hw4_16x128x16_16x128x16_id.cu | Bin 0 -> 1680 bytes ...dp4a_ncdiv4hw4_16x128x16_16x128x16_relu.cu | Bin 0 -> 1684 bytes ...cdiv4hw4_1x1_16x128x16_16x128x16_hswish.cu | Bin 0 -> 1687 bytes ...4a_ncdiv4hw4_1x1_16x128x16_16x128x16_id.cu | Bin 0 -> 1681 bytes ..._ncdiv4hw4_1x1_16x128x16_16x128x16_relu.cu | Bin 0 -> 1685 bytes dnn/test/cuda/conv_bias_int8.cpp | 29 ++++++++++++++++++ third_party/cutlass | 2 +- 11 files changed, 58 insertions(+), 26 deletions(-) create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_hswish.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_id.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_relu.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_hswish.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_id.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_relu.cu diff --git a/dnn/src/cuda/conv_bias/algo.cpp b/dnn/src/cuda/conv_bias/algo.cpp index e100417c..014bbde9 100644 --- a/dnn/src/cuda/conv_bias/algo.cpp +++ b/dnn/src/cuda/conv_bias/algo.cpp @@ -260,16 +260,17 @@ void ConvBiasForwardImpl::AlgoPack::fill_imma_algos() { void ConvBiasForwardImpl::AlgoPack::fill_dp4a_algos() { using AlgoParam = AlgoInt8NCHW4DotProdImplicitGemm::AlgoParam; - int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32}); - int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32}); - int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32}); - int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32}); - int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32}); - int8_nchw4_dotprod.emplace_back(AlgoParam{64, 64, 32, 64, 32, 32}); - int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32}); - int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32}); - int8_nchw4_dotprod.emplace_back(AlgoParam{32, 32, 32, 32, 32, 32}); - int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8}); + int8_nchw4_dotprod.emplace_back(AlgoParam{128, 128, 32, 64, 32, 32, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{128, 64, 32, 64, 32, 32, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{64, 128, 32, 64, 32, 32, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{32, 128, 32, 32, 64, 32, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{128, 32, 32, 64, 32, 32, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{64, 64, 32, 64, 32, 32, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{32, 64, 32, 32, 64, 32, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{64, 32, 32, 64, 32, 32, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{32, 32, 32, 32, 32, 32, 2}); + int8_nchw4_dotprod.emplace_back(AlgoParam{16, 128, 16, 16, 128, 16, 1}); + int8_nchw4_dotprod.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8, 2}); } diff --git a/dnn/src/cuda/conv_bias/algo.h b/dnn/src/cuda/conv_bias/algo.h index 8325548f..a95a96d4 100644 --- a/dnn/src/cuda/conv_bias/algo.h +++ b/dnn/src/cuda/conv_bias/algo.h @@ -407,15 +407,16 @@ public: int warp_m; int warp_n; int warp_k; + int stage; std::string to_string() { /// default algorithm if (threadblock_m == 128 && threadblock_n == 128 && threadblock_k == 32 && warp_m == 32 && warp_n == 64 && - warp_k == 32) { + warp_k == 32 && stage == 2) { return ""; } - return ssprintf("_%dX%dX%d_%dX%dX%d", threadblock_m, threadblock_n, - threadblock_k, warp_m, warp_n, warp_k); + return ssprintf("_%dX%dX%d_%dX%dX%d_%dstage", threadblock_m, threadblock_n, + threadblock_k, warp_m, warp_n, warp_k, stage); } }; AlgoInt8NCHW4DotProdImplicitGemm(AlgoParam algo_param) diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu index 832e1228..fd840927 100644 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu @@ -172,7 +172,7 @@ void megdnn::cuda::cutlass_wrapper:: const GemmCoord& warp_shape, cudaStream_t stream) { #define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ threadblock_k_, warp_m_, warp_n_, \ - warp_k_, aligned_) \ + warp_k_, stage_, aligned_) \ if (threadblock_shape.m() == threadblock_m_ && \ threadblock_shape.n() == threadblock_n_ && \ threadblock_shape.k() == threadblock_k_ && \ @@ -194,7 +194,7 @@ void megdnn::cuda::cutlass_wrapper:: cutlass::convolution::threadblock:: \ ConvolutionNCxHWxThreadblockSwizzle< \ cutlass::convolution::ConvType::kConvolution>, \ - 2, 4, aligned_, NeedLoadFromConstMem>; \ + stage_, 4, aligned_, NeedLoadFromConstMem>; \ typename Convolution::ConvolutionParameter conv_param{ \ param.n, param.ci, param.co, param.hi, param.wi, \ param.fh, param.fw, param.ho, param.wo, param.sh, \ @@ -204,16 +204,17 @@ void megdnn::cuda::cutlass_wrapper:: epilogue, stream); \ } #define DISPATCH_KERNEL \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 16); \ - DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 4); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ megdnn_assert(false, \ "unsupported threadblock shape (%dx%dx%d) and warp shape " \ "(%dx%dx%d)", \ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..ab01f989dadf003d7073fb228910bdc6050460b8 GIT binary patch literal 1686 zcmbVMZBN@U5dNND;T37xP@rO)5HcMSWUNeMFp)~^o8{VWYBg~z+i5AkerG4m3o2q- zKcx1Z@9sH2kMBS((7RvP({FcoWOhA$xFmm0qVX7ILJ32i;Ap<}8^S712`kl_%q6vC zC5$4n(78RBWd%Rd)X+jWVMhNDW1W%A9nl<+0vHI9iscm-Xk-l=`&PznD`TnLS+^yA8#9(v~Y8d*u>lx!b7%`p+WlM9HJl#A#O}<>5 z4`RN{96EYEEO~lJfA+&>OCu=1%1gFNmsE>>s7H$N6u!YVIXR8d$)_{;oVH~jsnP!x z_(Q>+R?Nv#y&VrFW7PF}tz}+jmAKb^xT3LQ>_Sr*UjbW}oZm_%s7cCVE*13!Z_`4fS4yIABUb?pTtAmB2ueA+%3FNVg9F&Ase1jfCyb=_T^ zmhe#K%$KW8HJ_a|pdFsB(e1N^!;R7~9w72aW?K!IRy|O0r#5QX661Rfmt<`!3qaIn zs(op33Z#S@FGUoiRu(aONpy@l`kMGQ^8jPb;@nGfaCNox>zcs}Y(78dqQ27ox%T&a z(bV_iiV+AK6e)mNCtwW0s6y*fa#VbI;cm2Pw_x^89!f<^vO)G@T>B`+}j6*iNOSOB^y+odb)WOUuh(jEEGS*FFFp)~^o8{VWYBg~z+i5AkerG4m3o2q- zKcx1Z@9sH2kMBS((7RtZ)30}TWOg%sxFUZ}qVX7ILJ32i;Ap<}8^S712`kl_%q6vC zC5$4n(78RBWd%Rd)X+jWVMhNDW1W%A9nl<+0vHI9iscm-Xk-l=`&PznD`TnLdAB8g8#9(v~Y8d*m>lx!b7%`p+WlM9HJbizBntZ;z z7{q*&Idt@TSn~9c{_Ka%mPSy1otJEtE~ys%P>&SjDSU%#a&i`+R?Nv#y&VrFW7PF}qh(%ZmAKb^xT3LQ>_Sr*UjtiLoZm_%s7cCVE*13!lbjZM z5}`f6z?FkyF@oi%Qv&ODvB*p6+SU`0Zkbm;VIPKycA!07vVAQJK&l3@qTdb^Ec_Ad%mtHw?JtN3G9bxQ<|04l!>0hfk+fgNYra( zz{c<00Ih+)_kU{=#_F4CS|L9D>;ai;owtjc#4) K=s+kg7tvpjUpJ-z literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_16x128x16_16x128x16_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..5dfd371c8bdcd07754f633083d69edd5c36850ef GIT binary patch literal 1684 zcmbVMZBN@U5dNND;T37xP@sZM2$>EEGS*FFFp)~^TjbhqYBjMfzqFKJzjKo21r;%^ zA5#0yclVs1$9EtW=-scI>DRlvY<4q!xXS*VMB_2ax#Wg8&d_`tHW~GdWK_v@HW$QZ zD{f>aN?q7vy$t~(m&$OoAw~+BBeg)(yMkC79cV-HlAvjlASMW$vrbT?y`rQw8K7mBXZCDFVe8jxZU7WoSenzV)TCn z{*Z8|o;soA+wo8`MqRHrS`;ex{9gCr^2Um>tD3_28rZsG>{dulOrpxUkR%j*;Dt{L zQtC;B_WUZo92ATZEI*xQux=NNLJ`-tpMZQ-;X~3slpux1j4|VDK{>{~vvu8FotEyZ z%y}qRn`$;YYlU`rx<zip>A$X^~yVn?3+3YUW?mhPs1rkyR6RXS|g82c9j~B0v NZ>@B6z$Ful=r4wGIX(aY literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..1c7115e9d9b4e077f0b265463b4c45710751ec6f GIT binary patch literal 1687 zcmbVMZBN@U5dNND;T37xP@rO)5HcMSWUNeMFp)~^o8{VWYSnQp+i5Ake&;033o2q- zKcx1Z@9sH2kMBS((7RvP({FcoWOhA$xFmm0qVX7ILJLd1;Ap-L8^TIX2`luP%q4YX zC9EbgH_9EGZ3vK=F04QsYPHlEDi(;9UQy?w17m5HQ#4Hy#3hAu&P#?&>6Wkg>SP>^ zjvh;0I>C<&wKNxA*wH`4IBzBMM^ph)00SXXGgWedM$WQvXl2~CGM3t(bz2g)RcDe( zPco6`(0kPh*;s0fu*a;FT8*Np4B62w-4>-!Ekj>-J!7mYBgQkKU19H%r<=#8$(M`s zLCkllprhBrlBb9CXFqJVw1V<0Rj^gMq(<~Z15%8q@U2{vlhYWTd^&^AXjlgqm{n+9NRD!majoG7;gyH);OTaT9{An6p7bvF;zKJnIO%*4 z6fk{=IDrw;APHK0>0MNBdf$OSeTWb9bDzIS|J?I+MY#h`TS#C(M4Q5{oS{tg93E7n zVNjx8D+4-y@8)O?{C)UalWMxYnWj}j=}qbH4aQYAw@bBq(!E51gw(^(%85rXLSXUn Q{Iv;g3KJa&&E+Ec3wYQ%M*si- literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_16x128x16_16x128x16_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..374f51e9660a374dd54912933f94785123c5e458 GIT binary patch literal 1681 zcmbVMZBN@U5dNND;T37xP@sZM2$>EEGS*FFFp)~^o8{VWYSnQp+i5Ake&;033o2q- zKcx1Z@9sH2kMBS((7RtZ)30}TWOg%sxFUZ}qVX7ILJLd1;Ap-L8^TIX2`luP%q4YX zC9EbgH_9EGZ3vK=F04QsYPHlEDi(;9UQy?w17m5HQ#4Hy#3hAu&P#?&>6Wkg>SP>^ zjvh;0I>C<&wKNxA*wH`4IBzBMM^ph)00SXXGgWedM$WQvXl2~CGM3t(cUuy+RcDe( zPco6`(0kPh*;s0fu*a;FT8*Np4B62w-4>-!Ekj>+J!7mYBgQkKU19H%r|*wXlh2nI zgP3npK}WBLB~K6O&wki!X$9ris$i>hNsZ`-2Ba8I;aj;TCucD_IX#EZSzGpzTJv9l zKQ!EF$-FG|+wo8`MqRHrMyet!#l7jn6^#{RUp0mCHL!KX`K{D~+N8+mQqxc{Q8YJ` z2<`b*TzM!KBUpYqC9rN63sq3xww{1=3sr`MeHcLojTvJu){+T~duQvqyE-k^Rhjcp zt~S+7cGe2*@N|uCpRGFFXbj^4qJU(!HGpZ=0~L2Vqn0f(zBh15&ZRO3L~W+pm*!_c zM^)pcd}7oJB1SJMj!{Qk6W?YYV66F@2T=~Lt`>b=v+|}J`TVRzeWm+z?eF)psUKt& zBarlhDkp$iFJJ<}ltSx5a+HgVYYjID?>l@4kG3=P!0%r4q<6WOA6luwN#}#0faycT z2@H=0Nzmd;@1lCm`wj%^LwuN=`}|G%=bo=C${ldpLIV3C+7x!>3}vF{@SqY6BNFvm z8PM^2H$iLQ@5A4kRCD#sG_4X!Z%Th}Fs`z>U8>zEEGS*FFFp)~^TjbhqYSnQpzqFKJzjKo21r;%^ zA5#0yclVs1$9EtW=-scI>DRlvY<4q!xXS*VMB_2ax#EU6&d_`tHW~GdWVBT4Y%Yk+ zR@|se6k6J2y$t~(SEb=-LyQtCN96)h?=C2L)pU%TK2ntlPyxmc+H~Cm`Qa`jE5_B}kz$W6bzkP>yl$Y+ZL(r=`0p za~{gorkc&pTA>}DuF>tYRfij$VLU(-kj%CkFs*u^;!bPC&?UzA8ZOD&R1|=y%~bo+ z;tVLMYP^(Cj9O8|=q1fD>d0&2+sp%uHHY&c%)!;w!mn#q-gG0LpCzxabbqe>{a!Zp zgS27XQ2T@Vy7 zeTY~F1EfI`wD{7ys9yBG1A+PwAExI%f0O>X=j)0x3!Ju)zr9bo|~;(Hi*s@VCyY+4^RhRter|@9s6mRW`dzwR_OLM1h3V!N@8yhhTuf;^W0@ O^IMlXI^c?lMf4Z3c{zyy literal 0 HcmV?d00001 diff --git a/dnn/test/cuda/conv_bias_int8.cpp b/dnn/test/cuda/conv_bias_int8.cpp index 3173be12..f6d588f5 100644 --- a/dnn/test/cuda/conv_bias_int8.cpp +++ b/dnn/test/cuda/conv_bias_int8.cpp @@ -97,6 +97,13 @@ std::vector get_detection_bench_args(size_t batch = 16) { return args; } +std::vector get_det_first_bench_args(size_t batch = 16) { + std::vector args; + args.emplace_back(BenchArgs{batch, 4, 736, 1280, 16, 3, 2}); + args.emplace_back(BenchArgs{batch, 16, 384, 640, 16, 3, 1}); + return args; +} + void benchmark_target_algo( Handle* handle, const std::vector& args, DType src_dtype, DType filter_dtype, DType bias_dtype, DType dst_dtype, @@ -1236,6 +1243,28 @@ TEST_F(CUDA, BENCHMARK_CUTLASS_CONV_BIAS_INT8_NCHW4) { dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.0f}, "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM", param::ConvBias::Format::NCHW4); } + +TEST_F(CUDA, BENCHMARK_SASS_CONV_BIAS_INT8_NCHW4_DET_FIRST) { + require_compute_capability(6, 1); + std::string algo = ConvBias::algo_name( + "SASS_INT8_NCHW4_DOTPROD_IMPLICIT_GEMM_128X32_64", + ConvBias::DirectParam{}); + benchmark_target_algo(handle_cuda(), get_det_first_bench_args(16), + dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f}, + dtype::QuantizedS32{1.2f * 1.3f}, + dtype::QuantizedS8{1.0f}, algo.c_str(), + param::ConvBias::Format::NCHW4); +} + +TEST_F(CUDA, BENCHMARK_CUTLASS_CONV_BIAS_INT8_NCHW4_DET_FIRST) { + require_compute_capability(6, 1); + benchmark_target_algo( + handle_cuda(), get_det_first_bench_args(16), + dtype::QuantizedS8{1.2f}, dtype::QuantizedS8{1.3f}, + dtype::QuantizedS32{1.2f * 1.3f}, dtype::QuantizedS8{1.0f}, + "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM_16", param::ConvBias::Format::NCHW4); +} + #endif } // namespace test } // namespace megdnn diff --git a/third_party/cutlass b/third_party/cutlass index 5a7f4bfa..41426ea4 160000 --- a/third_party/cutlass +++ b/third_party/cutlass @@ -1 +1 @@ -Subproject commit 5a7f4bfa0e57f92140c8236322a86730132e0847 +Subproject commit 41426ea4074dcfc448b1c9979ea7617407590c04 -- GitLab