conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_ncdiv4hw4_64x128x64_32x64x64_id.cu 1.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<32>;
using LayoutFilter = cutlass::layout::TensorCxRSKx<32>;
using LayoutDst = cutlass::layout::TensorNCxHWx<4>;
using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
                    int8_t, 4, int32_t, int32_t, float>;
17
using Convolution = cutlass::conv::device::Convolution<
18 19
    int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, 
    LayoutDst, int32_t, LayoutDst, int32_t, 
20
    cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, 
21
    ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, 
22
    cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, 
23 24 25 26 27 28 29 30 31 32 33 34 35 36
    2, 16, 16, true, 
    cutlass::arch::OpMultiplyAddSaturate>;
template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
        const typename Convolution::ElementSrc* d_src, 
        const typename Convolution::ElementFilter* d_filter, 
        const typename Convolution::ElementBias* d_bias, 
        const typename Convolution::ElementDst* d_z, 
        typename Convolution::ElementDst* d_dst, 
        int* workspace, 
        typename Convolution::ConvolutionParameter const& conv_param, 
        typename Convolution::EpilogueOutputOp::Params const& epilogue, 
        cudaStream_t stream);
#pragma GCC diagnostic pop
#endif