diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu index f9dd4c45a64f4c4bcffa25f5c63c5fdbf4e51073..e9e056e2ce4c9ec511f4e599fa60aed5424d38f5 100644 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cu @@ -25,6 +25,8 @@ using namespace megdnn; using namespace cuda; using namespace cutlass_wrapper; +/* ================= cutlass kernel wrapper for nchw32 layout ================ + */ #if MEGDNN_TEGRA_X1 template void megdnn::cuda::cutlass_wrapper:: @@ -148,6 +150,131 @@ INST(true); INST(false); #undef INST +/* ==== cutlass kernel wrapper for nchw32 layout and nchw4 output ===== */ +#if MEGDNN_TEGRA_X1 +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( + const int8_t* /* d_src */, const int8_t* /* d_filter */, + const int32_t* /* d_bias */, const int8_t* /* d_z */, + int8_t* /* d_dst */, int* /* workspace */, + const convolution::ConvParam& /* param */, + uint32_t /* nonlinear_mode */, float /* alpha */, + float /* beta */, float /* gamma */, float /* scale */, + const GemmCoord& /* threadblock_shape */, + const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} +#else +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( + const int8_t* d_src, const int8_t* d_filter, + const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, + int* workspace, const convolution::ConvParam& param, + uint32_t nonlinear_mode, float alpha, float beta, float gamma, + float scale, const GemmCoord& threadblock_shape, + const GemmCoord& warp_shape, cudaStream_t stream) { +#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ + threadblock_k_, warp_m_, warp_n_, \ + warp_k_) \ + if (threadblock_shape.m() == threadblock_m_ && \ + threadblock_shape.n() == threadblock_n_ && \ + threadblock_shape.k() == threadblock_k_ && \ + warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ + warp_shape.k() == warp_k_) { \ + using ThreadBlockShape = \ + cutlass::gemm::GemmShape; \ + using WarpShape = cutlass::gemm::GemmShape; \ + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; \ + using Convolution = cutlass::convolution::device::Convolution< \ + int8_t, cutlass::layout::TensorNCxHWx<32>, int8_t, \ + cutlass::layout::TensorCxRSKx<32>, ElementOutput, \ + cutlass::layout::TensorNCxHWx<4>, int32_t, \ + cutlass::layout::TensorNCxHWx<4>, int32_t, \ + cutlass::convolution::ConvType::kConvolution, \ + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, \ + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ + cutlass::convolution::threadblock:: \ + ConvolutionNCxHWxThreadblockSwizzle< \ + cutlass::convolution::ConvType::kConvolution>, \ + 2, 16, 16, NeedLoadFromConstMem>; \ + typename Convolution::ConvolutionParameter conv_param{ \ + param.n, param.ci, param.co, param.hi, param.wi, \ + param.fh, param.fw, param.ho, param.wo, param.sh, \ + param.sw, param.ph, param.pw, 1, 1}; \ + return cutlass_convolution_wrapper( \ + d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ + epilogue, stream); \ + } +#define DISPATCH_KERNEL \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(256, 128, 64, 64, 64, 64); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 256, 64, 64, 64, 64); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 64, 64, 64, 64); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 64, 32, 64, 64); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 64, 64, 32, 64); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 64, 32, 32, 64); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 64, 16, 32, 64); \ + megdnn_assert(false, \ + "unsupported threadblock shape (%dx%dx%d) and warp shape " \ + "(%dx%dx%d)", \ + threadblock_shape.m(), threadblock_shape.n(), \ + threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ + warp_shape.k()); + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementBias = int32_t; + using ElementCompute = float; + using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; + switch (nonlinear_mode) { + case NonlineMode::IDENTITY: { + using EpilogueOp = + cutlass::epilogue::thread::BiasAddLinearCombinationClamp< + ElementOutput, 4, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma}; + DISPATCH_KERNEL; + } + case NonlineMode::RELU: { + using EpilogueOp = cutlass::epilogue::thread:: + BiasAddLinearCombinationReluClamp< + ElementOutput, 4, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; + DISPATCH_KERNEL; + } + case NonlineMode::H_SWISH: { + using EpilogueOp = cutlass::epilogue::thread:: + BiasAddLinearCombinationHSwishClamp< + ElementOutput, 4, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; + DISPATCH_KERNEL; + } + default: + megdnn_assert(false, + "unsupported nonlinear mode for conv bias operator"); + } +#undef DISPATCH_KERNEL_WITH_TILE_SHAPE +#undef DISPATCH_KERNEL +} +#endif + +#define INST(need_load_from_const_mem) \ + template void megdnn::cuda::cutlass_wrapper:: \ + do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< \ + need_load_from_const_mem>( \ + const int8_t* d_src, const int8_t* d_filter, \ + const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ + int* workspace, const convolution::ConvParam& param, \ + uint32_t nonlinear_mode, float alpha, float beta, \ + float gamma, float scale, \ + const GemmCoord& threadblock_shape, \ + const GemmCoord& warp_shape, cudaStream_t stream); +INST(true); +INST(false); +#undef INST + +/* ================ cutlass kernel wrapper for nchw4 layout ================= */ #if MEGDNN_TEGRA_X1 template void megdnn::cuda::cutlass_wrapper:: @@ -275,6 +402,7 @@ INST(true); INST(false); #undef INST +/* ===== cutlass kernel wrapper for nchw4 layout and nchw output ===== */ #if MEGDNN_TEGRA_X1 template void megdnn::cuda::cutlass_wrapper:: @@ -401,4 +529,131 @@ void megdnn::cuda::cutlass_wrapper:: INST(true); INST(false); #undef INST + +/* ====== cutlass kernel wrapper for nchw4 layout and nchw32 output ====== */ +#if MEGDNN_TEGRA_X1 +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( + const int8_t* /* d_src */, const int8_t* /* d_filter */, + const int32_t* /* d_bias */, const int8_t* /* d_z */, + int8_t* /* d_dst */, int* /* workspace */, + const convolution::ConvParam& /* param */, + uint32_t /* nonlinear_mode */, float /* alpha */, + float /* beta */, float /* gamma */, float /* scale */, + const GemmCoord& /* threadblock_shape */, + const GemmCoord& /* warp_shape */, cudaStream_t /* stream */) {} +#else +template +void megdnn::cuda::cutlass_wrapper:: + do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( + const int8_t* d_src, const int8_t* d_filter, + const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, + int* workspace, const convolution::ConvParam& param, + uint32_t nonlinear_mode, float alpha, float beta, float gamma, + float scale, const GemmCoord& threadblock_shape, + const GemmCoord& warp_shape, cudaStream_t stream) { +#define DISPATCH_KERNEL_WITH_TILE_SHAPE(threadblock_m_, threadblock_n_, \ + threadblock_k_, warp_m_, warp_n_, \ + warp_k_, aligned_) \ + if (threadblock_shape.m() == threadblock_m_ && \ + threadblock_shape.n() == threadblock_n_ && \ + threadblock_shape.k() == threadblock_k_ && \ + warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ + warp_shape.k() == warp_k_) { \ + using ThreadBlockShape = \ + cutlass::gemm::GemmShape; \ + using WarpShape = cutlass::gemm::GemmShape; \ + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>; \ + using Convolution = cutlass::convolution::device::Convolution< \ + int8_t, cutlass::layout::TensorNCxHWx<4>, int8_t, \ + cutlass::layout::TensorCxRSKx<4>, ElementOutput, \ + cutlass::layout::TensorNCxHWx<32>, int32_t, \ + cutlass::layout::TensorNCxHWx<32>, int32_t, \ + cutlass::convolution::ConvType::kConvolution, \ + cutlass::arch::OpClassSimt, cutlass::arch::Sm61, \ + ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, \ + cutlass::convolution::threadblock:: \ + ConvolutionNCxHWxThreadblockSwizzle< \ + cutlass::convolution::ConvType::kConvolution>, \ + 2, 4, aligned_, NeedLoadFromConstMem>; \ + typename Convolution::ConvolutionParameter conv_param{ \ + param.n, param.ci, param.co, param.hi, param.wi, \ + param.fh, param.fw, param.ho, param.wo, param.sh, \ + param.sw, param.ph, param.pw, 1, 1}; \ + return cutlass_convolution_wrapper( \ + d_src, d_filter, d_bias, d_z, d_dst, workspace, conv_param, \ + epilogue, stream); \ + } +#define DISPATCH_KERNEL \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 128, 32, 64, 32, 32, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 64, 32, 64, 32, 32, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 128, 32, 64, 32, 32, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(128, 32, 32, 64, 32, 32, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 128, 32, 32, 64, 32, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 64, 32, 64, 32, 32, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 64, 32, 32, 64, 32, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(64, 32, 32, 64, 32, 32, 16); \ + DISPATCH_KERNEL_WITH_TILE_SHAPE(32, 32, 32, 32, 32, 32, 16); \ + megdnn_assert(false, \ + "unsupported threadblock shape (%dx%dx%d) and warp shape " \ + "(%dx%dx%d)", \ + threadblock_shape.m(), threadblock_shape.n(), \ + threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ + warp_shape.k()); + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementBias = int32_t; + using ElementCompute = float; + using NonlineMode = megdnn::param_enumv::ConvBias::NonlineMode; + switch (nonlinear_mode) { + case NonlineMode::IDENTITY: { + using EpilogueOp = + cutlass::epilogue::thread::BiasAddLinearCombinationClamp< + ElementOutput, 4, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma}; + DISPATCH_KERNEL; + } + case NonlineMode::RELU: { + using EpilogueOp = cutlass::epilogue::thread:: + BiasAddLinearCombinationReluClamp< + ElementOutput, 4, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, 0}; + DISPATCH_KERNEL; + } + case NonlineMode::H_SWISH: { + using EpilogueOp = cutlass::epilogue::thread:: + BiasAddLinearCombinationHSwishClamp< + ElementOutput, 4, ElementAccumulator, ElementBias, + ElementCompute>; + typename EpilogueOp::Params epilogue{alpha, beta, gamma, scale}; + DISPATCH_KERNEL; + } + default: + megdnn_assert(false, + "unsupported nonlinear mode for conv bias operator"); + } +#undef DISPATCH_KERNEL_WITH_TILE_SHAPE +#undef DISPATCH_KERNEL +} +#endif + +#define INST(need_load_from_const_mem) \ + template void megdnn::cuda::cutlass_wrapper:: \ + do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< \ + need_load_from_const_mem>( \ + const int8_t* d_src, const int8_t* d_filter, \ + const int32_t* d_bias, const int8_t* d_z, int8_t* d_dst, \ + int* workspace, const convolution::ConvParam& param, \ + uint32_t nonlinear_mode, float alpha, float beta, \ + float gamma, float scale, \ + const GemmCoord& threadblock_shape, \ + const GemmCoord& warp_shape, cudaStream_t stream); +INST(true); +INST(false); +#undef INST + // vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh index 2d78e8c37a199c73b6b3b9af628cd74d82906f1d..85fdd29e6ea09d262ff62f771f745b153939bc79 100644 --- a/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh +++ b/dnn/src/cuda/conv_bias/cutlass_convolution_wrapper.cuh @@ -40,6 +40,15 @@ void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, cudaStream_t stream); +template +void do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4( + const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, + const int8_t* d_z, int8_t* d_dst, int* workspace, + const convolution::ConvParam& param, uint32_t nonlinear_mode, + float alpha, float beta, float gamma, float scale, + const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, + cudaStream_t stream); + template void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4( const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, @@ -58,6 +67,15 @@ void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_nchw( const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, cudaStream_t stream); +template +void do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32( + const int8_t* d_src, const int8_t* d_filter, const int32_t* d_bias, + const int8_t* d_z, int8_t* d_dst, int* workspace, + const convolution::ConvParam& param, uint32_t nonlinear_mode, + float alpha, float beta, float gamma, float scale, + const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, + cudaStream_t stream); + } // namespace cutlass_wrapper } // namespace cuda } // namespace megdnn diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp index 80fd2d35f0ac915a21a837bb2e1462441f29b1f4..b02e9027a1d42341acd565d9a5348690406e259a 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw32_imma.cpp @@ -35,10 +35,23 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), param.format)) return false; - if (param.format != Format::NCHW32) + if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4) return false; - UNPACK_CONV_BIAS_NCHW32_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + size_t n = args.src_layout->operator[](0), + ci = args.src_layout->operator[](1) * 32, + hi = args.src_layout->operator[](2), + wi = args.src_layout->operator[](3); + size_t ho = args.dst_layout->operator[](2), + wo = args.dst_layout->operator[](3); + size_t co; + if (param.format == Format::NCHW32) { + co = args.dst_layout->operator[](1) * 32; + } else { + megdnn_assert(param.format == Format::NCHW32_NCHW4); + co = args.dst_layout->operator[](1) * 4; + } + UNPACK_CONV_PARAMETER(fm, param); + MARK_USED_VAR // TODO support group conv available &= param.sparse == Sparse::DENSE; // mode must be cross correlation @@ -84,8 +97,21 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( using Format = Param::Format; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - UNPACK_CONV_BIAS_NCHW32_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + size_t n = args.src_layout->operator[](0), + ci = args.src_layout->operator[](1) * 32, + hi = args.src_layout->operator[](2), + wi = args.src_layout->operator[](3); + size_t ho = args.dst_layout->operator[](2), + wo = args.dst_layout->operator[](3); + size_t co; + if (param.format == Format::NCHW32) { + co = args.dst_layout->operator[](1) * 32; + } else { + megdnn_assert(param.format == Format::NCHW32_NCHW4); + co = args.dst_layout->operator[](1) * 4; + } + UNPACK_CONV_PARAMETER(fm, param); + MARK_USED_VAR auto&& stream = cuda_stream(args.opr->handle()); int8_t* filter_ptr = nullptr; @@ -137,33 +163,79 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec( } uint32_t nonlinear_mode = static_cast(param.nonlineMode); if (fh == 1 && fw == 1) { - cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< - false>(args.src_tensor->compatible_ptr(), filter_ptr, - args.bias_tensor->compatible_ptr(), z_dev_ptr, - args.dst_tensor->compatible_ptr(), nullptr, - kern_param, nonlinear_mode, alpha, beta, gamma, - dst_scale, - cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k}, - stream); + if (param.format == Format::NCHW32) { + cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< + false>( + args.src_tensor->compatible_ptr(), filter_ptr, + args.bias_tensor->compatible_ptr(), z_dev_ptr, + args.dst_tensor->compatible_ptr(), nullptr, + kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, + cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k}, + cutlass_wrapper::GemmCoord{m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k}, + stream); + } else { + megdnn_assert(param.format == Format::NCHW32_NCHW4); + cutlass_wrapper:: + do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< + false>( + args.src_tensor->compatible_ptr(), + filter_ptr, + args.bias_tensor->compatible_ptr(), + z_dev_ptr, + args.dst_tensor->compatible_ptr(), nullptr, + kern_param, nonlinear_mode, alpha, beta, gamma, + dst_scale, + cutlass_wrapper::GemmCoord{ + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k}, + cutlass_wrapper::GemmCoord{m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k}, + stream); + } } else { - cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32( - args.src_tensor->compatible_ptr(), filter_ptr, - args.bias_tensor->compatible_ptr(), z_dev_ptr, - args.dst_tensor->compatible_ptr(), nullptr, kern_param, - nonlinear_mode, alpha, beta, gamma, dst_scale, - cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, - m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - cutlass_wrapper::GemmCoord{m_algo_param.warp_m, - m_algo_param.warp_n, - m_algo_param.warp_k}, - stream); + if (param.format == Format::NCHW32) { + cutlass_wrapper::do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32< + true>( + args.src_tensor->compatible_ptr(), filter_ptr, + args.bias_tensor->compatible_ptr(), z_dev_ptr, + args.dst_tensor->compatible_ptr(), nullptr, + kern_param, nonlinear_mode, alpha, beta, gamma, dst_scale, + cutlass_wrapper::GemmCoord{m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k}, + cutlass_wrapper::GemmCoord{m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k}, + stream); + } else { + megdnn_assert(param.format == Format::NCHW32_NCHW4); + cutlass_wrapper:: + do_conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4< + true>( + args.src_tensor->compatible_ptr(), + filter_ptr, + args.bias_tensor->compatible_ptr(), + z_dev_ptr, + args.dst_tensor->compatible_ptr(), nullptr, + kern_param, nonlinear_mode, alpha, beta, gamma, + dst_scale, + cutlass_wrapper::GemmCoord{ + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k}, + cutlass_wrapper::GemmCoord{m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k}, + stream); + } } + after_kernel_launch(); } std::string ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::to_string( @@ -189,8 +261,21 @@ void ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::exec_preprocess( using Format = Param::Format; auto&& param = args.opr->param(); auto&& fm = args.filter_meta; - UNPACK_CONV_BIAS_NCHW32_PARAM(*(args.src_layout), fm, *(args.dst_layout), - param); + size_t n = args.src_layout->operator[](0), + ci = args.src_layout->operator[](1) * 32, + hi = args.src_layout->operator[](2), + wi = args.src_layout->operator[](3); + size_t ho = args.dst_layout->operator[](2), + wo = args.dst_layout->operator[](3); + size_t co; + if (param.format == Format::NCHW32) { + co = args.dst_layout->operator[](1) * 32; + } else { + megdnn_assert(param.format == Format::NCHW32_NCHW4); + co = args.dst_layout->operator[](1) * 4; + } + UNPACK_CONV_PARAMETER(fm, param); + MARK_USED_VAR TensorLayout src{{co, ci / 32, fh, fw, 32}, dtype::Int8()}; src.init_contiguous_stride(); TensorLayout dst = src; diff --git a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp index 3eb7bb9fc67924564732657171c9eab4f2d20f88..58c6314440de978a4145d603d3f56a703dfe1065 100644 --- a/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp +++ b/dnn/src/cuda/conv_bias/implicit_gemm_int8_nchw4_dp4a.cpp @@ -208,6 +208,24 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( stream); } else { megdnn_assert(param.format == Format::NCHW4_NCHW32); + cutlass_wrapper:: + do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< + false>( + args.src_tensor->compatible_ptr(), + filter_ptr, + args.bias_tensor->compatible_ptr(), + args.z_tensor->compatible_ptr(), + args.dst_tensor->compatible_ptr(), nullptr, + kern_param, nonlinear_mode, alpha, beta, gamma, + dst_scale, + cutlass_wrapper::GemmCoord{ + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k}, + cutlass_wrapper::GemmCoord{m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k}, + stream); } } else { if (param.format == Format::NCHW4) { @@ -246,6 +264,24 @@ void ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::exec( } else { megdnn_assert(param.format == Format::NCHW4_NCHW32); + cutlass_wrapper:: + do_conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32< + true>( + args.src_tensor->compatible_ptr(), + filter_ptr, + args.bias_tensor->compatible_ptr(), + args.z_tensor->compatible_ptr(), + args.dst_tensor->compatible_ptr(), nullptr, + kern_param, nonlinear_mode, alpha, beta, gamma, + dst_scale, + cutlass_wrapper::GemmCoord{ + m_algo_param.threadblock_m, + m_algo_param.threadblock_n, + m_algo_param.threadblock_k}, + cutlass_wrapper::GemmCoord{m_algo_param.warp_m, + m_algo_param.warp_n, + m_algo_param.warp_k}, + stream); } } after_kernel_launch(); diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x128x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x128x32_64x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..b2c9c4628f4f1084d10517b099b89824b4beb5c9 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x128x32_64x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x128x32_64x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x128x32_64x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..e758d0da7b78dc2c888d8b1bcc587eddc394e681 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x128x32_64x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x128x32_64x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x128x32_64x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..f707002d6b15af94eb308bad30475306f8ae1e49 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x128x32_64x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x32x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x32x32_64x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..d1c44bec4e93537c1c512eec4b703cbd8decf0ad Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x32x32_64x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x32x32_64x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x32x32_64x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..5158b527d76fc935e73de1de686870785afc514e Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x32x32_64x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x32x32_64x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x32x32_64x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..21c01d3675fc3eeec3f93b2bcb86fe02c95c25be Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x32x32_64x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x64x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x64x32_64x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..b28f68247219e07bca34c94fcb166aacd470fab4 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x64x32_64x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x64x32_64x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x64x32_64x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..f106af213e1889d6744bb87bfc84f268b48ee5bf Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x64x32_64x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x64x32_64x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x64x32_64x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..7f45ef6209a6aa7d3a35034a79017ed3a4df73c7 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_128x64x32_64x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x128x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x128x32_64x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..130b04b863e3fa754e9fb4b24127f227e3c05b27 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x128x32_64x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x128x32_64x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x128x32_64x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..428024430f170a0ccf0fd834633ba37896839800 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x128x32_64x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x128x32_64x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x128x32_64x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..0a7d3c990753413a214bc256be4e40b13aebc819 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x128x32_64x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x32x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x32x32_64x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..716a540b711bef86ac13934b1de12ba137433287 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x32x32_64x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x32x32_64x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x32x32_64x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..35b97bcda599aed2b2404d4968942c8f36d5f7b4 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x32x32_64x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x32x32_64x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x32x32_64x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..8ea93a1c4d9f0621c4d5fcb17b5e95defec5b051 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x32x32_64x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x64x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x64x32_64x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..3681959837ff5df0c0583c4b36c43592099f1aab Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x64x32_64x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x64x32_64x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x64x32_64x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..ebeb82911b4e51fbc55541f23a10a9ab2fb6a9e5 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x64x32_64x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x64x32_64x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x64x32_64x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..d9b73fa2dbdf5da5653e58f66b25ce534fd27ae5 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_128x64x32_64x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x128x32_32x64x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x128x32_32x64x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..0cd9b194f6a68d2ccd1b833fc7b3925d87b2161a Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x128x32_32x64x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x128x32_32x64x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x128x32_32x64x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..4a89381a8c71bc9e076d8028880be7f2f0788428 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x128x32_32x64x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x128x32_32x64x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x128x32_32x64x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..0011b40b18136af0a689e5b900e523b826363b6f Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x128x32_32x64x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x32x32_32x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x32x32_32x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..a60e3b181f2bc9f11f21596075924161e45cb4cf Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x32x32_32x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x32x32_32x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x32x32_32x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..4bee34e863810b9e00e6e2eabbe6f433f11396a8 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x32x32_32x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x32x32_32x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x32x32_32x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..69ab4e102db4ca1372b6cc177fc6efc6d0107e78 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x32x32_32x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x64x32_32x64x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x64x32_32x64x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..eb4d164e9f163ad737e9d2aba221c5ca5afa2c67 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x64x32_32x64x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x64x32_32x64x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x64x32_32x64x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..576e237832cdb914ddea1b42c64d27ac69be9e68 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x64x32_32x64x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x64x32_32x64x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x64x32_32x64x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..32917d0042db7617b6fffa77bd88ab2f754230d7 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_32x64x32_32x64x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x128x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x128x32_64x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..367f6012d731606ffcd95b2905785dbdf4b7c483 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x128x32_64x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x128x32_64x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x128x32_64x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..e5ba977b18faf2a326b89835afe139284cb16cd9 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x128x32_64x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x128x32_64x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x128x32_64x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..c832049c8becbf04d6a5767cbeb08aada92afceb Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x128x32_64x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x32x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x32x32_64x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..8a81f58713a332d785f5e84de8641d6a5c09a08d Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x32x32_64x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x32x32_64x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x32x32_64x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..e391dd646f6040f53bb180d2bf83e4c005adc911 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x32x32_64x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x32x32_64x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x32x32_64x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..27309b701e1460f4a00f2f1dd9233a24b05aa254 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x32x32_64x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x64x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x64x32_64x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..026080c88204d95408fa592b3091cc96dc6ab669 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x64x32_64x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x64x32_64x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x64x32_64x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..4deac474aac6e0e6de47e28f64325e71250240c0 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x64x32_64x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x64x32_64x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x64x32_64x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..b0a42c147bbc697cb0d15d746c2a6c3b0b71fcf7 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_1x1_64x64x32_64x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x128x32_32x64x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x128x32_32x64x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..48036be9e0b7f96f188f1afa782f508b00914e18 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x128x32_32x64x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x128x32_32x64x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x128x32_32x64x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..6d1a69d287cb30995e7139d4d6080ac0f842f18e Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x128x32_32x64x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x128x32_32x64x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x128x32_32x64x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..91ce847d7c0a6f30c87ea54a3e76be8ff3fd3214 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x128x32_32x64x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x32x32_32x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x32x32_32x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..6d4a698f7f9e66e88afa75556a565dcec5b86437 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x32x32_32x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x32x32_32x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x32x32_32x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..aeee09ac7e4bf9e66e7f8319788848958d48dd86 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x32x32_32x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x32x32_32x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x32x32_32x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..d601e01df0257b3fb27ae757184487f0edb5abf3 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x32x32_32x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x64x32_32x64x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x64x32_32x64x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..fd235e0aab8e34cbb3112af60c674eaff31b41d5 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x64x32_32x64x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x64x32_32x64x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x64x32_32x64x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..ac180bd42d3ae8a8f9c97de1caae4d54199d02a2 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x64x32_32x64x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x64x32_32x64x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x64x32_32x64x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..352535b74739e956adb292b2200735c356386e8c Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_32x64x32_32x64x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x128x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x128x32_64x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..ccaca6e8ea2fc21d02487839ff0076b2b603ea2c Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x128x32_64x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x128x32_64x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x128x32_64x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..35915b26724c1ec1b1f622f34c6359e9b8f1516d Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x128x32_64x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x128x32_64x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x128x32_64x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..1c4b308e8d2717e9c4177af54197f14eb6b939ad Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x128x32_64x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x32x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x32x32_64x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..c97b5e918f1362e796388aafc53503ca016b112e Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x32x32_64x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x32x32_64x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x32x32_64x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..8327475ad764906b99f5c5499abb507f208ef675 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x32x32_64x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x32x32_64x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x32x32_64x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..1b0ba002d56de04a7f64cc51c3be396caccb687a Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x32x32_64x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x64x32_64x32x32_hswish.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x64x32_64x32x32_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..cabcf1ae65705db3bbe6701f06a589efed9ddbba Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x64x32_64x32x32_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x64x32_64x32x32_id.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x64x32_64x32x32_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..2f111fd96779010b8e6353ed1d8c33679a68cb46 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x64x32_64x32x32_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x64x32_64x32x32_relu.cu b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x64x32_64x32x32_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..b49faba5573568eac9a748fce52fbe5f28522d9a Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8/kimpl/conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_ncdiv32hw32_64x64x32_64x32x32_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x128x64_64x64x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x128x64_64x64x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..e11f92ff8df3965d38121d3f17051e112d851d40 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x128x64_64x64x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x128x64_64x64x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x128x64_64x64x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..bf638233d3e95f19637a082b05fc27dcdb385e96 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x128x64_64x64x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x128x64_64x64x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x128x64_64x64x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..5ef27dfd260d4f04676311b442ca52b132f3a29d Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x128x64_64x64x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x256x64_64x64x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x256x64_64x64x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..f46311f22deda0d9f8f4c555751febbcaaa9a562 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x256x64_64x64x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x256x64_64x64x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x256x64_64x64x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..f9b97cd47327d336721a342f29e89a98b8869f4d Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x256x64_64x64x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x256x64_64x64x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x256x64_64x64x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..3ac0bd364f159da54e205f74bd38156d57adc41e Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x256x64_64x64x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x64x64_64x32x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x64x64_64x32x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..c426dddf71e88169516ba2ae61ee62748fab4b3a Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x64x64_64x32x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x64x64_64x32x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x64x64_64x32x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..9fe447f3bc5cdc66ae5068f5948b67f71386064b Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x64x64_64x32x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x64x64_64x32x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x64x64_64x32x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..198293e2e5f9c383d75bd9efb258a50edf7d25ae Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_128x64x64_64x32x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x128x64_64x64x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x128x64_64x64x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..1d4a974f24ef62589264e938ebf6209341f57c17 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x128x64_64x64x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x128x64_64x64x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x128x64_64x64x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..069e96db4230c0ee9b4b0ac6cb33f58b603d7775 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x128x64_64x64x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x128x64_64x64x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x128x64_64x64x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..80a947c22da4cd66991b3ece5122ecb7077089a4 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x128x64_64x64x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x256x64_64x64x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x256x64_64x64x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..5fe0f19c1f92f7e50c037772dcd2332a25e71d31 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x256x64_64x64x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x256x64_64x64x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x256x64_64x64x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..ca80dda287c9f82ab8714da8f4c50493a4e27953 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x256x64_64x64x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x256x64_64x64x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x256x64_64x64x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..812cdc8124b4ae6df3ade810bdd81473d1cfd6ce Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x256x64_64x64x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x64x64_64x32x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x64x64_64x32x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..0b0da908e235ad22e5185606a2b5ca3c8eb6d73f Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x64x64_64x32x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x64x64_64x32x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x64x64_64x32x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..89608b5bb926679024612b54664113a32342319d Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x64x64_64x32x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x64x64_64x32x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x64x64_64x32x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..80aac1fcaccd6edd6950694a98200d426f993b2b Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_128x64x64_64x32x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_256x128x64_64x64x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_256x128x64_64x64x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..338226f8164bf67916a24a709ffaaaf55f377e42 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_256x128x64_64x64x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_256x128x64_64x64x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_256x128x64_64x64x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..bde7fed6b1ff0937607d55bcdbd8728171d48cc4 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_256x128x64_64x64x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_256x128x64_64x64x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_256x128x64_64x64x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..5cc092b65c3a41ba1f4f1f3fb97eb88b397fd44e Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_256x128x64_64x64x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_32x64x64_16x32x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_32x64x64_16x32x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..b49cdc6be5ee137b764147452395fc5246d08226 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_32x64x64_16x32x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_32x64x64_16x32x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_32x64x64_16x32x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..6957270fbc22a31e00543ef6ed17e920eff1b6a8 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_32x64x64_16x32x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_32x64x64_16x32x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_32x64x64_16x32x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..28960bc48cdb995ceaa4be6248359addb367e25d Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_32x64x64_16x32x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x128x64_32x64x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x128x64_32x64x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..99ff27cd0c4e2602dee58c53d6a8ec8ce0139248 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x128x64_32x64x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x128x64_32x64x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x128x64_32x64x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..c913c159c6ef972b80ade9bf46476ec6028c08b0 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x128x64_32x64x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x128x64_32x64x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x128x64_32x64x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..5144b3bcbcbaa55a6d08d60acb92f61bffcda14c Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x128x64_32x64x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x64x64_32x32x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x64x64_32x32x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..ef7e3b9de85be365046298f8a05750a26e325e72 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x64x64_32x32x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x64x64_32x32x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x64x64_32x32x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..adb02359854a66feaeeda63ae1dfe13c4081b394 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x64x64_32x32x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x64x64_32x32x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x64x64_32x32x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..b3df1baf675f1eb03a7f5de7fa28df496279c315 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_1x1_64x64x64_32x32x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_256x128x64_64x64x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_256x128x64_64x64x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..a0a2a8bee1689b7b418bdc6234b429dac2db6975 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_256x128x64_64x64x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_256x128x64_64x64x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_256x128x64_64x64x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..89645723cd8f980f20255bda375f4e4e5540674c Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_256x128x64_64x64x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_256x128x64_64x64x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_256x128x64_64x64x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..513ff2695ac3b18f7c6cb0880eee43999971493f Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_256x128x64_64x64x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_32x64x64_16x32x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_32x64x64_16x32x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..c9e96d49978360eb544ea2f40ac508315800ec32 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_32x64x64_16x32x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_32x64x64_16x32x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_32x64x64_16x32x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..20b9cddc891a1bc261ad5cc3910615fde50ec3dd Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_32x64x64_16x32x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_32x64x64_16x32x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_32x64x64_16x32x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..41cd9ae5656ec339616b4163c7cf0f4a69f36b37 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_32x64x64_16x32x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x128x64_32x64x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x128x64_32x64x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..3151e3dc38a1dc57985f575c35008ab43c303bb3 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x128x64_32x64x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x128x64_32x64x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x128x64_32x64x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..c20fe3d06c6a28e861f168a35f98de11e7f743b0 Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x128x64_32x64x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x128x64_32x64x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x128x64_32x64x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..dddbed50f612d50841807e11adcaa5f2a21a4b8d Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x128x64_32x64x64_relu.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x64x64_32x32x64_hswish.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x64x64_32x32x64_hswish.cu new file mode 100644 index 0000000000000000000000000000000000000000..6028fba165cc043a786b248a69ad88ba773c9ddd Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x64x64_32x32x64_hswish.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x64x64_32x32x64_id.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x64x64_32x32x64_id.cu new file mode 100644 index 0000000000000000000000000000000000000000..0ea06e92e5d98051ef2238bbb2b03113a92216cf Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x64x64_32x32x64_id.cu differ diff --git a/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x64x64_32x32x64_relu.cu b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x64x64_32x32x64_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..688064da3aee8198af4f42d9d05da5e753446aba Binary files /dev/null and b/dnn/src/cuda/conv_bias/int8_imma/kimpl/conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x64x64_32x32x64_relu.cu differ diff --git a/dnn/test/cuda/conv_bias_int8.cpp b/dnn/test/cuda/conv_bias_int8.cpp index 41592a08b99d8f8a89673e36d9431908b0a746f8..01a3b8d31bfe352e37d42f2b9dc832618fad0896 100644 --- a/dnn/test/cuda/conv_bias_int8.cpp +++ b/dnn/test/cuda/conv_bias_int8.cpp @@ -1232,6 +1232,73 @@ TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW4_NCHW) { run({{16, 4, 46, 80, 4}, {4, 4, 3, 3, 4}, {1, 4, 1, 1}}); } +TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW4_NCHW32) { + require_compute_capability(6, 1); + using namespace conv_bias; + Checker checker(handle_cuda()); + UniformIntRNG int_rng{-3, 3}; + UniformIntRNG bias_rng{-50, 50}; + ConvBias::Param param; + param.format = ConvBias::Param::Format::NCHW4_NCHW32; + param.nonlineMode = ConvBias::Param::NonlineMode::IDENTITY; + checker.set_before_exec_callback( + conv_bias::ConvBiasAlgoChecker( + "INT8_NCHW4_DOTPROD_IMPLICIT_GEMM")); + checker.set_dtype(0, dtype::QuantizedS8(1.9980618f)) + .set_dtype(1, dtype::QuantizedS8(1.9980927f)) + .set_dtype(2, dtype::QuantizedS32(1.9980618f * 1.9980927f)) + .set_dtype(3, dtype::QuantizedS8(1.9980618f)) + .set_dtype(4, dtype::QuantizedS8(1.9980618f)) + .set_rng(0, &int_rng) + .set_rng(1, &int_rng) + .set_rng(2, &bias_rng) + .set_rng(3, &int_rng) + .set_param(param); + auto run = [&](const TensorShapeArray& shapes) { + checker.execs({shapes[0], shapes[1], shapes[2], {}, {}}); + }; + + run({{16, 4, 23, 40, 4}, {32, 4, 3, 3, 4}, {1, 1, 1, 1, 32}}); + run({{16, 4, 92, 160, 4}, {32, 4, 3, 3, 4}, {1, 1, 1, 1, 32}}); + run({{16, 4, 46, 80, 4}, {32, 4, 3, 3, 4}, {1, 1, 1, 1, 32}}); +} + +#if CUDA_VERSION >= 10020 +TEST_F(CUDA, CUTLASS_CONV_BIAS_INT8_NCHW32_NCHW4) { + require_compute_capability(7, 5); + using namespace conv_bias; + Checker checker(handle_cuda()); + UniformIntRNG int_rng{-3, 3}; + UniformIntRNG bias_rng{-50, 50}; + ConvBias::Param param; + param.format = ConvBias::Param::Format::NCHW32_NCHW4; + param.nonlineMode = ConvBias::Param::NonlineMode::IDENTITY; + checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker< + ConvBiasForward>( + ConvBias::algo_name( + "INT8_NCHW32_IMMA_IMPLICIT_GEMM_256X128X64_64X64X64", + ConvBias::DirectParam{}) + .c_str())); + checker.set_dtype(0, dtype::QuantizedS8(1.9980618f)) + .set_dtype(1, dtype::QuantizedS8(1.9980927f)) + .set_dtype(2, dtype::QuantizedS32(1.9980618f * 1.9980927f)) + .set_dtype(3, dtype::QuantizedS8(1.9980618f)) + .set_dtype(4, dtype::QuantizedS8(1.9980618f)) + .set_rng(0, &int_rng) + .set_rng(1, &int_rng) + .set_rng(2, &bias_rng) + .set_rng(3, &int_rng) + .set_param(param); + auto run = [&](const TensorShapeArray& shapes) { + checker.execs({shapes[0], shapes[1], shapes[2], {}, {}}); + }; + + run({{16, 2, 23, 40, 32}, {20, 2, 3, 3, 32}, {1, 5, 1, 1, 4}}); + run({{16, 1, 92, 160, 32}, {24, 1, 3, 3, 32}, {1, 6, 1, 1, 4}}); + run({{16, 2, 46, 80, 32}, {4, 2, 3, 3, 32}, {1, 1, 1, 1, 4}}); +} +#endif + #if MEGDNN_WITH_BENCHMARK TEST_F(CUDA, BENCHMARK_CONV_BIAS_INT8_CHWN4) { require_compute_capability(6, 1);