From 66f70578c2b29bbd5cecbcdfcbdaf0ba333e1c3f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 18 May 2021 19:08:22 +0800 Subject: [PATCH] feat(dnn/cuda): add convolution with i8 input and i4 output GitOrigin-RevId: 10512645d5d5ac3d985720788760bf8a3855c1f1 --- dnn/scripts/Makefile | 5 +- dnn/scripts/opr_param_defs.py | 2 + dnn/src/common/conv_bias.cpp | 3 +- dnn/src/common/convolution.cpp | 29 +- .../conv_bias/cudnn_conv_bias_activation.cpp | 3 + .../conv_bias/cutlass_convolution_wrapper.cu | 136 +++++++++ .../conv_bias/cutlass_convolution_wrapper.cuh | 9 + ...licit_gemm_conv_bias_cutlass_wrapper.cuinl | 65 +++++ .../implicit_gemm_int8_nchw4_dp4a.cpp | 264 +++++++++--------- ...s_int8_implicit_gemm_cutlass_wrapper.cuinl | 66 +---- ...sh_s8_128x128x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1804 bytes ...ish_s8_128x32x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1802 bytes ...ish_s8_128x64x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1802 bytes ...sh_s8_16x128x16_16x128x16_1_nc4hw4_nhwc.cu | Bin 0 -> 1803 bytes ...hswish_s8_16x64x8_16x64x8_2_nc4hw4_nhwc.cu | Bin 0 -> 1795 bytes ...ish_s8_32x128x32_32x64x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1802 bytes ...wish_s8_32x32x32_32x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1800 bytes ...wish_s8_32x64x32_32x64x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1800 bytes ...ish_s8_64x128x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1802 bytes ...wish_s8_64x32x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1800 bytes ...wish_s8_64x64x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1800 bytes ...ty_s8_128x128x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1800 bytes ...ity_s8_128x32x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1798 bytes ...ity_s8_128x64x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1798 bytes ...ty_s8_16x128x16_16x128x16_1_nc4hw4_nhwc.cu | Bin 0 -> 1799 bytes ...entity_s8_16x64x8_16x64x8_2_nc4hw4_nhwc.cu | Bin 0 -> 1791 bytes ...ity_s8_32x128x32_32x64x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1798 bytes ...tity_s8_32x32x32_32x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1796 bytes ...tity_s8_32x64x32_32x64x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1796 bytes ...ity_s8_64x128x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1798 bytes ...tity_s8_64x32x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1796 bytes ...tity_s8_64x64x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1796 bytes ...lu_s8_128x128x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1800 bytes ...elu_s8_128x32x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1798 bytes ...elu_s8_128x64x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1798 bytes ...lu_s8_16x128x16_16x128x16_1_nc4hw4_nhwc.cu | Bin 0 -> 1799 bytes ...p_relu_s8_16x64x8_16x64x8_2_nc4hw4_nhwc.cu | Bin 0 -> 1791 bytes ...elu_s8_32x128x32_32x64x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1798 bytes ...relu_s8_32x32x32_32x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1796 bytes ...relu_s8_32x64x32_32x64x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1796 bytes ...elu_s8_64x128x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1798 bytes ...relu_s8_64x32x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1796 bytes ...relu_s8_64x64x32_64x32x32_2_nc4hw4_nhwc.cu | Bin 0 -> 1796 bytes dnn/src/naive/convolution/helper.h | 26 +- 44 files changed, 401 insertions(+), 207 deletions(-) create mode 100644 dnn/src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl mode change 100644 => 120000 dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_128x64x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_16x128x16_16x128x16_1_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_16x64x8_16x64x8_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_32x32x32_32x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_64x32x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_64x64x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_128x32x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_16x128x16_16x128x16_1_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_32x128x32_32x64x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_64x128x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_64x32x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_16x128x16_16x128x16_1_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_16x64x8_16x64x8_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_32x32x32_32x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_32x64x32_32x64x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_64x32x32_64x32x32_2_nc4hw4_nhwc.cu create mode 100644 dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_nhwc.cu diff --git a/dnn/scripts/Makefile b/dnn/scripts/Makefile index b5b9532e9..bd219e486 100644 --- a/dnn/scripts/Makefile +++ b/dnn/scripts/Makefile @@ -37,15 +37,16 @@ all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL) ../src/cuda/elemwise_multi_type/kimpl: gen_elemwise_multi_type_kern_impls.py ./$^ --type cuda $@ -../src/cuda/conv_bias/int8/kimpl: gen_cuda_conv_bias_kern_impls.py gen_cutlass_conv_bias_kern_impls.py +../src/cuda/conv_bias/int8/kimpl: gen_cuda_conv_bias_kern_impls.py gen_cutlass_conv_bias_kern_impls.py cutlass_generator/generator.py ./gen_cuda_conv_bias_kern_impls.py --type dp4a $@ ./gen_cutlass_conv_bias_kern_impls.py --type dp4a $@ + python3 ./cutlass_generator/generator.py --operations all --type simt $@ ../src/cuda/conv_bias/int8_imma/kimpl: gen_cuda_conv_bias_kern_impls.py gen_cutlass_conv_bias_kern_impls.py ./gen_cuda_conv_bias_kern_impls.py --type imma $@ ./gen_cutlass_conv_bias_kern_impls.py --type imma $@ -../src/cuda/batch_conv_bias/int8/kimpl: gen_cuda_batch_conv_bias_kern_impls.py +../src/cuda/batch_conv_bias/int8/kimpl: gen_cuda_batch_conv_bias_kern_impls.py ./$^ --type dp4a $@ ../src/cuda/matrix_mul/fp32_simt/kimpl: gen_cutlass_matmul_kern_impls.py diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index cb9ec7d55..3a3fa37b1 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -43,6 +43,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), + Doc('NCHW4_NHWC', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'), Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, ' 'output tensor is nchw layout'), Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' @@ -99,6 +100,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) Doc('NCHW4_NCHW32', 'NCHW4_NCHW32 means input tensors are nchw4 layout, output tensor is nchw32 layout'), Doc('NCHW32_NCHW4', 'NCHW32_NCHW4 means input tensors are nchw32 layout, output tensor is nchw4 layout'), Doc('NCHW4_NCHW', 'NCHW4_NCHW means input tensors are nchw4 layout, output tensor is nchw layout'), + Doc('NCHW4_NHWC', 'NCHW4_NHWC means input tensors are nchw4 layout, output tensor is nhwc layout'), Doc('NHWC_NCHW', 'NHWC_NCHW means input tensors are nhwc layout, ' 'output tensor is nchw layout'), Doc('NHWC_NCHW4_IC_SMALL', 'NHWC_NCHW4_IC_SMALL means input tensors are nhwc(c < 4) layout, ' diff --git a/dnn/src/common/conv_bias.cpp b/dnn/src/common/conv_bias.cpp index 76ef5c960..2fb2395c5 100644 --- a/dnn/src/common/conv_bias.cpp +++ b/dnn/src/common/conv_bias.cpp @@ -65,7 +65,8 @@ void do_check_exec_common( bias.to_string().c_str(), dst.to_string().c_str()); megdnn_assert(bias.shape[2] == 1); megdnn_assert(bias.shape[3] == 1); - } else if (opr->param().format == param::ConvBias::Format::NHWC) { + } else if (param().format == param::ConvBias::Format::NHWC || + param().format == param::ConvBias::Format::NCHW4_NHWC) { megdnn_assert(bias.shape[0] == 1); megdnn_assert(bias.shape[1] == 1); megdnn_assert(bias.shape[2] == 1); diff --git a/dnn/src/common/convolution.cpp b/dnn/src/common/convolution.cpp index fc81aab79..04934ac39 100644 --- a/dnn/src/common/convolution.cpp +++ b/dnn/src/common/convolution.cpp @@ -368,7 +368,8 @@ void make_canonized_filter_meta_nchwx( megdnn_assert(param.format == Param::Format::NCHW4 || param.format == Param::Format::NCHW8 || param.format == Param::Format::NCHW32 || - param.format == Param::Format::NCHW4_NCHW || + param.format == Param::Format::NCHW4_NCHW || + param.format == Param::Format::NCHW4_NHWC || param.format == Param::Format::NCHW4_NCHW32 || param.format == Param::Format::NCHW32_NCHW4 || param.format == Param::Format::NCHW64); @@ -498,6 +499,7 @@ ConvolutionBase::make_canonized_filter_meta( } } else if (param().format == Param::Format::NCHW4 || param().format == Param::Format::NCHW4_NCHW || + param().format == Param::Format::NCHW4_NHWC || param().format == Param::Format::NCHW4_NCHW32) { make_canonized_filter_meta_nchwx<4, Parameter>(src_ndim, filter, param(), ret); @@ -547,7 +549,12 @@ void ConvolutionBase::check_or_deduce_dtype_fwd(DType src, src.enumv() == DTypeEnum::Quantized4Asymm) { supported_dst_dtype.push_back( dtype::QuantizedS32(mul_scale(src, filter))); - if (dst.valid() && dst.enumv() == src.enumv()) { + bool cond_dst = + dst.valid() && (dst.enumv() == src.enumv() || + ((dst.enumv() == DTypeEnum::QuantizedS4 || + dst.enumv() == DTypeEnum::Quantized4Asymm) && + src.enumv() == DTypeEnum::QuantizedS8)); + if (cond_dst) { supported_dst_dtype.push_back(dst); } if (src.enumv() == DTypeEnum::QuantizedS8) { @@ -611,7 +618,8 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, } else { megdnn_assert(param().format == Param::Format::NHWCD4 || param().format == Param::Format::NCHW4 || - param().format == Param::Format::NCHW4_NCHW || + param().format == Param::Format::NCHW4_NCHW || + param().format == Param::Format::NCHW4_NHWC || param().format == Param::Format::NCHW4_NCHW32 || param().format == Param::Format::NCHW44 || param().format == Param::Format::NCHW44_DOT || @@ -879,6 +887,21 @@ ConvolutionBase::deduce_layout_fwd(const TensorLayout& src, cflt.stride[0], cflt.padding[0]); dst[3] = infer_conv_shape(src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); + } else if (param().format == Param::Format::NCHW4_NHWC) { + megdnn_assert(src.ndim == 5, + "invalid src ndim for NCHW4_NHWC, expected=5, got=%zu", + src.ndim); + megdnn_assert(cflt.icpg * cflt.group == src[1] * 4, + "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, + cflt.group); + dst.ndim = 4; + dst[0] = src[0]; + dst[1] = infer_conv_shape(src[2], cflt.dilated_spatial[0], + cflt.stride[0], cflt.padding[0]); + dst[2] = infer_conv_shape(src[3], cflt.dilated_spatial[1], + cflt.stride[1], cflt.padding[1]); + auto oc = cflt.ocpg * cflt.group; + dst[3] = oc; } else if (param().format == Param::Format::NCHW4_NCHW32) { megdnn_assert(src.ndim == 5, "invalid src ndim for NCHW4_NCHW32, expected=5, got=%zu", diff --git a/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp b/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp index 674dd629f..9b6d6d677 100644 --- a/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp +++ b/dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp @@ -35,6 +35,9 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( args.src_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) && args.filter_layout->dtype.enumv() == DTypeEnum::QuantizedS4) return false; + if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || + args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) + return false; if (args.src_layout->dtype == args.filter_layout->dtype && args.src_layout->dtype == dtype::BFloat16()) { return false; diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu index cb77f6175..4e3cba3e1 100644 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu @@ -911,4 +911,140 @@ void megdnn::cuda::cutlass_wrapper:: INST(true); #undef INST +/* ===== cutlass kernel wrapper for nchw4 layout and nhwc output ===== */ +#if MEGDNN_TEGRA_X1 +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( + const int8_t* /* d_src */, const int8_t* /* d_filter */, + const int32_t* /* d_bias */, const int8_t* /* d_z */, + int8_t* /* d_dst */, int* /* workspace */, + const convolution::ConvParam& /* param */, + uint32_t /* nonlinear_mode */, float /* alpha */, + float /* beta */, float /* gamma */, float /* delta */, + float /* theta */, float /* scale */, + const GemmCoord& /* threadblock_shape */, + const GemmCoord& /* warp_shape */, int /* stages */, + cudaStream_t /* stream */) {} +#else +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( + const int8_t* d_src, const int8_t* d_filter, + const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, + int* workspace, const convolution::ConvParam& param, + uint32_t nonlinear_mode, float alpha, float beta, float gamma, + float delta, float theta, float scale, + const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, + int stages, cudaStream_t stream) { +#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ + threadblock_k_, warp_m_, warp_n_, \ + warp_k_, stages_, aligned_) \ + if (threadblock_shape.m() == threadblock_m_ && \ + threadblock_shape.n() == threadblock_n_ && \ + threadblock_shape.k() == threadblock_k_ && \ + warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ + warp_shape.k() == warp_k_ && stages == stages_) { \ + using ThreadBlockShape = \ + cutlass::gemm::GemmShape; \ + using WarpShape = cutlass::gemm::GemmShape; \ + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ + using Convolution = cutlass::conv::device::Convolution< \ + int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ + cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ + cutlass::layout::TensorNHWC, int32_t, \ + cutlass::layout::TensorNHWC, int32_t, \ + cutlass::conv::ConvType::kConvolution, \ + cutlass::arch::OpClassSimt, cutlass::arch::Sm75, \ + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ + cutlass::conv::threadblock:: \ + ConvolutionFpropNCxHWxThreadblockSwizzle, \ + stages_, 4, aligned_, NeedLoadFromConstMem, \ + cutlass::arch::OpMultiplyAddSaturate>; \ + typename Convolution::ConvolutionParameter conv_param( \ + param.n, param.hi, param.wi, param.ci, param.co, param.fh, \ + param.fw, param.ho, param.wo, param.ph, param.pw, param.sh, \ + param.sw, 1, 1, cutlass::conv::Mode::kCrossCorrelation); \ + return cutlass_convolution_wrapper( \ + d_src, d_filter, d_bias, \ + reinterpret_cast(d_z), \ + reinterpret_cast(d_dst), workspace, \ + conv_param, epilogue, stream); \ + } +#define DISPATCH_KERNEL \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 2, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 128, 16, 16, 128, 16, 1, 8); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(16, 64, 8, 16, 64, 8, 2, 4); \ + megdnn_assert(false, \ + "unsupported threadblock shape (%dx%dx%d) and warp shape " \ + "(%dx%dx%d)", \ + threadblock_shape.m(), threadblock_shape.n(), \ + threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ + warp_shape.k()); + using ElementOutput = cutlass::integer_subbyte<4, signedness>; + using ElementAccumulator = int32_t; + using ElementBias = int32_t; + using ElementCompute = float; + using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; + switch (nonlinear_mode) { + case NonlineMode::IDENTITY: { + using EpilogueOp = + cutlass::epilogue::thread::BiasAddLinearCombinationClamp< + ElementOutput, 8, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, + delta + theta}; + DISPATCH_KERNEL; + } + case NonlineMode::RELU: { + using EpilogueOp = cutlass::epilogue::thread:: + BiasAddLinearCombinationReluClamp< + ElementOutput, 8, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, + 0, delta, theta}; + DISPATCH_KERNEL; + } + case NonlineMode::H_SWISH: { + using EpilogueOp = cutlass::epilogue::thread:: + BiasAddLinearCombinationHSwishClamp< + ElementOutput, 8, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, + scale, detla, theta}; + DISPATCH_KERNEL; + } + default: + megdnn_assert(false, + "unsupported nonlinear mode for conv bias operator"); + } +#undef DISPATCH_KERNEL_WITH_TILE_SHAPE +#undef DISPATCH_KERNEL +} +#endif + +#define INST(signedness) \ + template void megdnn::cuda::cutlass_wrapper:: \ + do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( \ + const int8_t* d_src, const int8_t* d_filter, \ + const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ + int* workspace, const convolution::ConvParam& param, \ + uint32_t nonlinear_mode, float alpha, float beta, \ + float gamma, float delta, float theta, float scale, \ + const GemmCoord& threadblock_shape, \ + const GemmCoord& warp_shape, int stages, \ + cudaStream_t stream); +INST(true); +INST(false); +#undef INST + // vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh index c97f2bc7a..f2d7370de 100644 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh @@ -94,6 +94,15 @@ void do_conv_bias_uint4_int4_implicit_gemm_imma_ncdiv64hw64( float scale, uint8_t src_zero_point, const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, cudaStream_t stream); +template +void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc( + const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, + const int8_t* d_z, int8_t* d_dst, int* workspace, + const convolution::ConvParam& param, uint32_t nonlinear_mode, + float alpha, float beta, float gamma, float delta, float theta, + float scale, const GemmCoord& threadblock_shape, + const GemmCoord& warp_shape, int stages, cudaStream_t stream); + } // namespace cutlass_wrapper } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl b/dnn/src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl new file mode 100644 index 000000000..9f09ce41d --- /dev/null +++ b/dnn/src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl @@ -0,0 +1,65 @@ +/** + * \file + * dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "cutlass/convolution/device/convolution.h" +#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" + +using namespace megdnn; +using namespace cuda; +using namespace cutlass_wrapper; + +template +void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( + const typename Convolution::ElementSrc* d_src, + const typename Convolution::ElementFilter* d_filter, + const typename Convolution::ElementBias* d_bias, + const typename Convolution::ElementDst* d_z, + typename Convolution::ElementDst* d_dst, int* workspace, + typename Convolution::ConvolutionParameter const& conv_param, + typename Convolution::EpilogueOutputOp::Params const& epilogue, + cudaStream_t stream, typename Convolution::ExtraParam extra_param) { + typename Convolution::TensorRefSrc tensor_src{ + const_cast(d_src), + Convolution::LayoutSrc::packed( + {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; + typename Convolution::TensorRefFilter tensor_filter{ + const_cast(d_filter), + Convolution::LayoutFilter::packed( + {conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; + typename Convolution::TensorRefBias tensor_bias{ + const_cast(d_bias), + Convolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; + typename Convolution::TensorRefDst tensor_z{ + const_cast(d_z), + Convolution::LayoutDst::packed( + {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; + typename Convolution::TensorRefDst tensor_dst{ + d_dst, + Convolution::LayoutDst::packed( + {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; + typename Convolution::Arguments arguments{conv_param, + tensor_src.non_const_ref(), + tensor_filter.non_const_ref(), + tensor_bias.non_const_ref(), + tensor_z.non_const_ref(), + tensor_dst.non_const_ref(), + epilogue, + {}, + {}, + extra_param}; + Convolution conv_op; + cutlass_check(conv_op.initialize(arguments, workspace)); + cutlass_check(conv_op(stream)); + after_kernel_launch(); +} + +// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp index 3b7e2e704..87672047a 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp @@ -37,27 +37,40 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( if (!check_bias_share_in_channel(*(args.bias_layout), param.format)) return false; - if (param.format == Format::NCHW4_NCHW32) { - if (m_algo_param.threadblock_m % 32 != 0) - return false; - } else if (param.format != Format::NCHW4_NCHW && - param.format != Format::NCHW4) - return false; + bool valid_format = param.format == Format::NCHW4_NCHW32 && + m_algo_param.threadblock_m % 32 == 0; + valid_format |= param.format == Format::NCHW4_NCHW && + args.bias_layout->dtype.enumv() == DTypeEnum::Float32 && + args.dst_layout->dtype.enumv() == DTypeEnum::Float32; + valid_format |= + param.format == Format::NCHW4_NHWC && + args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32 && + (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 || + args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm); + valid_format |= param.format == Format::NCHW4; + if (!valid_format) return false; size_t n = args.src_layout->operator[](0), ci = args.src_layout->operator[](1) * 4, hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3); - size_t ho = args.dst_layout->operator[](2), - wo = args.dst_layout->operator[](3); size_t co; + size_t dst_spatial_pos; if (param.format == Format::NCHW4) { co = args.dst_layout->operator[](1) * 4; + dst_spatial_pos = 2; } else if (param.format == Format::NCHW4_NCHW) { co = args.dst_layout->operator[](1); + dst_spatial_pos = 2; + } else if (param.format == Format::NCHW4_NHWC) { + co = args.dst_layout->operator[](3); + dst_spatial_pos = 1; } else { megdnn_assert(param.format == Format::NCHW4_NCHW32); + dst_spatial_pos = 2; co = args.dst_layout->operator[](1) * 32; } + size_t ho = args.dst_layout->operator[](dst_spatial_pos), + wo = args.dst_layout->operator[](dst_spatial_pos + 1); UNPACK_CONV_PARAMETER(fm, param); MARK_USED_VAR // TODO support group conv @@ -72,7 +85,9 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( available &= (src_dtype.enumv() == DTypeEnum::QuantizedS8 && filter_dtype.enumv() == DTypeEnum::QuantizedS8); available &= (bias_dtype.enumv() == DTypeEnum::QuantizedS32 && - dst_dtype.enumv() == DTypeEnum::QuantizedS8) || + (dst_dtype.enumv() == DTypeEnum::QuantizedS8 || + dst_dtype.enumv() == DTypeEnum::QuantizedS4 || + dst_dtype.enumv() == DTypeEnum::Quantized4Asymm)) || (bias_dtype.enumv() == DTypeEnum::Float32 && dst_dtype.enumv() == DTypeEnum::Float32); // TODO: support dialtion @@ -111,17 +126,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( ci = args.src_layout->operator[](1) * 4, hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3); - size_t ho = args.dst_layout->operator[](2), - wo = args.dst_layout->operator[](3); - size_t co; + size_t co, dst_spatial_pos; if (param.format == Format::NCHW4) { co = args.dst_layout->operator[](1) * 4; + dst_spatial_pos = 2; } else if (param.format == Format::NCHW4_NCHW) { co = args.dst_layout->operator[](1); + dst_spatial_pos = 2; + } else if (param.format == Format::NCHW4_NHWC) { + co = args.dst_layout->operator[](3); + dst_spatial_pos = 1; } else { megdnn_assert(param.format == Format::NCHW4_NCHW32); + dst_spatial_pos = 2; co = args.dst_layout->operator[](1) * 32; } + size_t ho = args.dst_layout->operator[](dst_spatial_pos), + wo = args.dst_layout->operator[](dst_spatial_pos + 1); UNPACK_CONV_PARAMETER(fm, param); MARK_USED_VAR auto&& stream = cuda_stream(args.opr->handle()); @@ -161,136 +182,107 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( float beta = 1.f; float dst_scale = 1.f; if (args.bias_layout->dtype.enumv() == DTypeEnum::QuantizedS32) { - megdnn_assert(args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS8); + megdnn_assert(args.dst_layout->dtype.category() == + DTypeCategory::QUANTIZED); float bias_scale = args.bias_layout->dtype.param() - .scale, - dst_scale = - args.dst_layout->dtype.param().scale; + .scale; + dst_scale = get_scale(args.dst_layout->dtype); alpha /= dst_scale, beta = bias_scale / dst_scale; } float gamma = 0.f; if (args.z_layout->ndim > 0) { gamma = 1.f; - if (args.z_layout->dtype.enumv() == DTypeEnum::QuantizedS8) { - megdnn_assert(args.dst_layout->dtype.enumv() == - DTypeEnum::QuantizedS8); - float z_scale = args.z_layout->dtype.param() - .scale; + if (args.z_layout->dtype.category() == DTypeCategory::QUANTIZED) { + megdnn_assert(args.dst_layout->dtype.category() == + DTypeCategory::QUANTIZED); + float z_scale = get_scale(args.z_layout->dtype); gamma = z_scale / dst_scale; } } uint32_t nonlinear_mode = static_cast(param.nonlineMode); - if (fh == 1 && fw == 1) { - if (param.format == Format::NCHW4) { - cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< - false>( - args.src_tensor->compatible_ptr(), filter_ptr, - args.bias_tensor->compatible_ptr(), - args.z_tensor->compatible_ptr(), - args.dst_tensor->compatible_ptr(), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, - cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k}, - m_algo_param.stage, stream); - } else if (param.format == Format::NCHW4_NCHW) { - cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( - args.src_tensor->compatible_ptr(), - filter_ptr, - args.bias_tensor->compatible_ptr(), - args.z_tensor->compatible_ptr(), - args.dst_tensor->compatible_ptr(), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, - dst_scale, - cutlass_wrapper::GemmCoord{ - m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k}, - m_algo_param.stage, stream); - } else { - megdnn_assert(param.format == Format::NCHW4_NCHW32); - cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< - false>( - args.src_tensor->compatible_ptr(), - filter_ptr, - args.bias_tensor->compatible_ptr(), - args.z_tensor->compatible_ptr(), - args.dst_tensor->compatible_ptr(), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, - dst_scale, - cutlass_wrapper::GemmCoord{ - m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k}, - m_algo_param.stage, stream); - } + bool nonunity_kernel = !(fh == 1 && fw == 1); +#define DISPATCH(_nonunity_kernel) \ + if (nonunity_kernel == _nonunity_kernel) { \ + cb(_nonunity_kernel) \ + } + if (param.format == Format::NCHW4) { +#define cb(_nonunity_kernel) \ + cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< \ + _nonunity_kernel>( \ + args.src_tensor->compatible_ptr(), filter_ptr, \ + args.bias_tensor->compatible_ptr(), \ + args.z_tensor->compatible_ptr(), \ + args.dst_tensor->compatible_ptr(), nullptr, kern_param, \ + nonlinear_mode, alpha, beta, gamma, dst_scale, \ + cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ + m_algo_param.threadblock_n, \ + m_algo_param.threadblock_k}, \ + cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ + m_algo_param.warp_n, \ + m_algo_param.warp_k}, \ + m_algo_param.stage, stream); + DISPATCH(true); + DISPATCH(false); +#undef cb + } else if (param.format == Format::NCHW4_NCHW) { +#define cb(_nonunity_kernel) \ + cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw< \ + _nonunity_kernel>( \ + args.src_tensor->compatible_ptr(), filter_ptr, \ + args.bias_tensor->compatible_ptr(), \ + args.z_tensor->compatible_ptr(), \ + args.dst_tensor->compatible_ptr(), nullptr, kern_param, \ + nonlinear_mode, alpha, beta, gamma, dst_scale, \ + cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ + m_algo_param.threadblock_n, \ + m_algo_param.threadblock_k}, \ + cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ + m_algo_param.warp_n, \ + m_algo_param.warp_k}, \ + m_algo_param.stage, stream); + DISPATCH(true); + DISPATCH(false); +#undef cb + } else if (param.format == Format::NCHW4_NHWC) { +#define cb(_nonunity_kernel) \ + cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nhwc< \ + _nonunity_kernel>( \ + args.src_tensor->compatible_ptr(), filter_ptr, \ + args.bias_tensor->compatible_ptr(), \ + reinterpret_cast(args.z_tensor->raw_ptr), \ + reinterpret_cast(args.dst_tensor->raw_ptr), nullptr, \ + kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, \ + cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ + m_algo_param.threadblock_n, \ + m_algo_param.threadblock_k}, \ + cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ + m_algo_param.warp_n, \ + m_algo_param.warp_k}, \ + m_algo_param.stage, stream); + cb(true); +#undef cb } else { - if (param.format == Format::NCHW4) { - cutlass_wrapper::do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4< - true>( - args.src_tensor->compatible_ptr(), filter_ptr, - args.bias_tensor->compatible_ptr(), - args.z_tensor->compatible_ptr(), - args.dst_tensor->compatible_ptr(), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, - cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k}, + megdnn_assert(param.format == Format::NCHW4_NCHW32); +#define cb(_nonunity_kernel) \ + cutlass_wrapper:: \ + do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \ + _nonunity_kernel>( \ + args.src_tensor->compatible_ptr(), filter_ptr, \ + args.bias_tensor->compatible_ptr(), \ + args.z_tensor->compatible_ptr(), \ + args.dst_tensor->compatible_ptr(), nullptr, \ + kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, \ + cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, \ + m_algo_param.threadblock_n, \ + m_algo_param.threadblock_k}, \ + cutlass_wrapper::GemmCoord{m_algo_param.warp_m, \ + m_algo_param.warp_n, \ + m_algo_param.warp_k}, \ m_algo_param.stage, stream); - } else if (param.format == Format::NCHW4_NCHW) { - cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( - args.src_tensor->compatible_ptr(), - filter_ptr, - args.bias_tensor->compatible_ptr(), - args.z_tensor->compatible_ptr(), - args.dst_tensor->compatible_ptr(), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, - dst_scale, - cutlass_wrapper::GemmCoord{ - m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k}, - m_algo_param.stage, stream); - - } else { - megdnn_assert(param.format == Format::NCHW4_NCHW32); - cutlass_wrapper:: - do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< - true>( - args.src_tensor->compatible_ptr(), - filter_ptr, - args.bias_tensor->compatible_ptr(), - args.z_tensor->compatible_ptr(), - args.dst_tensor->compatible_ptr(), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, - dst_scale, - cutlass_wrapper::GemmCoord{ - m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k}, - m_algo_param.stage, stream); - } + DISPATCH(true); + DISPATCH(false); +#undef cb +#undef DISPATCH } after_kernel_launch(); } @@ -315,17 +307,23 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec_preprocess( ci = args.src_layout->operator[](1) * 4, hi = args.src_layout->operator[](2), wi = args.src_layout->operator[](3); - size_t ho = args.dst_layout->operator[](2), - wo = args.dst_layout->operator[](3); - size_t co; + size_t co, dst_spatial_pos; if (param.format == Format::NCHW4) { co = args.dst_layout->operator[](1) * 4; + dst_spatial_pos = 2; } else if (param.format == Format::NCHW4_NCHW) { co = args.dst_layout->operator[](1); + dst_spatial_pos = 2; + } else if (param.format == Format::NCHW4_NHWC) { + co = args.dst_layout->operator[](3); + dst_spatial_pos = 1; } else { megdnn_assert(param.format == Format::NCHW4_NCHW32); + dst_spatial_pos = 2; co = args.dst_layout->operator[](1) * 32; } + size_t ho = args.dst_layout->operator[](dst_spatial_pos), + wo = args.dst_layout->operator[](dst_spatial_pos + 1); UNPACK_CONV_PARAMETER(fm, param); MARK_USED_VAR TensorLayout src{{co, ci / 4 * fh * fw}, dtype::Int32()}; diff --git a/dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl b/dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl deleted file mode 100644 index 9f09ce41d..000000000 --- a/dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl +++ /dev/null @@ -1,65 +0,0 @@ -/** - * \file - * dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl - * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") - * - * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or - * implied. - */ -#include "cutlass/convolution/device/convolution.h" -#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" - -using namespace megdnn; -using namespace cuda; -using namespace cutlass_wrapper; - -template -void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper( - const typename Convolution::ElementSrc* d_src, - const typename Convolution::ElementFilter* d_filter, - const typename Convolution::ElementBias* d_bias, - const typename Convolution::ElementDst* d_z, - typename Convolution::ElementDst* d_dst, int* workspace, - typename Convolution::ConvolutionParameter const& conv_param, - typename Convolution::EpilogueOutputOp::Params const& epilogue, - cudaStream_t stream, typename Convolution::ExtraParam extra_param) { - typename Convolution::TensorRefSrc tensor_src{ - const_cast(d_src), - Convolution::LayoutSrc::packed( - {conv_param.N, conv_param.H, conv_param.W, conv_param.C})}; - typename Convolution::TensorRefFilter tensor_filter{ - const_cast(d_filter), - Convolution::LayoutFilter::packed( - {conv_param.K, conv_param.R, conv_param.S, conv_param.C})}; - typename Convolution::TensorRefBias tensor_bias{ - const_cast(d_bias), - Convolution::LayoutBias::packed({1, 1, 1, conv_param.K})}; - typename Convolution::TensorRefDst tensor_z{ - const_cast(d_z), - Convolution::LayoutDst::packed( - {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; - typename Convolution::TensorRefDst tensor_dst{ - d_dst, - Convolution::LayoutDst::packed( - {conv_param.N, conv_param.P, conv_param.Q, conv_param.K})}; - typename Convolution::Arguments arguments{conv_param, - tensor_src.non_const_ref(), - tensor_filter.non_const_ref(), - tensor_bias.non_const_ref(), - tensor_z.non_const_ref(), - tensor_dst.non_const_ref(), - epilogue, - {}, - {}, - extra_param}; - Convolution conv_op; - cutlass_check(conv_op.initialize(arguments, workspace)); - cutlass_check(conv_op(stream)); - after_kernel_launch(); -} - -// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl b/dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl new file mode 120000 index 000000000..74e039d95 --- /dev/null +++ b/dnn/src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl @@ -0,0 +1 @@ +../implicit_gemm_conv_bias_cutlass_wrapper.cuinl \ No newline at end of file diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_128x128x32_64x32x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..42e715a0b9d61076178dfb86fc5cf4f51dba2169 GIT binary patch literal 1804 zcmb7FZBN@U5dNND;T37xP+8T|vC4EvtU!@AU?P>+H_LV0#;Rjmw$m4W{myx5vkp}F zP}hm?Ip00^+$A~^GxXu^dVF<%PbSyn$4m038+AHJ%$2qrt*BK(&5@oVmU~H^i;j$? z^NgbLXoM0$!OnTXkjY*8+F9u=L9LUgT;+}@Cx%*@anEh*4H)OGVE%+kK^+7G5TTgN z6OLNWvJT4=+F@EP$rKDaA~RAj;mMq58QH!kRs3XSsWIGsVYyJU6-6Q21-FVzB$V@1 zG4NX(JCS+f`b5l()rO>QC0t6}fOLCUn~u zo0`4)rM6KXazP`ATg%)Ft=oPE8wO^U>s7wIN?jdIK~aY5X=z-?FYoS0~YN- zl5TYFF`jVcw7nm#Z=cq~{zd!EkkR@v{<(uxgkTKU^`{5#yL}pMuL5Qcg%`}r1i<}E z_9X`)_X+TW@uES)e(Kb+6ypao3UZC1fMq1x49fF3`}TG3AP5HIF|2V)4Ie^%+o)Jh zdlhqAC;PT%{Wd!7h2J5|F8FMtx)fW7+z6@XIppS3%V~n~4^Z=ElKc{iQ#;bxR45vB zavKK`LUx%c70@VqU7e!L2o1^%3z0Nv{r`C;HT4B(Km3yUGN!52Z1Il-Z$ev?>Y5Z` z7bXzk^RwjbpvbbUy)qCWsko$jR&Bh^!K`?+H_LV0#;Rjmw$lf{e%E=lSqGv> zY3juHT;FrGlTe(;i&ptY0m9M6rIl>Jyc&SPRX$;+A@zpy{R- zuYRd*lm{wk1Yv8Ld!ZHjjv|Eo#&8AF)Ci0N1B?^C6pUls01Uw#LCk(Yd^=##4kYPD z=N{t;S5Dje(famjJ?vk!4~C4^kMYkPq|yUpIIcf^^Ss-q(e@~i=&kUAc}W1=zhqx( z5V%i3AB-0b8TP4D%TkOV%qUcA4EZY~*)k~i+H_LV0#%f|)e)Pew-*p~s)`2Ke zl{)r4*Y}+JNOUA-=)>Lh`0Da|m*h`3>U5BpOJz7(Q6q(%BQ-=tpR5eknwu}o3z@Z|D6m~{BYB2|w2n#! zersbZaz|{Rh?zD@lhm$+O^F?lZf~$Y>#t9HqFh6m)uat6Exk_ z;?tX+*eK2ISevE(aAe9~%!*%`XoA=#5jkZ^TGH-jV(XA=k*T={YA!XLCK&&KTE0w@UqW(fMk=2QNkf_3#-Rj( zT&~LkXk@)kPLXAV1|`9wMjERA|31qUbp>djXeoRd(==0T@lOTbgr+EqYf^+=m_UHf z&lzurf-JMzD+2*3i%Ys^)xsGMX2srCi8_6EK66568_*!8&WBFAIEQX^9Qqqx&{By6 z<$Rt<2~7esAUul^Rk!$VBhaYYtUeqzO6cNKr5I?bU+o8>xgW7V-O+v$a0zvEomjDY~{ zC2`{O9iQ(z7e@zTj^15gOwMm^$n;`zcSe5qqh1e*h0>OzHML5p1=4fG3NNX1(Sfma zkyA8@W0VRCcFqfiOyRQU_9|y7>Kxq{s&G6#GSt$Xdu}_gz&LLO^G8$)>L3_^2*qTP za@28_^;nV89@AVu*j^FE2HAQ*RJLXvwBU+c>UoN0+eW?S zr?ycMxS$x~*0S(IEA$OT2>Gqy3bbhu7zYX%r+g(C$G8O;gE@kr{fPK(z@j}!(ycB$ z##64Gwm0$S>V7jGoOWLgi8pu2k3FOo1Y z|CoZmE^1P-^T>^mdQkv3pIJ^*jK4!QpQY(fp*Xd%&Syf=P$ySOC_-SDo2mj@X>YPq zrWv78#jq@rmaPAN&!wim1U+V2DPP4jlbS96D#6Rp7FBgk%dkrm2=L)S@@}ZeDy{u0 zkU^a*(6l_v+XAf04#b{kW#}C*mq&e2;&RMbFO@1*&#{^%$_>-K=T#2OCe$6 z1?aq7y$v18Az;IhdmHlH`2c)qp>|W|DSowm>YAw%_(5eIqP4b5XDDlu7B-%C)$;qR zne~-M@55_F${XP2i0ZDmExalGt--kT<{DQ`*R&jkN=Q7kP)+o8>xgV>PiYUwYx!?>Lt>V<2EJ zablnE_sgXi1keVarT~4iy4z!_* zf}%+rqf}6^vraIiz0ICGtDL2%b9C>ew>&-4)X;)EZaS~PSZ4%tM>H4ILNEXklI1?- zsACQ5F`v>NQ*uRSV9*nV&IJ>WEO=3n-9D+~Cu>8s=H?UgLgt+)3T&6$NS-4ht)r5G z-^SRA!Vx67bF;4kPFphBxFa~o30lN`#-GD`VP)@hXJB+7X zT4iqH&DH&8JUH#X8WL~rk{^3WEe6JL-C*|OeOLGK?yAqnuQIO)fcvNHD-8noDd>ao zvL(Y~>eR3d<69kvQcWOvRU|tG)p=5U{qpD_6br@^cqSRueB2*)(Qwd3qgu9|hyOS1 zqJ9@OY1sMWTI6csftSk+rzytYp^DGa^rw)VnphPxA!(?Qt0WX3Fe`MG{Z`7G%#oeekE=jH z$nuhAb;g5LxwR9bN!_*2ozVFfBuJ?9p^HvWpiLcz#)c2HQX)kKU!+n(gFy3#XE~di z-rjEn7Ilj?hvQcHoPMYU11S|W|GooTMi^h@yx`I$h8?1mz|5&*4>VtayA(N0x&WP* zZ*N0~atPQk+}?&fw=Mu5TBzNGd5T|cpW0@k1b$FihiI+L(rU_@q=kd0UA6rFYG!<; z(fhEhNICgmj;QX6Tko`Yw;JQpn{8Y*UDI+DDj{*uLIrUMei`uc@#ML`i&i>1;4&5S F=nvTiOo{*i literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_32x128x32_32x64x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..7f956905bd0d23049ecaf75b6dc9adaa4486e06c GIT binary patch literal 1802 zcmb7FTTk0C6n@XI@QO5TsI2N`tTG)ED^R2jmftaO&3*2z<@a>tVsLoLm?=eBhK#(684KcP}k2f+YDC?@lS zqn5L*!}5f7m{v=tpR6o3hTAVJ7fQCGD6m~{tGGl$IZqV> zzqPRwnI~>Q#LQT2Na|L?rNj+Mzqju8hUdbj4a|&K0-R^8`&d zwRrWWwox9apfQB4W$uMm=sSuK@*BeyNK+#)4h%3(_);*AaRV>{a|AK_A@S{iMLUqB z8=ZTMAGvbc-i_BcPwUa(ynQfayncv(?jV&O7{hgg>6`c6JdL+kfiiD}7tBin;C`Qd zsX^fW2>M{WXvpxAI<+jt_}+{|wZ@RYGLkKW@;uJIeSL8d>ILHoJmZuaKI-?{2y)y; z{az*9*2g~6)4?H@UGUjPbtbkBxe-#&bD-u^%V~n~52)pfB>5#2r*^EfsZcbO$xR$e z5XfbwEPzJV>*N$!Mrc?PENY~o>i^#}si`kOFNv1ImoZJHW{ZC+@H#X_SzM1r*o6rM z`1~w+I}~J@)m|A02vl6sJ*^hr=3rLrZI!6gcbiKujNE_*G4*-qr1LZAR?nfo;R7wT zNKnS-iBixcFayG~7*TbL?=}LBs?F-dQKN*;KUIo>mKtWjt_vGR7+*=Aaphynj!{Bj z^i;72o(JHrg@l0@ptEB2K6I#ufDJS5eaK7a1Ms+H_LV0#;Rjme(4Lpe&@Wjfq{TY zY3juHobS1N?vm_@8G3hhKK^liO(y5#+cWa#FzNM>m@92LT2ZTnnj<|!EO<$sOZJSV z^Ma!BXoNCB!OnTXkO?k-Zm)8dq0YfwP{Hx+z)(vI?z!!}0^__D%pXuGsDoetA{3J$ z#>m09@A<`reM$$g^_{@Pv*QR$o4&{<0mUijp6na3qr|Gl0>!(ZWWhEDCenS z;I}b$qVU8Gh?p6x4awa~xSY6Sa@1d!*KW}#iVgDBfT(=MI%v)nx770tO*c*MnwQ!p zAu>TDh+50wg;wYrN)YlJ!xdzvL0}vSV4U%#U>xHXUU3ZX#|@@0o_BdS+8%Y(f0cPf0Ng)iU*#ZjpFn;v zUbNEim^!sA$N0vKqF7TXUlqxgLA9S2Uq3(Eh;qSr4EHprh7XVWT~z(*ylq|lzvW>U z!POLE>yR5E^*jJGpIc5djK71D&$8^7P@LM4E~Y}!sFKSx3J|yzrmBCd=uK_PA|rHM z@hc_LO6tGQGpVUBL63=6d9PxcO3fAz1b7i!qSDr+47)Ue0PpW5??ydVMeSFC4C+c9 zeyJDU1~7Xl+@^Y0L3&~226B*6A7Ue&o)2f*yeh5bx!HY8rg^ zBGRb4tl1p46mhaM`3NATT%#mAH9J~2j{WRI&%%#y#W C{!LW? literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_32x64x32_32x64x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..130ca582290979e90805f914773e7e64ca788d62 GIT binary patch literal 1800 zcmbVNYfsxS6#brG;T37xP+3LWF=RR<)xgW7V-OKldh*V<2Es zl{)r4*Y}+JNU|qp=-t)%3XW$7hFV&1&u!-w80W2E{(wqB9Rvdqp_mLA zM;&KbkA;l(m{vw)>=xpR6o3hTBgp2qim764@@eRa_#WoTrL` z-^SR9!V@>#xq7{hhL>5KPW-jBCefkZEb7tJdI;QlH5N`uJ# z4)noz(URdYb!u6T@r@bBVof1^RU}&m)p=Tc`TXc0mJ7xcc&0fud^G5H(b2Gr2K`#M zt&9J+JnSO)HHp|dD zsr@Ptkg~kw@TgvR8^EmG+bYpy?vL-DS<;s8vCyA8N%wOAQlX*Mu!2jL#)6xbmrG`zRwYca8OZOHG=N8m#XwH-1~@vH4K*K}><2c@--R@yF{p{z+-Jb2nw z%kQsd&{rD0kAF*2j{lbe4sNImpWK|F%*1}r|FJokw) M+9Z2iWnz~60Sp*TX#fBK literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_hswish_s8_64x128x32_64x32x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..78daff4fe4ee11797e75e189d63420daaf9ca532 GIT binary patch literal 1802 zcmb7FTTk0C6n@XI@QO5TsI2N`tTG)ED^R2jmftaO&3*2z<@a>tVsLoLm?=eBhK#(684KcP}k2f+YDC?@lS zqn5L*!}5f7m{v=tpR6o3hTAVJ7fQCGD6m~{tGGl$IZqV> zzqPRwnI~>Q#LQT2Na|L?rNj-%>0sUM4cGl%SoM2h(Ibit($#>dbj4a|&K0-R^8`&d zwRrWWwox9apfQB4W$uMm=sSuK@*BeyNK+#)4h%3(_);*AaRV>{a|AK_A@S{iMLUqB z8=ZTMAGvbc-i_BcPwUa(ynQfayncv(?jV&O7{hgg>6`c6JdL+kfiiD}7tBin;C`Qd zsX^fW2>M{WXvpxAI<+jt_}+{|wZ@RYGLkKW@;uJIeSL8d>ILHoJmZuaJ~|z=5#+dy z`n^iHt&e@C{oWy#UGUjPbtbkBxe-#&bD-u^%V~n~52)pfB>5#2r*^EfsZcbO$xR$e z5XfbwEPzJV>*N$!Mrc?PENY~o>i^#}si`kOFNv1ImoZJHW{ZC+@H#X_SzM1r*o6rM z`1~w+I}~J@)m|A0=viFSJ*^hr=3rLrZI!6gcbiKujNE_*G4*-qr1LZAR?nfo;R7wT zNKnS-iBixcFayG~7*TbL?=}LBs?F-dQKN*;KUIo>mKtWjt_vGR7+*=Aaphynj!{Bj z^i;72o(JHrg@l0@ptEB2K6I#ufDJS5eaK7a1Ms!3&*Fp)~^o8>xgW7V-OKldh*V<2Es znmX}4*Y}+JNU|qp=-u_j`1{QbnOuzT&dKk?q}M}YuC(Q7MXeHQj`R$%;3ah~*)x{T z3yQ|05y}JwJLd&MCb;~$y~giV}qU#&8A7Gzg3%0gN-g6pUls0t~?%LCF4u_-@3a9Z1rR4j$tP zS5DiT(faCsJsg~MUkw?p@6sPTNG%7(aNJ<};(1s1qwP^3(M#b)^NIktf6BhnAab98 zJ{T`rGCZbEEz2>!HKSOpDWtE8WXqu1Pm3>~A8o{P!FUY!G^d6Sj|W|J)Q`Vfx2=o+ zw><13xSB+49daY2o(Ev&bIWOl@i(aC^DO%*6sLBii>XjFR>@Tw3lO*!rpkXS=}m6R zBqMZE@heNDCF;M=GpVUBL63=6x>qqxrDlu23h*+vL{(dpGVIa>0(^Loyc_GWN@~9f z1mr9aIXtcx-Ucu$-))s>vUeS%7e;PCgOvIZ8|my6dew92Zg@aTEizQ_d8QQf2uy%@ zF9%f9;JX)*M%`u2=CD;kXCG?CKuZl1VAq5#BaAO3FSzolW&0>2FnH?NBhOdhE`@|~ z7ooFq^)_^@hlq_6?rq5L&PU)w3$-0GPw}hmGuL!&FX znmX}4*Y}+JNOT~k=)>*R`10!2A|+)~d|G+8&= zH7~V|iogWL5Ve+t7h0k3C_>1u4OgH{gTOctz&Pa#!8pb(zzEC{gzQg=?*=T|f+St* z!eji%mDBbvUfn#cM#J;&-jI0pko??2YB?~5h()k(ms^`$%@PHOtq$uaJ zR4M2Ym;m8k4ydNVw=V*Xy33l)QLBQ^Kh=tXmKr9&wh3EC7+*=AbLA7u4pB;A@YJyf zo_pY~g@kbzpwn{oK6I#ufDIGweaK7a1Msye`TgAt`bwjZ;cr38@&9&2bz9sQ-W2}cU|f20jjN_@T8=^`Bp!MwCmz9f0~Vjp QUi-utZK4CNQZbGG0zBhQa{vGU literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_128x128x32_64x32x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..aeec4c86193c2b572ca0358886ded75f581e05e1 GIT binary patch literal 1800 zcmb7FTTk0C6n@XI@QO5TsH|$~SY(5txqUI~sLNlR+rCx9} z+0^>gFSU*Gzy`$-xlVI0wL;%fgpl7Dp+KD)fpH*#F&9h81jY@(5X=#T?FYoS0~YN- zDc$JYV>}khX?q{9Z=cq~{zd!Eka+!={M#^FB&rJQ>T{AF@7*{DA)v2SVppCP@X5*x37B#p=2-~!J5pe5ktst8x_lG zuTpO7WZ(9z-$tjs@H=GL1)ptHr()|+7@6u>4&40Q3d%A50rh;z`7fyiwXx17QqfQ+ zz$-Il?i&fO^HL-jp+U*5D36Aq|36PtO??5{54_a8jAoK*w)m&`ZbD0xbu}&mE=-b@ ze11+vJJe#C&|Vn`$WvU>J*zg}=3rL5+iFl}?k-QgH0cJ|C)DSmi7w8eSG|Dlh6l9N zlA}z_xKhv`Fag3^45+%ncP|2ss>ABTVWWC3K2?f=mKr9&t^pfH7++H_Nr%)@ou~w$lf{e%E=lSqGv> zY3juHT;Fr(d@w1xJ%j zEnodo+b9oYPz-_VH1|>~^c_VA`Hc|@#HkS&2M!o>v6M_;+yD&096{86Kzuu3(GHZ- zjm|yBW1*b3_woAnX+7*;v=4^F>&N8h4pJ$CF&x*Qym{X3Q@lM2Bzh~nU|tdc_b=I( z8U*fR&X-rc8Y! z-*sAw{30|csTI}H(DVQ2X{xC&K>HL+$;)Uasb-6RYVRgAL|InjBH+R#S;^<;RJ21W zmig?Jfq*>4A>FfT;cX6P#k;Krb>i;w)Jv0YfPF%J9(w5F96Hqt=xumFOD#Fd#EdHi z?E(EC?!|bj+k5vS(5M=$-W)cH=i*bP7-*@X|9Ab@Fv9pc6&dt?V%ag`1jbGkd*FEh z?nb6C>;iOJtlozXl@PFDzP%55<$M4>H&EO8@)Ezwp1ZnX13xIOW3l4AJ39Z}T~54ktFe=ryq-dydfZkUFnPzi~L?#YNp@YR6D=kwP- NFh-l`NGL9+(O>RaOi2I$ literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_128x64x32_64x32x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..15fc0d1bf9b2a33f1b0da0bb790c6d11f0733890 GIT binary patch literal 1798 zcmb7FZBN@U5dNND;T37xP+8T|vC4EvEI^SqU?P>+H_LV0)~aJ$w$m4W{myx5vkpX& zs?@RXIp1@4&m}q%Q}p5PdUSPvPsZ1y$4m038+AHJ%#^kqt*BK(&5)iVmU~H^i;j$? zvy7rq9HUfFuybB8WO6ru?W}Z`qSnb%u5!oI6GJV{xaYQY0LFPMm_MOXPzS*PL?|Zn zl%tljti$q@c9>R6G692*$cz+BcrxQzMz;H;il3}3HHOrdXi@AfI)UIm$XE4*M{5&-uv z*_Sy8+{cg~j2DeG>{F+f%`tv3ap>0sDp*FcWl)|c*|)EI2ccsy9>Fu2Q^SW);x;<# zx6x^@O5E1RzG-Q7h-J5p;8#~->yR5E^(+T!er`EUG5!JFe3_=dgyPi3I-3YZLz@7t z%#@{X)VnTAQD1}xCAFeE8u|SHeJVBe1!$jQnesB4iPUWIPw(9XL6mJZE&?u0l9hOV zmb@KWv8-pW34R&jS!aLD0_NX=Bv z1!{T6IxOe3!;D^$DHwEQW|d?zkh#b*vinY|_{rK)YlZv7a;a4-Ng~^&a9SuNwGUJ? z@LL;unFZo|M9!=;mPjsikYP*wfOPwt?(tyL?KgKsv)*FeBYLrBEi@NeI2r^;(`~I^ z{n5B2k8IEgA~z}z(rEMzB?yJB6&lp35g11j7;~|bOkmsq48a^h*kM3IJ7Uour0CY< z0pp3#-ng65=IVYk?47n>4H<3j(jR+BB?-oGU2pp0eOLFR-Bq8DUu9kr0QXPXml{Ow z6VM0aWkZI?)Tv_&jBo8I7HkSBEF;-5D9_XE>z79dv1BkF!iLZGpHd6zMkbp|O=Fz^ugsRYZzR0VOOarN z1|_$mJQ{-j`#w{Kh641MW2t!=%~Tn-{HypbV@s5EH7NovOwzS{cu=ApYq3mdw+y6L zWeLz@*@aAh*b-y$m#Y#v6G*3nXg5L&gKAt@Hhq1;c2SRf> GOa1_-j!cRG literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_16x64x8_16x64x8_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..78624fb9a267297e67b0231d6a33b729567404a1 GIT binary patch literal 1791 zcmbVN+fExX5PjdTaD!B}C^QO9)5xSqG(eDApkgD@H_P!li50JH`H~A?-|=3OEd>F6 zS?}6&#&gbGHaQS8^zQm%e13C7CKuzoGxEEebUH}Pm9`wMsZ~PFk)9zIyrj-02gcHQ zPSJQYLYbgo=e%IZ1h;tZtaO&4*3o@X!SU?KP)l>}xoy1yR;Q{Fb-@>Gn7M-e$iZ6UBOqb&sgUnzhiJD{iUh8Jccu?dnZ! zlMtDp5k#$J@Iou}4J8Qqt>FrksSy}Q0vKm}B^bxJ0T_Zgf{^`y_;$piJxJ264j$tP zS5DiT(dO!YGwhwVUkw>;?$RH7NF@ixa9wZu;(b^5quo`Xk6&e85&-v4*_Rqb?i0`l z<7GpJ$JD813yg2gC>Co9=_@1IF(}W|{OgxT2eDi*9>X(TP{W7aejD|BZ8WH4+j;nZ z!#3)+QJsdJKW>E7^8md3!g88n{2l7}EX#fh#i^dh!auFJo zyo%Cji23jLOls;2&|`+B-eojXsoC2)Hmw*W%$p@^-AlGMUF^ARu9J zNxga-z^d5VNl+*4I!G^!+yeQO`Vjl)^aNVeb7*S#Kr1aWl=FF}6to63etZ^lsqX9j zMr2VnSA95a6wc{~N-&U8L*wszuVIAoh2%LjeQMbu$_UJxD)vb86}U?wVai45tay7H zIu=62#=-VB?lTN*@o zNxbVh<2h&Mc%vgZMIUaj#+P?@U$dCo_>{Wc!{}^^=vQ#t8d`shp>Z zLEbvp$;=ZsAaZJ~Hbiovy!0F5hNRzHcYDKizZX`ggRtlk#Rl`$fT;P3wa`o`VW}4! zO*XZB^-FD|Jdi;#6t2_UORdm%6d~j{Mko-cMqnH`V9do*GJ$aeFoJLdRr?|F?SMr) zP)avC_ZUA4<+Qzv*Edh=(crv&FeF|-BtLhMN)e3VxWVMj^KPEv?NOl2TjhoDk^rRN z7hh@+xIcnE7%v(!yrfPon`3-$;xMlXbg-;s%b?s(vTt8sY=nuycnrVEoEkCe_u2@0 z+(!LgCEPa0zR=UbA(36k*+z9Iwh4uish;IP&Cji%9OEA_%@>^il1fk;>ue$w4Pyed zGE=s`(eJt~MSl?*meh*rXz2O>^EB1e7oe9EOUcV>CaGqNe`fDG7@{1j$D-gOBw5Ml z=Tx-AD3<-~m4$#b#Ub6(YT<1TVa2RLOi2I$ literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_32x32x32_32x32x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..a50f3970bba155b65ca6558565a2969e1091fe4c GIT binary patch literal 1796 zcmbVNZBN@U5dNND;T37xP+3LkHe@;^mO+s=U?P>+H_Nr%)@ou~e(4Lpe&@Wjfq{TY zY3juHobS22=aTHnDSCHxKKgNeO~&V=+cWa#FzNM>oGEPuT2ZT{njt+!Ecl!{m+To! zX9Y#0G(}ueuybBAWP+PNw^upisB>@^RB(bH7;0%Dys({DV4Sy-`2(6u>L3_^NX7EN z1?o7e1YNg~^Yuu9}eD(9(U z;I}b$vhc(Wh@2X$4Ut?ZFa3tNV{+7Am)CC5CyEW`s{v8-73-jxP{L9#IGSvl>@_d7 zO+w^?QixmU!Aq^sH*3(E`)WwKzRiB_Aaz18hT{g47tg!AOSeZI^3L6$}#>9y?n;`FR28zsV*i`(by%xsxVdMTjg%5QkEB?$nh;pd9YFq|fnq(_^f1itP?8K^^{VI?_U8BP< z^}^c#W)EAosoPb^y)^j-*k{y-D5296(5V-oYq*Nl_M4_a#e4w$ULa0ARfUt0~Q}op8LEQZIV5qxSS?`0XH2> A(*OVf literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_32x64x32_32x64x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..5005718657f201ae62d6adf67e17ac5a987f9398 GIT binary patch literal 1796 zcmbVN+fExX5PjdTaD!B}C^RBTDKaS%-9nIBpkk5eo8@?&#ERFpe948c?|3iCmV$sv z(Z*}f8PA!SGn?#*DSCHxKK_1vP43Rew`b({Vbbd%F;m)dw4zoCHA8xeSn!fMm+To! zX9Y#$G)1|fVCTGG$OJck?yPc_qt3y7P{Hy1z)(vI?z!!}0^__D%pXuGsDoetA{3J$ z=cwZ>>#>m29@A<`CScGLg^_{@PiDL*$abI9@spLM#&G+I1)*doNg~??w~9+7l=D0NkkA~|(f4Arp#fJ0Mkf`~JbFphBxFamP~LHlFkyAg|aAW1hm zc#QA3a@t;}>&yG~Xn5LvH6&f%WQI{o$663a^tkLrcD0nEz1Z4R2cT?OfdksDy2Q6Hj&PESCmo`bUC1ueD6QNd@qQlJh{ ze|(n1sS)>XBhsk#)f|pmy>t4Z&KTrUgZg*!YZ+mDE_ngU&n(+VIf0>5#~yjU0(T)K zjJgP&maDg+V<$vxoNjMJes?|sA6lsGczKFnZJ)bFuaO_L);?NkyKsiGrfBitX;&@3 zznURm<>-C?lTN*@o zNxbVh<2h&Mc%vgZMIUaj#+P?@U$dCo_>{Wc!{}^^=vQ#t8d`shp>Z zLEbvp$;=ZsAaZJ~Hbiovy!0F5hU9dx?)HZ3elM*0y|CyJ#Rl`$fT;P3wa`o`VW}4! zO*XZB^-FD|Jdi;#6t2_UORdm%6d~j{Mko-cMqnH`V9do*GJ$aeFoJLdRr?|F?SMr) zP)avC_ZUA4<+Qzv*Edh=(crv&FeF|-BtLhMN)e3VxWVMj^KPEv?NOl2TjhoDk^rRN z7hh@+xIcnE7%v(!yrfPon`3-$;xMlXbg-;s%b?s(vTt8sY=nuycnrVEoEkAY9kdbj zxQ+U~O1N!~eWCr{A(36k*+z9Iwh4uish;IP&Cji%9OEA_%@>^il1fk;>ue$w4Pyed zGE=s`(eJt~MSl?*meh*rXz2O>^EB1e7oe9EOUcV>CaGqNe`fDG7@{1j$D-gOBw5Ml z=Tx-AD3<-~m4$$k#Ub6(YT<1TVa2dh*V<2Es znmX}4*Y}+JNU|rV=-u_j==;qL8DEU<&dKk?q}M}ornD7kMXi!*hV&G%;B)F+vS%!v z6%>uq6mdzx&Uwj@32y$}UgeCV&cS_9!3lm~sHKJQ!ggMPao$Si4`?o_gJ1w670Uw` zsN*c_vA}7MX|*I1FzCs`oUwdGWlf`*eF0Nc2*8(Yzu6?w_)+G>F{C zpby52mJE-nQ_JQU-=0y-NW&{9i|3NhnKL3cp+$9p-P zn&#fUh&1Z{YBq=>A>%wTv*n$VCC2pINq#IDw&4#~yjU0(U8M7>6e>FqC z(&&BsTat40zZ_BB5Vygb;BO7ar8n2OY8s~HC{{w|p?eDA5qvXX@$vMz&x_F}*%OM( GY4Qh__Dkmg literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_identity_s8_64x64x32_64x32x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..ca3cc8195ba3814106d80a410809f741287b2a28 GIT binary patch literal 1796 zcmb7FYfsxS6#brG;T37xP+3LWF=RRFX znmX}4*Y}+JNOB-2=)>*R=<@E4JYJ0+F36wbq|-ris*Mw9Nu8E@ip&JDJSggYa$p^u zW)zKvL&PNoJ0B!NHutmF_DW|QwT_;0oqNHLEOj&!LAcf)7$2Nu;fN|pJp=;~shP^T zKrQcBhvl4hn9&O|27`{wtddLyG8I`yHt$IlKUq3zt#DshF12bUNn|@0P78&k_JL{! zersbdvp{^G$cc5v63K-QGOUR|B`5tAT)l2w^ty4;C7ShT%RbSwC2OIn(8AFmI2x~O z_3D?#C3)n6h7h+=d5}hLLBPg;auI49E4yZ=QGaG~65o%Dfd`G%pE&`(5^> z29f(C=!5aRA;U}R)Ug@H_jVY|HH8G0k!%>0`)T&=>x+$8Fc^>Ep3bNhgOh$6#b2*m z>9%pP%d*$q$MU$1;Hnd`aVV@*W|{-DFmr-(jDJ8aUvU0QYC+x5WMipmtP)_A*)sEu zbk|ua(u>e($*d@jhMfODPn4md0KKGGDqco2R))?0DZT614rNt676BI~=~6yFE76XX zSf;aE1_B}$ha8_&i{NrFE8cB1sPlH4t01jf1N)SQJhssJ8T4ro(An^S7DjTEi7D3_ zx&yjD-izT>H~02Mq*3)(y*X&q&iSWGG0@UL_ix*;VTAFO5*c)U>ewOT1cpu(d*rzX z?pi7sbrCu#R_{Z{N{HAv-QI`1^dSPD8>r2A*~PE2=f3XO$PY^E5G{?HdrMiJw79X` zRm1P^ddQa=eT;t#QjGq$BdQwWJ`Xkz_ZH*Ao3CBf4byNGDPs N#TuI&2+idr`3uKCOYr~z literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_128x128x32_64x32x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..1a0bdfc4118cb5ca093161ddeca3c7574600d851 GIT binary patch literal 1800 zcmb7F+in^$5PjdTaEVm45h;Su6gf0fN)jTqNveR9zFEd=z`|==zHrIccf1$a(nKT= z81H({c+Q#2CP!k5KHOaoukP>3_jn1C>0d!oEHq4;O4KL70y!BIC%;xIG&yuYH7wjw~aSooVSAc6DkFD5DY+sVlt#0 zHJoKF7E;<`S}n;03|b;HQZV7kjAt3yz9(h;WM!!_+n?bp13X%Q)9Iumdij~kF-0zb^M)n$XOTGxaknZy7N_+sQHRD(2Of?spl!0Y%1-l zm)a&FGC>1~TFc;tR_HrQ5b_(t6)007FpdN;PWe(Wj&Tjp2Xh1=`yTPlh($Y)q#GSP z#$&FWw)cbe?bEv7y=cA}GFU&3e(oTp92moO-O0iGZl4C*t3a7U;YIU;0Jwk2zR)0Y zAA>#^FKROEQ>T{AF@7+ESga#RUlGZcL2*9HzJ1+0h~8 zPWEljx=nQ2iN8aZZSdJdRVua)xe-#&0&w$l%V~=752)qKH2ozMr*@#TiBL4w$s?2? zaLY`Q|6096CbmRTTjMU7SO&dJf$U4```H ziZVV+m4Y6D2@uzOKvfOCdl6}rT~;0TYZY|ysZ^{Qj;6eWB6E__rka_&*#`*%S}Ko8TV|#<@3FxvH9`<|tOe$U_HZ#3T4_ Tz~=M$Yo8dSO>)FlDyGR_6^cxo literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_128x32x32_64x32x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..2a588246467af144362244e3f48c4ccbf1ab11be GIT binary patch literal 1798 zcmb7F+in^$5PjdTaEVm45h;Su6gf0fN)jTqNveR9zFEd=z`|==zHrIccf1$a(nKNB zvUt~X#xrN;%qB--iay+353la;$@qHscuD@WlU56fnbMY{6}3vJ8PZe4f|t~}O|1Qg85H}Y?AzD9jo2+158*zVQ^WgE;wC!n#9v9= z*2li-S@%%Ob`!x>RbuOq8zJ>9Kx%$&IZZMC0iArAroV*Z)DCnu5sJn(d4vXp)G||) zzgF|AGJ{nb9a7`;|L3XH)aRi6Gz;2`h$d3A#XtRb69rMU);JG3HyN$O^Rwj5 z*p5X_J4GNsQGQ7KtXy~-z$|~a<)JFxU65WFxd8@7)Q6~}i*rz_=b&wPKuaxBl<`@r z6i5X0AMg2is^qb&bRj=ubhv-=Nf7|V_wRyv`<|n+L#}-)-hUXyKsiGs%UZJ zWmh%7zpF7{aP%?$ElEE7-;Ssh#Y6BW_y>b=?#)%MD$&#&#ZDM`P*6rZg0BWFKA*q# Ofic=7M_i?1n*0ThbWAS* literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_128x64x32_64x32x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..c802ed9243360f1c6c4b35446f1cbe3431b0a7a0 GIT binary patch literal 1798 zcmb7F+in^$5PjdTaEVm45h;Su6gf0fN)jTqNveR9zFEd=z`|==zHrIccf1$a(nKK= zTD5j23HwG3Wpg}$Q%A-^$PfiM*U;}}22DPIc4F|GmnV2&VS-y^;mv1kX9bfbgE zc+8d4_I|LweOmXs7tMnqgZ1O+=MGY;fiYayoxFM9?bBd;6)5voc+tEd0PbJ1FEohU z$Dj|!i<%7k)Tw22j33M(*6IkdS46U9P@IplZ(sKgVzppAgy(2Z4evvWo9L|DM5mon zxUG+U)57Wy%XSmNugb*MAvZ$mSpaH&ZaGac{sEPInWn#l;?xdwHW7-(GI@jo1ZtTn z(qGGYm6|-q2=xkjd4<$8{r`O`HT5}YpJpL@5z$0yw)m(1Zel|er8Ulj&P_%u@%$`# zGnQkK(@qfx2$^5fJ}VdA1~AL_wmMYFy9?3_BR8PHi24xw=;9nY)pO`=ctJ}oQk3yo zsuZ*c^nZNj^{Qj=Se4)|D__rka@P9j^vMC;dH^Dy`jB{_Ua#b}=%~7m`k%tb-h)3|%fW_zY P*FG>to8*YAR7{hGU zi`ToJGoEwivdMv%p?B98 zzqPRwxhJki#LQT2h~+X6Hz3{qrh7ctboB0^t5B`%;6*eFFMm zyllwum^!s=f$^;w#bQk%ePtv&2IYC0fBo|4AeIZpV_4G#HGJ6Zw-F>6mVPDN&d2|l zg1Z{#AgNu_emdniN47Ch1x{JV@S-^;jmgTL#jr zG6jkjmw6k&s@U#isFQacq!&hRL4cI{5c}x#1X|T|Xl{5xD=jjV^LeHevF6 zS?}6&#&gbGHaQS8^zQm%e13C7CKuzoGxEEebUH}Pm9`wMsZ~PFk)9zIyrj-02gcHQ zPSJQYLYbgo=e%IZ1h;tZtaO&4*3o@X!SU?KP)l>}xoy1y|0ds8S;^j=^J`@s(g4;|5>|<_N;|1LE5ei}oN%w>o%?CtNvg zZ$_J|`^~U-+I}@;w7E-v>>-sD7{hhF>5KPW-H&!xeLjAbc}W1=KV@HP5V=o4AB>j` z86H!omMt*8HKSOlDWtB9WXGU9PxG%|9v#F|!FUYMbU_UtcKdDA@3qmOl5OYV{|(!y z+eUR7cK)~#QqKeM@(asphVgf(;j=9JDHNx6r1PmzG}g!+lpgTPO_}>f!t1;g2}Wp8 zax2QCA?Ux~GpVUBK#y6LnwJqxrDn^&itjSELRnXnBIv>-U5keY$=k6O%Y+`6fq;y~ zCH3lU0IOnaCqtdM>ma=_ati{a)Q8wZrzg;&oo@~LHqC?ha$s@NmVSKuy%gee!Hv*PV- z=vWL98wcClkmt@v;6nqon=eoCtL!sZ&y&axD(eufwOu+xS)H_a@U*Lj-(U57FEx4} zmlY|d{>u?nJ#ibn3I5h#TzGS}tGZ_zj$$RG9vUbo9>MPdUOt{Y_h&KMBnMn&VwU^? D3=c}~ literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_32x128x32_32x64x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..3283d1d0a3f743f3c153779e9efaca5ad2b11e59 GIT binary patch literal 1798 zcmb7F+fExX5PjdTaD!B}C^X8YG%_g?4G^Rjs0b2$vmCFJSn=AHFS+pb9q%RC(jdai zuGgM3o;fpTHaQYg^x^huba{729Y^2jgzOKg5&9lp_XRcbK5uoR1Mv0Mh?`lQ=gw>$lHw-Z;Vy}0NQ#d`BqkEr>IHPDPJZmH)fnrtfS zs!eT^5EDT|2wKbFg;wZ0N)YlJ!xa!yAux{dW1RA(U>xHbU;yR_BKCdan-Pn4AW1hm zc#I#pa@yVv*Edh=LGQeIFl4xX82{WsN;WWt>w1$n@4I;#Zm$An-U=_87YTs-ef9+h zG51H{2jfMJhL@>R%jOv0n_=wLF;uUJWSc>8KF+>+Gs$8&1jANx#Cdxu!I!DkaymDoDuMo2vikeZ)cPE(A3Kqp_M=`W!;wL_gvgrc!c9-sjs zwagUduhqONOoE^HHyN+Q z^Rwj5*p5X_J4GNsPku@Jv|M-_z%1X}@=z7;Hb^gw+yDb(>O<7g`57qHbI>+?prsZm z%J?i*3M2yhkI#HORr20##578gRfmIG1D$^=83Qgg=zpidni0lVl4o4`v1P|7B`|i% z*khgt;I4&)VHcs(eDywb?1qSq^X+}eOXnl-xrW-#n0@(`_Nl8x8}ozKIz}sP7tT;t z6)hg@cUANIyBhNaM<3&BN%GR}{54?l R`RuhHjL{}J;wlx>+H_LV0)~aJ$e(4Lpe&@Wjfq{TY zY3juHobS1N?vm_@DSCHxKKgNeO~&V=+cWa#FzNM>m?>>JT2ZTnnjt+!EO<$sOZJSV zvx1^gnxb4#uybB8WP+PNw^upKQRm<;sNi^hV5p@9_uO_~fpOjn<`1Y8)Il%+5sJx> zbJTH`^;pPhk7>0e6ENtB!brh{Co^6YWc!}f@spLM#&G+I1)*doNg~??w~9+7l=D!cqE8eX%vS@V<}21gGp@L$p66(?X)@Qm)HVr` z2TCDkErSHTb$Zrf+ke3F5aSR{hoG%6A7`FgJFh>xuKPJ8#v1kX9bfbgEc+8d4 z_Bvf(-mQm&)9$Mw>H0SNxr5Znz!;7jOkO@jY*?aGbmma$(BL2pA}y}KiY_5!FUAsY)%ay9`(DZ`qg>cy7+&~!!ClW zDa6(xH$v)J0A_w}In6Qt4mv)|^It-7YExZIgrZR;w;(@YR+y^#t&%sDDNBsdamB6_ zNGqlPK2N2lz63p{S!KP7Xd*RRJdoc-Y==r&<1*;dBwLC5d&#>|ja5neRUm`9PKRIW zg|`9B9?G_<-BplY7`cHAWYmY)MyDsxr=CM+!vk7sk)wjoa;2a{p!?&!98OJh?_NY2 zb&oZh!xgYt^wWKldh*V<2Es zl{)r4*Y}+JNU|rU=-t)%`1|!WxjP@ur6y<_~o%4br6Wsi{v&vbHItTYb1;_IPLoF@1=eF|-jPq77e?X<64uS!QP)vrL zqmHwz$3jkfOsgfCfI&|bMhYf8nen0^+kH~UPga&1!|f*)gp!>kiEJ0#DlU;w&Qry} zZ)5C4;fWg(F*Q~jVz~^&9g{(SeKcGT`nyG+C^nq0hD6O*tb=A;aZ5eV(PYyo*KBH= zgvbM>5VMxS3$4&slpy3chAYseL0}xi$2jLp!8pb(zzEC{1niHA??x=zfh67N;4!}A z%4vI@t}pM`qv2`y)sS?3oBh~9YB4Z|>xPpT@4LKDw^xBgFNGJ)D+1vDDf>!;$o&rV z!FbV<;W2e;*&O2=lg3ibAbC|JTL#s6R($#V=pYsg#$$M9b87f#(C?z7VHXYhwQgG% z|8IHNMeu79v31Ceka`w?nV(xubBw=19iQd-PoX%qsV*i$(O4z7P=3IyFje+jDQ_}U zrWm2)id|VCElK}J(x`i^IUKd>=k!CZ7-*@X`*$taGQ#*=@`5X$S+cJ@R}7?m|cy zbrCu(S8qedYKYi4-QI@$?tBD3v{2g-^Ax|@K6g#eMt)FQ`)H-@!Wqh%q{V}$UA6rF zYKDBJ(fjzfB<1LTIik8LZh|+#-x!QbZ?19GG)>DWM7o$zG J$5k$-$sahUOU(cP literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_64x128x32_64x32x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..a1cbf38b007ff5de0af8a35030999dcda8225c7a GIT binary patch literal 1798 zcmb7F+fExX5PjdTaD!B}C^X8YG%_g?4G^Rjs0b2$vmCFJSn=AHFS+pb9q%RC(jdai zuGgM3o;fpTHaQYg^x^huba{729Y^2jgzOKg5&9lp_XRcbK5uoR1Mv0Mh?`sB2?Zg=|YZYQp~ow(=_#d`BqkEr>IHPDPJZmH)fnrtfS zs!eT^5EDT|2wKbFg;wZ0N)YlJ!xa!yAux{dW1RA(U>xHbU;yR_BKCdan-Pn4AW1hm zc#I#pa@yVv*Edh=LGQeIFl4xX82{WsN;WWt>w1$n@4I;#Zm$An-U=_87YTs-ef9+h zG51H{2jfMJhL@>R%jOv0n_=wLF;uUJWSc>8KF+>+Gs$8&1<;I!97P~#@* zc1q&5KK7Y*JBL`d!DkaymDoDuMo2vikeZ)cPE(A3Kqp_M=`W!;wL_gvgrc!c9-sjs zwagUduhqONOoE^HHyN+Q z^Rwj5*p5X_J4GPCW`0Tgv|M-_z%1X}@=z7;Hb^gw+yDb(>O<7g`57qHbI>+?prsZm z%J?i*3M2yhkI#HORr20##578gRfmIG1D$^=83Qgg=zpidni0lVl4o4`v1P|7B`|i% z*khgt;I4&)VHcs(eDywb?1qSq^X+}eOXnl-xrW-#n0@(`_Nl8x8}ozKIz}sP7tT;t z6)hg@cUANIyBhNaM<3&BN%GR}{54?l R`RuhHjL{}J;wlx>xgYt^wWzx0J)zjI#NjDdhj zY3juHobR5y=aTG+DSCH(G5UUUL&g`QyL0mUFzNM>m?>>JT2ZTnnjt+!EO<$sOZJSV zvx1^gnxb4#uybB8WP+PNw^upKQRm=3sNi^hV5p@9_uO_~fpOjn<`1Y8)Il%+5sJx> zbJTH`^;pPhk7>0e6ENtB!brh{Co^6YWc!}f@spLM#&G+I1)*doNn*AOZWWhEDCenS z;I}b$qVU8Gh?p9y4Y6DX;!eo%V13kwYq#hV#Rl`$fT;P3bxHXU#O_qaB$XrH6&f%Wj}V1nhcELxWVMb^RDjG?NLC{OW{THDgki+lzqiP%zX^} zV7zG2@Hlm9*&O3rlg3ufpmicU2+HQ2bmS7s(!2FO=Zdw zBXm+_S2jqC(tn?)Qd3`o9;aEcUPUyKnl1inzso3xshsX88+W=JMX?nO+a)>yMSZ1vCChng|qQiJ+;5^Nb^d?9(kmCr2OM>&C^Q^y|jd8qLQ1ptbhVO524qlr=?*8&A7x`Tf-l z`HG|W@o!1W(f@KpwJ2_bH^JW;j7x8>an*>Xt;lk9Po Hi)r!)IPFW# literal 0 HcmV?d00001 diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_nhwc.cu b/dnn/src/cuda/conv_bias/int8/kimpl/cutlass_simt_s4_ifprop_relu_s8_64x64x32_64x32x32_2_nc4hw4_nhwc.cu new file mode 100644 index 0000000000000000000000000000000000000000..00a266e755c907b6193a6f7c77de833edee1c271 GIT binary patch literal 1796 zcmb7FYfsxS6#brG;T37xP+3LWF=RR2#2oDs4GhQmcfTB0WJYcuAd04veMK zoTAZih%!OJ&UwL*32ye;r0s)LdjN=M7DEo6_-dT=c!`g zw>EYn_r&#ym>8=Kv0Mh?PRU7s1y`>d7rky=bctg9*|Ja6Y{^<^$`!ZN^9+sGwQ}`K zZIcjrpdrMpW$;2P^c^J#`L*E+bg2;-$M7-E_(Cv_aRV>_a|8kVQ{vkZi?$$1*E)EN zAGvbc-VIkbPpd)yyuCMMxOzx`ZXuNz7{hV>@tfz}JPkKTfiiD}7tKoo;C`2VsX^ra z2>M_=Z^-bHI<;(u@x2+wQcWRwWh5I0<$juf`}$%d77NBBxTiB}_~4}9M)B9{R=RCm z?6T~2_pv-~Be?2BY#eeUq@D&~=4Y1E4C5bA#}`@lODIn5Q0HTzXsnV4C_iA9n=<>2 zl-HRlQjE}P$*w4nhNS;LPo$>40KKGHs$NDkmYU7~DZlI34rOIM7C{##=~6sDOWuyv zSf;dF1_DABha8_&3vUCM74J4W)VaG2(hDQkpg>A}h;4L!27T%|bT&Mog%%mg`7~1s zIt02u-izT>H~02Mq*3))y*X&q&-tfHG0;*&_itOUVTAFOr2Q*~PE2&s^QJksp-SAzErXcZRY$X>nt> ztA^j-^^h+$`WXKfq!|5gM^rV%eefpudxLS|&DE~zrfE2em5_SqpqzLF-warMK6~x+ NVzfyPxXQ#N`3q@YOV> 5) * layout.stride[1] + @@ -370,6 +387,7 @@ void compute2d(_megdnn_tensor_in src, ftype* __restrict fptr, size_t fh, size_t fw) { if (filter_meta.format == Format::NCHW4 || filter_meta.format == Format::NCHW4_NCHW || + filter_meta.format == Format::NCHW4_NHWC || filter_meta.format == Format::NCHW4_NCHW32) { return gc_out.cur_grp * FS_G + gc_out.cur_off * FS_OC + (ic - ic0) / 4 * FS_IC * 4 + @@ -695,6 +713,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, case param::Convolution::Format::NHWC: case param::Convolution::Format::NCHW4: case param::Convolution::Format::NCHW4_NCHW: + case param::Convolution::Format::NCHW4_NHWC: case param::Convolution::Format::NCHW4_NCHW32: case param::Convolution::Format::NCHW8: case param::Convolution::Format::NCHW32: @@ -820,6 +839,7 @@ void forward_bias(_megdnn_tensor_in src, _megdnn_tensor_in filter, BIAS_ADD_CHWNx(4); break; } + case Format::NCHW4_NHWC: case Format::NHWC: { int dst_nhw = dst.layout.shape[0] * dst.layout.shape[1] * dst.layout.shape[2]; -- GitLab