conv_bias_int8_implicit_gemm_imma_ncdiv32hw32_128x256x64_64x64x64_id.cu 1.9 KB
Newer Older
1 2 3 4 5 6
#if !MEGDNN_TEGRA_X1
// generated by gen_cuda_conv_bias_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
7
#include "src/cuda/conv_bias/int8/conv_bias_int8_implicit_gemm_cutlass_wrapper.cuinl"
8 9 10

using LayoutSrc = cutlass::layout::TensorNCxHWx<32>;
using LayoutFilter = cutlass::layout::TensorCxRSKx<32>;
11
using LayoutDst = cutlass::layout::TensorNCxHWx<32>;
12 13 14 15 16 17 18
using ThreadBlockShape = cutlass::gemm::GemmShape<128, 256, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>;
using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombinationClamp<
                    int8_t, 8, int32_t, int32_t, float>;
using Convolution = cutlass::convolution::device::Convolution<
    int8_t, LayoutSrc, int8_t, LayoutFilter, int8_t, 
19
    LayoutDst, int32_t, LayoutDst, int32_t, 
20 21 22 23
    cutlass::convolution::ConvType::kConvolution, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm75, 
    ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, 
    cutlass::convolution::threadblock::ConvolutionNCxHWxThreadblockSwizzle<
    cutlass::convolution::ConvType::kConvolution>, 
24 25
    2, 16, 16, true, 
    cutlass::arch::OpMultiplyAddSaturate>;
26
template void megdnn::cuda::cutlass_wrapper::cutlass_convolution_wrapper<Convolution>(
27 28 29 30 31
        const typename Convolution::ElementSrc* d_src, 
        const typename Convolution::ElementFilter* d_filter, 
        const typename Convolution::ElementBias* d_bias, 
        const typename Convolution::ElementDst* d_z, 
        typename Convolution::ElementDst* d_dst, 
32 33 34 35 36 37
        int* workspace, 
        typename Convolution::ConvolutionParameter const& conv_param, 
        typename Convolution::EpilogueOutputOp::Params const& epilogue, 
        cudaStream_t stream);
#pragma GCC diagnostic pop
#endif