diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp index c79f81fbeea694036b0b6c64b1a67e52960afb51..51a8a7eeadf480ba428030c2329b5cbbf88cd95c 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp @@ -40,19 +40,7 @@ bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( const SizeArgs& args) const { - size_t lda = args.layout_a.stride[0], ldb = args.layout_b.stride[0], - ldc = args.layout_c.stride[0]; - auto&& param = args.opr->param(); - int m = args.layout_c.shape[0], n = args.layout_c.shape[1], - k = args.layout_a.shape[param.transposeA ? 0 : 1]; - GemmCoord problem_size{m, n, k}; - return cutlass_matrix_mul_float32_simt_get_workspace_size( - param.transposeA, lda, param.transposeB, ldb, ldc, problem_size, - 1.f, 0.f, - GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, - m_algo_param.warp_k}); + return 0_z; } void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp index 82e9034110bbc62f4525cbfc6ad397d3a6fab6b9..02d028da6a6a2f248e2067215d868ecf1018892d 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp @@ -33,21 +33,11 @@ bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( const SizeArgs& args) const { - size_t lda = args.layout_a.stride[0], ldb = args.layout_b.stride[0], - ldc = args.layout_c.stride[0]; auto&& param = args.opr->param(); int m = args.layout_c.shape[0], n = args.layout_c.shape[1], k = args.layout_a.shape[param.transposeA ? 0 : 1]; - GemmCoord problem_size{m, n, k}; int split_k_slices = k / n; - return cutlass_matrix_mul_float32_simt_get_workspace_size( - param.transposeA, lda, param.transposeB, ldb, ldc, problem_size, - 1.f, 0.f, - GemmCoord{m_algo_param.threadblock_m, m_algo_param.threadblock_n, - m_algo_param.threadblock_k}, - GemmCoord{m_algo_param.warp_m, m_algo_param.warp_n, - m_algo_param.warp_k}, - split_k_slices); + return args.layout_c.dtype.size(m * n * split_k_slices); } void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( diff --git a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu index c48182d21263d09de895c4c1a658548380cd4283..fe478b832e13de30b151940e17cdcc4dc3fc0513 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu +++ b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu @@ -150,203 +150,6 @@ void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( #undef cb } } - -size_t megdnn::cuda::cutlass_wrapper:: - cutlass_matrix_mul_float32_simt_get_workspace_size( - bool transpose_A, size_t lda, bool transpose_B, size_t ldb, - size_t ldc, GemmCoord const& problem_size, float alpha, - float beta, const GemmCoord& threadblock_shape, - const GemmCoord& warp_shape, int split_k_slices) { - static constexpr int kEpilogueElementsPerAccess = 1; - using EpilogueOp = cutlass::epilogue::thread::LinearCombination< - float, kEpilogueElementsPerAccess, float, float>; - typename EpilogueOp::Params epilogue{alpha, beta}; - if (split_k_slices == 1) { -#define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ - warp_k_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \ - using Gemm = cutlass::gemm::device::Gemm< \ - float, LayoutA, float, LayoutB, float, \ - cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \ - cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \ - InstructionShape, EpilogueOp, \ - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, \ - 2>; \ - typename Gemm::TensorRefA tensor_A{ \ - nullptr, Gemm::LayoutA{static_cast(lda)}}; \ - typename Gemm::TensorRefB tensor_B{ \ - nullptr, Gemm::LayoutB{static_cast(ldb)}}; \ - typename Gemm::TensorRefC tensor_C{ \ - nullptr, Gemm::LayoutC{static_cast(ldc)}}; \ - typename Gemm::TensorRefD tensor_D{ \ - nullptr, Gemm::LayoutC{static_cast(ldc)}}; \ - typename Gemm::Arguments arguments{problem_size, tensor_A, tensor_B, \ - tensor_C, tensor_D, epilogue, \ - split_k_slices}; \ - return Gemm::get_workspace_size(arguments); \ - } - if (!transpose_A && !transpose_B) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; - DISPATCH(cb) - } else if (!transpose_A && transpose_B) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - DISPATCH(cb) - } else if (transpose_A && !transpose_B) { - using LayoutA = cutlass::layout::ColumnMajor; - using LayoutB = cutlass::layout::RowMajor; - DISPATCH(cb) - } else { - megdnn_assert(transpose_A && transpose_B); - using LayoutA = cutlass::layout::ColumnMajor; - using LayoutB = cutlass::layout::ColumnMajor; - DISPATCH(cb) - } -#undef cb - } else { -#define cb(threadblock_m_, threadblock_n_, threadblock_k_, warp_m_, warp_n_, \ - warp_k_) \ - if (threadblock_shape.m() == threadblock_m_ && \ - threadblock_shape.n() == threadblock_n_ && \ - threadblock_shape.k() == threadblock_k_ && \ - warp_shape.m() == warp_m_ && warp_shape.n() == warp_n_ && \ - warp_shape.k() == warp_k_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape; \ - using WarpShape = cutlass::gemm::GemmShape; \ - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; \ - using Gemm = cutlass::gemm::device::GemmSplitKParallel< \ - float, LayoutA, float, LayoutB, float, \ - cutlass::layout::RowMajor, float, cutlass::arch::OpClassSimt, \ - cutlass::arch::Sm50, ThreadBlockShape, WarpShape, \ - InstructionShape, EpilogueOp>; \ - using TensorRefA = cutlass::TensorRef; \ - using TensorRefB = cutlass::TensorRef; \ - using TensorRefC = cutlass::TensorRef; \ - using TensorRefD = cutlass::TensorRef; \ - TensorRefA tensor_A{nullptr, Gemm::LayoutA{static_cast(lda)}}; \ - TensorRefB tensor_B{nullptr, Gemm::LayoutB{static_cast(ldb)}}; \ - TensorRefC tensor_C{nullptr, Gemm::LayoutC{static_cast(ldc)}}; \ - TensorRefD tensor_D{nullptr, Gemm::LayoutC{static_cast(ldc)}}; \ - typename Gemm::Arguments arguments{problem_size, tensor_A, tensor_B, \ - tensor_C, tensor_D, epilogue, \ - split_k_slices}; \ - return Gemm::get_workspace_size(arguments); \ - } - if (!transpose_A && !transpose_B) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; - DISPATCH(cb) - } else if (!transpose_A && transpose_B) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - DISPATCH(cb) - } else if (transpose_A && !transpose_B) { - using LayoutA = cutlass::layout::ColumnMajor; - using LayoutB = cutlass::layout::RowMajor; - DISPATCH(cb) - } else { - megdnn_assert(transpose_A && transpose_B); - using LayoutA = cutlass::layout::ColumnMajor; - using LayoutB = cutlass::layout::ColumnMajor; - DISPATCH(cb) - } -#undef cb - } -} -#undef DISPATCH - -/* ============ cutlass kernel wrapper for f32 vector-matrix mul batched strided - * =========== - */ -#define DISPATCH(cb) \ - cb(128, 4, 4); \ - cb(128, 4, 2); \ - cb(128, 4, 1); \ - cb(128, 2, 4); \ - cb(128, 1, 4); \ - cb(128, 2, 2); \ - cb(128, 1, 2); \ - cb(128, 2, 1); \ - cb(128, 1, 1); \ - cb(64, 4, 4); \ - cb(64, 4, 2); \ - cb(64, 4, 1); \ - cb(64, 2, 4); \ - cb(64, 1, 4); \ - cb(64, 2, 2); \ - cb(64, 1, 2); \ - cb(64, 2, 1); \ - cb(64, 1, 1); \ - cb(32, 4, 4); \ - cb(32, 4, 2); \ - cb(32, 4, 1); \ - cb(32, 2, 4); \ - cb(32, 1, 4); \ - cb(32, 2, 2); \ - cb(32, 1, 2); \ - cb(32, 2, 1); \ - cb(32, 1, 1); \ - megdnn_assert(false, \ - "unsupported gemv batched strided A=%dX%dX%d, B=%dX%dX%d", \ - problem_size.batch(), problem_size.m(), problem_size.k(), \ - problem_size.batch(), problem_size.k(), problem_size.n()); - -void megdnn::cuda::cutlass_wrapper:: - cutlass_matrix_mul_float32_simt_gemv_batched_strided( - const float* d_A, size_t lda, size_t batch_stride_a, - const float* d_B, size_t ldb, size_t batch_stride_b, float* d_C, - size_t ldc, size_t batch_stride_c, - BatchedGemmCoord const& problem_size, int threadblock_n, - cudaStream_t stream) { - int LDG_K, LDG_N; - if (lda % 4 == 0) - LDG_K = 4; - else if (lda % 2 == 0) - LDG_K = 2; - else - LDG_K = 1; - - if (ldb % 4 == 0) - LDG_N = 4; - else if (ldb % 2 == 0) - LDG_N = 2; - else - LDG_N = 1; -#define cb(threadblock_n_, LDG_K_, LDG_N_) \ - if (threadblock_n == threadblock_n_ && LDG_K == LDG_K_ && \ - LDG_N == LDG_N_) { \ - using ThreadBlockShape = \ - cutlass::gemm::GemmShape<1, threadblock_n_, \ - (256 * LDG_K_) / \ - (threadblock_n_ / LDG_N_)>; \ - using ThreadShape = cutlass::gemm::GemmShape<1, LDG_N_, LDG_K_>; \ - using GemvKernel = cutlass::gemm::kernel::DefaultGemv< \ - ThreadBlockShape, ThreadShape, float, \ - cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, \ - float, cutlass::layout::RowMajor>; \ - return cutlass_vector_matrix_mul_batched_strided_wrapper( \ - problem_size, d_A, lda, batch_stride_a, d_B, ldb, \ - batch_stride_b, d_C, ldc, batch_stride_c, stream); \ - } - DISPATCH(cb) -#undef cb -} #undef DISPATCH #endif diff --git a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh index 63cecc12904b6f591b1202263a349570c8ef3763..86144b912d218d69d9231120b308d07d03f0ab07 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh +++ b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh @@ -37,12 +37,6 @@ void cutlass_matrix_mul_float32_simt( const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, cudaStream_t stream, int split_k_slices = 1); -size_t cutlass_matrix_mul_float32_simt_get_workspace_size( - bool transpose_A, size_t lda, bool transpose_B, size_t ldb, size_t ldc, - GemmCoord const& problem_size, float alpha, float beta, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, - int split_k_slices = 1); - template void cutlass_vector_matrix_mul_batched_strided_wrapper( BatchedGemmCoord const& problem_size, diff --git a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cu b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cu new file mode 100644 index 0000000000000000000000000000000000000000..9366b25da04a63b3869daf3cf4ef2603f0df8e77 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cu @@ -0,0 +1,111 @@ +/** + * \file dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +// ignore warning of cutlass +#include "cuda.h" +#if __CUDACC_VER_MAJOR__ > 9 || \ + (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" + +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_splitk_parallel.h" +#include "cutlass/gemm/kernel/default_gemv.h" +#include "src/common/opr_param_defs_enumv.cuh" +#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" +#pragma GCC diagnostic pop + +using namespace megdnn; +using namespace cuda; +using namespace cutlass_wrapper; + +/* ============ cutlass kernel wrapper for f32 vector-matrix mul batched strided + * =========== + */ +#define DISPATCH(cb) \ + cb(128, 4, 4); \ + cb(128, 4, 2); \ + cb(128, 4, 1); \ + cb(128, 2, 4); \ + cb(128, 1, 4); \ + cb(128, 2, 2); \ + cb(128, 1, 2); \ + cb(128, 2, 1); \ + cb(128, 1, 1); \ + cb(64, 4, 4); \ + cb(64, 4, 2); \ + cb(64, 4, 1); \ + cb(64, 2, 4); \ + cb(64, 1, 4); \ + cb(64, 2, 2); \ + cb(64, 1, 2); \ + cb(64, 2, 1); \ + cb(64, 1, 1); \ + cb(32, 4, 4); \ + cb(32, 4, 2); \ + cb(32, 4, 1); \ + cb(32, 2, 4); \ + cb(32, 1, 4); \ + cb(32, 2, 2); \ + cb(32, 1, 2); \ + cb(32, 2, 1); \ + cb(32, 1, 1); \ + megdnn_assert(false, \ + "unsupported gemv batched strided A=%dX%dX%d, B=%dX%dX%d", \ + problem_size.batch(), problem_size.m(), problem_size.k(), \ + problem_size.batch(), problem_size.k(), problem_size.n()); + +void megdnn::cuda::cutlass_wrapper:: + cutlass_matrix_mul_float32_simt_gemv_batched_strided( + const float* d_A, size_t lda, size_t batch_stride_a, + const float* d_B, size_t ldb, size_t batch_stride_b, float* d_C, + size_t ldc, size_t batch_stride_c, + BatchedGemmCoord const& problem_size, int threadblock_n, + cudaStream_t stream) { + int LDG_K, LDG_N; + if (lda % 4 == 0) + LDG_K = 4; + else if (lda % 2 == 0) + LDG_K = 2; + else + LDG_K = 1; + + if (ldb % 4 == 0) + LDG_N = 4; + else if (ldb % 2 == 0) + LDG_N = 2; + else + LDG_N = 1; +#define cb(threadblock_n_, LDG_K_, LDG_N_) \ + if (threadblock_n == threadblock_n_ && LDG_K == LDG_K_ && \ + LDG_N == LDG_N_) { \ + using ThreadBlockShape = \ + cutlass::gemm::GemmShape<1, threadblock_n_, \ + (256 * LDG_K_) / \ + (threadblock_n_ / LDG_N_)>; \ + using ThreadShape = cutlass::gemm::GemmShape<1, LDG_N_, LDG_K_>; \ + using GemvKernel = cutlass::gemm::kernel::DefaultGemv< \ + ThreadBlockShape, ThreadShape, float, \ + cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, \ + float, cutlass::layout::RowMajor>; \ + return cutlass_vector_matrix_mul_batched_strided_wrapper( \ + problem_size, d_A, lda, batch_stride_a, d_B, ldb, \ + batch_stride_b, d_C, ldc, batch_stride_c, stream); \ + } + DISPATCH(cb) +#undef cb +} +#undef DISPATCH + +#endif + +// vim: syntax=cuda.doxygen