matrix_mul_fp32_simt_32x128x8_32x64x8_tn.cu 1.6 KB
Newer Older
1 2 3 4 5 6 7
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_matrix_mul_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
8
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
#include "src/cuda/matrix_mul/fp32_simt/matrix_mul_float_simt_cutlass_wrapper.cuinl"

using LayoutA = cutlass::layout::ColumnMajor;
using LayoutB = cutlass::layout::RowMajor;
using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, 8>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>;
using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<float, 1, float, float>;
using Gemm = cutlass::gemm::device::Gemm<
    float, LayoutA, 
    float, LayoutB, 
    float, cutlass::layout::RowMajor, float, 
    cutlass::arch::OpClassSimt, cutlass::arch::Sm50, 
    ThreadBlockShape, WarpShape, InstructionShape, EpilogueOp, 
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 
    2>;
template void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_wrapper<Gemm>(
        const typename Gemm::ElementA* d_A, size_t lda, 
        const typename Gemm::ElementB* d_B, size_t ldb,  
        typename Gemm::ElementC* d_C, size_t ldc,  
        int* workspace, 
        cutlass::gemm::GemmCoord const& problem_size,   
        typename Gemm::EpilogueOutputOp::Params const& epilogue, 
32 33
        cudaStream_t stream, int split_k_slices);

34 35
#pragma GCC diagnostic pop
#endif