conv_bias_int8_implicit_gemm_dp4a_ncdiv4hw4_1x1_128x128x32_64x32x32_hswish.cu 1.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl"

using LayoutSrc = cutlass::layout::TensorNCxHWx<4>;
using LayoutFilter = cutlass::layout::TensorCxRSKx<4>;
11
using LayoutDst = cutlass::layout::TensorNCxHWx<4>;
12 13 14 15 16
using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 32>;
using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 4>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp<
                    int8_t, 4, int32_t, int32_t, float>;
17
using Convolution = cutlass::conv::device::Convolution<
18
    int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, 
19
    LayoutDst, int32_t, LayoutDst, int32_t, 
20
    cutlass::conv::ConvType::kConvolution, cutlass::arch::OpClassSimt, cutlass::arch::Sm61, 
21
    ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, 
22
    cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle, 
23 24
    2, 4, 16, false, 
    cutlass::arch::OpMultiplyAddSaturate>;
25
template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
26 27 28 29 30
        const typename Convolution::ElementSrc* d_src, 
        const typename Convolution::ElementFilter* d_filter, 
        const typename Convolution::ElementBias* d_bias, 
        const typename Convolution::ElementDst* d_z, 
        typename Convolution::ElementDst* d_dst, 
31 32 33 34 35 36
        int* workspace, 
        typename Convolution::ConvolutionParameter const& conv_param, 
        typename Convolution::EpilogueOutputOp::Params const& epilogue, 
        cudaStream_t stream);
#pragma GCC diagnostic pop
#endif