/** * \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ // ignore warning of cutlass #include "cuda.h" #if __CUDACC_VER_MAJOR__ > 9 || \ (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" #pragma GCC diagnostic ignored "-Wstrict-aliasing" #include "cutlass/gemm/device/gemm.h" #include "cutlass/gemm/device/gemm_splitk_parallel.h" #include "cutlass/gemm/kernel/default_gemv.h" #include "src/common/opr_param_defs_enumv.cuh" #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" #pragma GCC diagnostic pop using namespace megdnn; using namespace cuda; using namespace cutlass_wrapper; /* ================= cutlass kernel wrapper for f32 matrix mul ================ */ #define DISPATCH(cb) \ cb(64, 256, 8, 32, 64, 8); \ cb(256, 64, 8, 64, 32, 8); \ cb(32, 256, 8, 16, 64, 8); \ cb(256, 32, 8, 64, 16, 8); \ cb(128, 128, 8, 32, 64, 8); \ cb(128, 64, 8, 64, 32, 8); \ cb(64, 128, 8, 32, 64, 8); \ cb(128, 32, 8, 64, 32, 8); \ cb(32, 128, 8, 32, 64, 8); \ cb(64, 64, 8, 32, 64, 8); \ cb(32, 64, 8, 32, 64, 8); \ cb(64, 32, 8, 64, 32, 8); \ cb(32, 32, 8, 32, 32, 8); \ cb(8, 32, 8, 8, 32, 8); \ cb(16, 32, 8, 16, 32, 8); \ cb(16, 64, 8, 16, 64, 8); \ cb(16, 128, 8, 16, 64, 8); \ megdnn_assert(false, \ "unsupported threadblock shape (%dx%dx%d) and warp shape " \ "(%dx%dx%d)", \ threadblock_shape.m(), threadblock_shape.n(), \ threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ warp_shape.k()); void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( const float* d_A, bool transpose_A, size_t lda, const float* d_B, bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace, GemmCoord const& problem_size, float alpha, float beta, const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, cudaStream_t stream, int split_k_slices) { static constexpr int kEpilogueElementsPerAccess = 1; using EpilogueOp = cutlass::epilogue::thread::LinearCombination< float, kEpilogueElementsPerAccess, float, float>; typename EpilogueOp::Params epilogue{alpha, beta}; if (split_k_slices == 1) { #define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ warp_k_) \ if (threadblock_shape.m() == threadblock_m_ && \ threadblock_shape.n() == threadblock_n_ && \ threadblock_shape.k() == threadblock_k_ && \ warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ warp_shape.k() == warp_k_) { \ using ThreadBlockShape = \ cutlass::gemm::GemmShape; \ using WarpShape = cutlass::gemm::GemmShape; \ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \ using Gemm = cutlass::gemm::device::Gemm< \ float, LayoutA, float, LayoutB, float, \ cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \ cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \ InstructionShape, EpilogueOp, \ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, \ 2>; \ return cutlass_matrix_mul_wrapper(d_A, lda, d_B, ldb, d_C, ldc, \ workspace, problem_size, \ epilogue, stream); \ } if (!transpose_A && !transpose_B) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; DISPATCH(cb) } else if (!transpose_A && transpose_B) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; DISPATCH(cb) } else if (transpose_A && !transpose_B) { using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::RowMajor; DISPATCH(cb) } else { megdnn_assert(transpose_A && transpose_B); using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::ColumnMajor; DISPATCH(cb) } #undef cb } else { #define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ warp_k_) \ if (threadblock_shape.m() == threadblock_m_ && \ threadblock_shape.n() == threadblock_n_ && \ threadblock_shape.k() == threadblock_k_ && \ warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ warp_shape.k() == warp_k_) { \ using ThreadBlockShape = \ cutlass::gemm::GemmShape; \ using WarpShape = cutlass::gemm::GemmShape; \ using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \ using Gemm = cutlass::gemm::device::GemmSplitKParallel< \ float, LayoutA, float, LayoutB, float, \ cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \ cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \ InstructionShape, EpilogueOp>; \ return cutlass_matrix_mul_wrapper( \ d_A, lda, d_B, ldb, d_C, ldc, workspace, problem_size, \ epilogue, stream, split_k_slices); \ } if (!transpose_A && !transpose_B) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; DISPATCH(cb) } else if (!transpose_A && transpose_B) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; DISPATCH(cb) } else if (transpose_A && !transpose_B) { using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::RowMajor; DISPATCH(cb) } else { megdnn_assert(transpose_A && transpose_B); using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::ColumnMajor; DISPATCH(cb) } #undef cb } } #undef DISPATCH #endif // vim: syntax=cuda.doxygen