cutlass_simt_u4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_nhwc.cu 1.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

#if !MEGDNN_TEGRA_X1
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/conv_bias/implicit_gemm_conv_bias_cutlass_wrapper.cuinl"


// kernel instance "cutlass_simt_u4_ifprop_hswish_s8_128x32x32_64x32x32_2_nc4hw4_nhwc" generated by cutlass generator
using Convolution = 
  typename cutlass::conv::device::Convolution<
    int8_t, 
    cutlass::layout::TensorNCxHWx<4>,
    int8_t, 
    cutlass::layout::TensorCxRSKx<4>,
    cutlass::uint4b_t, 
    cutlass::layout::TensorNHWC,
    int32_t, 
    cutlass::layout::TensorNHWC, 
    int32_t, 
    cutlass::conv::ConvType::kConvolution, 
    cutlass::arch::OpClassSimt,
    cutlass::arch::Sm75,
    cutlass::gemm::GemmShape<128, 32, 32>,
    cutlass::gemm::GemmShape<64, 32, 32>,
    cutlass::gemm::GemmShape<1, 1, 4>,
    cutlass::epilogue::thread::BiasAddLinearCombinationHSwishClamp<
      cutlass::uint4b_t,
      8,
      int32_t, 
      int32_t, 
      float
    >,
    cutlass::conv::threadblock::ConvolutionFpropNCxHWxThreadblockSwizzle,     
    2,
    4, 
    16, 
    true, 
    cutlass::arch::OpMultiplyAddSaturate>;


template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
        const typename Convolution::ElementSrc* d_src, 
        const typename Convolution::ElementFilter* d_filter, 
        const typename Convolution::ElementBias* d_bias, 
        const typename Convolution::ElementDst* d_z, 
        typename Convolution::ElementDst* d_dst, 
        int* workspace, 
        typename Convolution::ConvolutionParameter const& conv_param, 
        typename Convolution::EpilogueOutputOp::Params const& epilogue, 
        cudaStream_t stream, 
        typename Convolution::ExtraParam extra_param);
#pragma GCC diagnostic pop
#endif