From 2de2222e467a1ac1b4292c72922c5b68cebb362a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 19 Jan 2021 16:22:45 +0800 Subject: [PATCH] feat(dnn/cuda): add cutlass batched gemv kernel for matmul operator GitOrigin-RevId: 51702c4e79347175a993700be4022bc38102d79f --- dnn/scripts/Makefile | 5 +- dnn/src/cuda/matrix_mul/algos.cpp | 10 ++ dnn/src/cuda/matrix_mul/algos.h | 34 ++++++ .../cuda/matrix_mul/cutlass_float32_simt.cpp | 6 +- ...lass_float32_simt_gemv_batched_strided.cpp | 58 +++++++++ .../cutlass_float32_simt_split_k.cpp | 10 +- .../matrix_mul/cutlass_matrix_mul_wrapper.cu | 114 +++++++++++++----- .../matrix_mul/cutlass_matrix_mul_wrapper.cuh | 20 ++- ...imt_gemv_batched_strided_1x128x16_1x2x4.cu | 26 ++++ ...imt_gemv_batched_strided_1x128x16_1x4x2.cu | 26 ++++ ...simt_gemv_batched_strided_1x128x2_1x1x1.cu | 26 ++++ ...imt_gemv_batched_strided_1x128x32_1x4x4.cu | 26 ++++ ...simt_gemv_batched_strided_1x128x4_1x1x2.cu | 26 ++++ ...simt_gemv_batched_strided_1x128x4_1x2x1.cu | 26 ++++ ...simt_gemv_batched_strided_1x128x8_1x1x4.cu | 26 ++++ ...simt_gemv_batched_strided_1x128x8_1x2x2.cu | 26 ++++ ...simt_gemv_batched_strided_1x128x8_1x4x1.cu | 26 ++++ ...imt_gemv_batched_strided_1x32x128_1x4x4.cu | 26 ++++ ...simt_gemv_batched_strided_1x32x16_1x1x2.cu | 26 ++++ ...simt_gemv_batched_strided_1x32x16_1x2x1.cu | 26 ++++ ...simt_gemv_batched_strided_1x32x32_1x1x4.cu | 26 ++++ ...simt_gemv_batched_strided_1x32x32_1x2x2.cu | 26 ++++ ...simt_gemv_batched_strided_1x32x32_1x4x1.cu | 26 ++++ ...simt_gemv_batched_strided_1x32x64_1x2x4.cu | 26 ++++ ...simt_gemv_batched_strided_1x32x64_1x4x2.cu | 26 ++++ ..._simt_gemv_batched_strided_1x32x8_1x1x1.cu | 26 ++++ ...simt_gemv_batched_strided_1x64x16_1x1x4.cu | 26 ++++ ...simt_gemv_batched_strided_1x64x16_1x2x2.cu | 26 ++++ ...simt_gemv_batched_strided_1x64x16_1x4x1.cu | 26 ++++ ...simt_gemv_batched_strided_1x64x32_1x2x4.cu | 26 ++++ ...simt_gemv_batched_strided_1x64x32_1x4x2.cu | 26 ++++ ..._simt_gemv_batched_strided_1x64x4_1x1x1.cu | 26 ++++ ...simt_gemv_batched_strided_1x64x64_1x4x4.cu | 26 ++++ ..._simt_gemv_batched_strided_1x64x8_1x1x2.cu | 26 ++++ ..._simt_gemv_batched_strided_1x64x8_1x2x1.cu | 26 ++++ ...gemv_batched_strided_cutlass_wrapper.cuinl | 70 +++++++++++ dnn/src/cuda/matrix_mul/opr_impl.h | 3 + dnn/test/cuda/cutlass_matmul.cpp | 28 ++++- 38 files changed, 1018 insertions(+), 42 deletions(-) create mode 100644 dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x2x4.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x4x2.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x2_1x1x1.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x32_1x4x4.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x1x2.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x2x1.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x1x4.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x2x2.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x4x1.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x128_1x4x4.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x1x2.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x2x1.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x1x4.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x2x2.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x4x1.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x2x4.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x4x2.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x8_1x1x1.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x1x4.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x2x2.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x4x1.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x2x4.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x4x2.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x4_1x1x1.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x64_1x4x4.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x1x2.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x2x1.cu create mode 100644 dnn/src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl diff --git a/dnn/scripts/Makefile b/dnn/scripts/Makefile index b093b4aec..b5b9532e9 100644 --- a/dnn/scripts/Makefile +++ b/dnn/scripts/Makefile @@ -9,7 +9,7 @@ ELEMWISE_IMPL := ../src/cuda/cond_take/kimpl \ ../src/cuda/elemwise_multi_type/kimpl CUDA_CONV_IMPL := ../src/cuda/conv_bias/int8/kimpl ../src/cuda/conv_bias/int8_imma/kimpl ../src/cuda/batch_conv_bias/int8/kimpl -CUDA_MATMUL_IMPL := ../src/cuda/matrix_mul/fp32_simt/kimpl +CUDA_MATMUL_IMPL := ../src/cuda/matrix_mul/fp32_simt/kimpl ../src/cuda/matrix_mul/fp32_simt_gemv/kimpl all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL) @@ -51,4 +51,7 @@ all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL) ../src/cuda/matrix_mul/fp32_simt/kimpl: gen_cutlass_matmul_kern_impls.py ./$^ $@ +../src/cuda/matrix_mul/fp32_simt_gemv/kimpl: gen_cutlass_gemv_batched_strided_kern_impls.py + ./$^ $@ + .PHONY: all diff --git a/dnn/src/cuda/matrix_mul/algos.cpp b/dnn/src/cuda/matrix_mul/algos.cpp index fa1909805..490b32bdc 100644 --- a/dnn/src/cuda/matrix_mul/algos.cpp +++ b/dnn/src/cuda/matrix_mul/algos.cpp @@ -33,6 +33,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { #if !MEGDNN_DISABLE_FLOAT16 all_algos.push_back(&bfloat16); #endif +#if CUDA_VERSION >= 9020 fill_cutlass_algos(); for (auto&& algo : simt_float32) { all_algos.push_back(&algo); @@ -40,12 +41,17 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() { for (auto&& algo : simt_float32_split_k) { all_algos.push_back(&algo); } + for (auto&& algo : simt_float32_gemv_batched_strided) { + all_algos.push_back(&algo); + } +#endif for (auto&& algo : all_algos) { m_all_algos_map.emplace(algo->info().desc, algo); } } +#if CUDA_VERSION >= 9020 void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { using AlgoParam = AlgoFloat32SIMT::AlgoParam; simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8}); @@ -82,7 +88,11 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() { simt_float32_split_k.emplace_back(AlgoParam{16, 32, 8, 16, 32, 8}); simt_float32_split_k.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8}); simt_float32_split_k.emplace_back(AlgoParam{16, 128, 8, 16, 64, 8}); + simt_float32_gemv_batched_strided.emplace_back(128); + simt_float32_gemv_batched_strided.emplace_back(64); + simt_float32_gemv_batched_strided.emplace_back(32); } +#endif MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack; diff --git a/dnn/src/cuda/matrix_mul/algos.h b/dnn/src/cuda/matrix_mul/algos.h index d647c6613..e55e0b20f 100644 --- a/dnn/src/cuda/matrix_mul/algos.h +++ b/dnn/src/cuda/matrix_mul/algos.h @@ -42,8 +42,11 @@ public: CUDA_CUBLASLT, CUDA_NAIVE, CUDA_BFLOAT16, +#if CUDA_VERSION >= 9020 CUDA_FLOAT32_SIMT, CUDA_FLOAT32_SIMT_SPLIT_K, + CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED, +#endif }; using Mapper = std::unordered_map; @@ -167,6 +170,7 @@ private: }; #endif +#if CUDA_VERSION >= 9020 class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase { public: struct AlgoParam { @@ -224,6 +228,32 @@ private: std::string m_name; }; +class MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided final + : public AlgoBase { +public: + AlgoFloat32SIMTGemvBatchedStrided(int threadblock_n) + : m_threadblock_n{threadblock_n}, + m_name{ssprintf("CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_%d", + m_threadblock_n)} {} + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + const char* name() const override { return m_name.c_str(); } + void exec(const ExecArgs& args) const override; + bool is_reproducible() const override { return true; } + MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED) + + std::string param() const override { + std::string ret; + serialize_write_pod(m_threadblock_n, ret); + return ret; + } + +private: + int m_threadblock_n; + std::string m_name; +}; +#endif + class MatrixMulForwardImpl::AlgoPack : NonCopyableObj { private: AlgoBase::Mapper m_all_algos_map; @@ -241,8 +271,12 @@ public: #if !MEGDNN_DISABLE_FLOAT16 AlgoBFloat16 bfloat16; #endif +#if CUDA_VERSION >= 9020 std::vector simt_float32; std::vector simt_float32_split_k; + std::vector + simt_float32_gemv_batched_strided; +#endif std::vector all_algos; const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; } diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp index a219271aa..16d80c6a1 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp @@ -15,20 +15,17 @@ #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" #include "src/cuda/utils.h" +#if CUDA_VERSION >= 9020 using namespace megdnn; using namespace cuda; using namespace cutlass_wrapper; bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( const SizeArgs& args) const { -#if CUDA_VERSION >= 9200 return args.opr->param().format == param::MatrixMul::Format::DEFAULT && args.layout_a.dtype == dtype::Float32() && args.layout_b.dtype == dtype::Float32() && args.layout_c.dtype == dtype::Float32(); -#else - return false; -#endif } size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( @@ -69,5 +66,6 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const { m_algo_param.warp_k}, stream); } +#endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp new file mode 100644 index 000000000..6d581a8ed --- /dev/null +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp @@ -0,0 +1,58 @@ +/** + * \file dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/handle.h" +#include "src/cuda/matrix_mul/algos.h" +#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" +#include "src/cuda/utils.h" + +#if CUDA_VERSION >= 9020 +using namespace megdnn; +using namespace cuda; +using namespace cutlass_wrapper; + +bool MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::is_available( + const SizeArgs& args) const { + auto&& param = args.opr->param(); + bool ta = param.transposeA, tb = param.transposeB; + return args.opr->param().format == param::MatrixMul::Format::DEFAULT && + args.layout_a.dtype == dtype::Float32() && + args.layout_b.dtype == dtype::Float32() && + args.layout_c.dtype == dtype::Float32() && ((!ta) && (!tb)); +} + +size_t +MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::get_workspace_in_bytes( + const SizeArgs& /* args */) const { + return 0; +} + +void MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::exec( + const ExecArgs& args) const { + size_t lda = args.tensor_a.layout.stride[0], + ldb = args.tensor_b.layout.stride[0], + ldc = args.tensor_c.layout.stride[0]; + auto&& param = args.opr->param(); + int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], + k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; + // m is always 1 in gemv batched strided case + BatchedGemmCoord problem_size{1, n, k, m}; + auto&& stream = cuda_stream(args.opr->handle()); + return cutlass_matrix_mul_float32_simt_gemv_batched_strided( + args.tensor_a.ptr(), lda, lda, + args.tensor_b.ptr(), ldb, 0, + args.tensor_c.ptr(), ldc, ldc, problem_size, + m_threadblock_n, stream); +} +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp index 50ccb67db..82e903411 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp @@ -15,6 +15,7 @@ #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" #include "src/cuda/utils.h" +#if CUDA_VERSION >= 9020 using namespace megdnn; using namespace cuda; using namespace cutlass_wrapper; @@ -22,12 +23,12 @@ using namespace cutlass_wrapper; bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( const SizeArgs& args) const { auto&& param = args.opr->param(); - int m = args.layout_c.shape[0], n = args.layout_c.shape[1], + int n = args.layout_c.shape[1], k = args.layout_a.shape[param.transposeA ? 0 : 1]; return args.opr->param().format == param::MatrixMul::Format::DEFAULT && args.layout_a.dtype == dtype::Float32() && args.layout_b.dtype == dtype::Float32() && - args.layout_c.dtype == dtype::Float32() && k > std::max(m, n); + args.layout_c.dtype == dtype::Float32() && k > n; } size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( @@ -38,7 +39,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( int m = args.layout_c.shape[0], n = args.layout_c.shape[1], k = args.layout_a.shape[param.transposeA ? 0 : 1]; GemmCoord problem_size{m, n, k}; - int split_k_slices = k / std::max(m, n); + int split_k_slices = k / n; return cutlass_matrix_mul_float32_simt_get_workspace_size( param.transposeA, lda, param.transposeB, ldb, ldc, problem_size, 1.f, 0.f, @@ -58,7 +59,7 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; GemmCoord problem_size{m, n, k}; - int split_k_slices = k / std::max(m, n); + int split_k_slices = k / n; auto&& stream = cuda_stream(args.opr->handle()); int* workspace = reinterpret_cast(args.workspace.raw_ptr); return cutlass_matrix_mul_float32_simt( @@ -72,5 +73,6 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( m_algo_param.warp_k}, stream, split_k_slices); } +#endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu index 4907b4fa3..c48182d21 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu +++ b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu @@ -10,16 +10,16 @@ * implied. */ // ignore warning of cutlass +#include "cuda.h" +#if __CUDACC_VER_MAJOR__ > 9 || \ + (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wunused-parameter" #pragma GCC diagnostic ignored "-Wstrict-aliasing" -#include "cuda.h" -#if __CUDACC_VER_MAJOR__ > 9 || \ - (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) #include "cutlass/gemm/device/gemm.h" #include "cutlass/gemm/device/gemm_splitk_parallel.h" -#endif +#include "cutlass/gemm/kernel/default_gemv.h" #include "src/common/opr_param_defs_enumv.cuh" #include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" #pragma GCC diagnostic pop @@ -54,18 +54,6 @@ using namespace cutlass_wrapper; threadblock_shape.m(), threadblock_shape.n(), \ threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \ warp_shape.k()); - -#if __CUDACC_VER_MAJOR__ < 9 || \ - (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ <= 2) -void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( - const float* /* d_A */, bool /* transpose_A */, size_t /* lda */, - const float* /* d_B */, bool /* transpose_B */, size_t /* ldb */, - float* /* d_C */, size_t /* ldc */, int* /* workspace */, - GemmCoord const& /* problem_size */, float /* alpha */, - float /* beta */, const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, cudaStream_t /* stream */, - int /* split_k_slices */) {} -#else void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( const float* d_A, bool transpose_A, size_t lda, const float* d_B, bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace, @@ -162,20 +150,7 @@ void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt( #undef cb } } -#endif -#if __CUDACC_VER_MAJOR__ < 9 || \ - (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ <= 2) -size_t megdnn::cuda::cutlass_wrapper:: - cutlass_matrix_mul_float32_simt_get_workspace_size( - bool /* transpose_A */, size_t /* lda */, - bool /* transpose_B */, size_t /* ldb */, size_t /* ldc */, - GemmCoord const& /* problem_size */, float /* alpha */, - float /* beta */, const GemmCoord& /* threadblock_shape */, - const GemmCoord& /* warp_shape */, int /* split_k_slices */) { - return 0; -} -#else size_t megdnn::cuda::cutlass_wrapper:: cutlass_matrix_mul_float32_simt_get_workspace_size( bool transpose_A, size_t lda, bool transpose_B, size_t ldb, @@ -294,7 +269,86 @@ size_t megdnn::cuda::cutlass_wrapper:: #undef cb } } -#endif +#undef DISPATCH + +/* ============ cutlass kernel wrapper for f32 vector-matrix mul batched strided + * =========== + */ +#define DISPATCH(cb) \ + cb(128, 4, 4); \ + cb(128, 4, 2); \ + cb(128, 4, 1); \ + cb(128, 2, 4); \ + cb(128, 1, 4); \ + cb(128, 2, 2); \ + cb(128, 1, 2); \ + cb(128, 2, 1); \ + cb(128, 1, 1); \ + cb(64, 4, 4); \ + cb(64, 4, 2); \ + cb(64, 4, 1); \ + cb(64, 2, 4); \ + cb(64, 1, 4); \ + cb(64, 2, 2); \ + cb(64, 1, 2); \ + cb(64, 2, 1); \ + cb(64, 1, 1); \ + cb(32, 4, 4); \ + cb(32, 4, 2); \ + cb(32, 4, 1); \ + cb(32, 2, 4); \ + cb(32, 1, 4); \ + cb(32, 2, 2); \ + cb(32, 1, 2); \ + cb(32, 2, 1); \ + cb(32, 1, 1); \ + megdnn_assert(false, \ + "unsupported gemv batched strided A=%dX%dX%d, B=%dX%dX%d", \ + problem_size.batch(), problem_size.m(), problem_size.k(), \ + problem_size.batch(), problem_size.k(), problem_size.n()); + +void megdnn::cuda::cutlass_wrapper:: + cutlass_matrix_mul_float32_simt_gemv_batched_strided( + const float* d_A, size_t lda, size_t batch_stride_a, + const float* d_B, size_t ldb, size_t batch_stride_b, float* d_C, + size_t ldc, size_t batch_stride_c, + BatchedGemmCoord const& problem_size, int threadblock_n, + cudaStream_t stream) { + int LDG_K, LDG_N; + if (lda % 4 == 0) + LDG_K = 4; + else if (lda % 2 == 0) + LDG_K = 2; + else + LDG_K = 1; + if (ldb % 4 == 0) + LDG_N = 4; + else if (ldb % 2 == 0) + LDG_N = 2; + else + LDG_N = 1; +#define cb(threadblock_n_, LDG_K_, LDG_N_) \ + if (threadblock_n == threadblock_n_ && LDG_K == LDG_K_ && \ + LDG_N == LDG_N_) { \ + using ThreadBlockShape = \ + cutlass::gemm::GemmShape<1, threadblock_n_, \ + (256 * LDG_K_) / \ + (threadblock_n_ / LDG_N_)>; \ + using ThreadShape = cutlass::gemm::GemmShape<1, LDG_N_, LDG_K_>; \ + using GemvKernel = cutlass::gemm::kernel::DefaultGemv< \ + ThreadBlockShape, ThreadShape, float, \ + cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, \ + float, cutlass::layout::RowMajor>; \ + return cutlass_vector_matrix_mul_batched_strided_wrapper( \ + problem_size, d_A, lda, batch_stride_a, d_B, ldb, \ + batch_stride_b, d_C, ldc, batch_stride_c, stream); \ + } + DISPATCH(cb) +#undef cb +} #undef DISPATCH + +#endif + // vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh index 1947f773d..63cecc129 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh +++ b/dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh @@ -13,11 +13,13 @@ #include "cutlass/gemm/gemm.h" #include "src/cuda/utils.cuh" +#if CUDA_VERSION >= 9020 namespace megdnn { namespace cuda { namespace cutlass_wrapper { using GemmCoord = cutlass::gemm::GemmCoord; +using BatchedGemmCoord = cutlass::gemm::BatchedGemmCoord; template void cutlass_matrix_mul_wrapper( @@ -38,10 +40,26 @@ void cutlass_matrix_mul_float32_simt( size_t cutlass_matrix_mul_float32_simt_get_workspace_size( bool transpose_A, size_t lda, bool transpose_B, size_t ldb, size_t ldc, GemmCoord const& problem_size, float alpha, float beta, - const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, int split_k_slices = 1); + const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, + int split_k_slices = 1); + +template +void cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, + size_t batch_stride_a, const typename GemvKernel::ElementB* d_B, + size_t ldb, size_t batch_stride_b, typename GemvKernel::ElementCD* d_C, + size_t ldc, size_t batch_stride_c, cudaStream_t stream); + +void cutlass_matrix_mul_float32_simt_gemv_batched_strided( + const float* d_A, size_t lda, size_t batch_stride_a, const float* d_B, + size_t ldb, size_t batch_stride_b, float* d_C, size_t ldc, + size_t batch_stride_c, BatchedGemmCoord const& problem_size, + int threadblock_n, cudaStream_t stream); } // namespace cutlass_wrapper } // namespace cuda } // namespace megdnn +#endif // vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x2x4.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x2x4.cu new file mode 100644 index 000000000..95a8741c4 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x2x4.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 16>; +using ThreadShape = cutlass::gemm::GemmShape<1, 2, 4>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x4x2.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x4x2.cu new file mode 100644 index 000000000..a620831fd --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x4x2.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 16>; +using ThreadShape = cutlass::gemm::GemmShape<1, 4, 2>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x2_1x1x1.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x2_1x1x1.cu new file mode 100644 index 000000000..b3c0e76d1 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x2_1x1x1.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 2>; +using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x32_1x4x4.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x32_1x4x4.cu new file mode 100644 index 000000000..0870613f8 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x32_1x4x4.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 32>; +using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x1x2.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x1x2.cu new file mode 100644 index 000000000..bcdfb0eb3 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x1x2.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 4>; +using ThreadShape = cutlass::gemm::GemmShape<1, 1, 2>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x2x1.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x2x1.cu new file mode 100644 index 000000000..af6d0e492 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x2x1.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 4>; +using ThreadShape = cutlass::gemm::GemmShape<1, 2, 1>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x1x4.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x1x4.cu new file mode 100644 index 000000000..37bc33c8c --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x1x4.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 8>; +using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x2x2.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x2x2.cu new file mode 100644 index 000000000..4ea842e8d --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x2x2.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 8>; +using ThreadShape = cutlass::gemm::GemmShape<1, 2, 2>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x4x1.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x4x1.cu new file mode 100644 index 000000000..bc916cfc5 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x4x1.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 8>; +using ThreadShape = cutlass::gemm::GemmShape<1, 4, 1>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x128_1x4x4.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x128_1x4x4.cu new file mode 100644 index 000000000..5ed9df149 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x128_1x4x4.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 128>; +using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x1x2.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x1x2.cu new file mode 100644 index 000000000..d38317f27 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x1x2.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 16>; +using ThreadShape = cutlass::gemm::GemmShape<1, 1, 2>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x2x1.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x2x1.cu new file mode 100644 index 000000000..7ebe415c8 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x2x1.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 16>; +using ThreadShape = cutlass::gemm::GemmShape<1, 2, 1>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x1x4.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x1x4.cu new file mode 100644 index 000000000..e7647be11 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x1x4.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 32>; +using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x2x2.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x2x2.cu new file mode 100644 index 000000000..2e0f05752 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x2x2.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 32>; +using ThreadShape = cutlass::gemm::GemmShape<1, 2, 2>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x4x1.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x4x1.cu new file mode 100644 index 000000000..c8252f5f8 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x4x1.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 32>; +using ThreadShape = cutlass::gemm::GemmShape<1, 4, 1>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x2x4.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x2x4.cu new file mode 100644 index 000000000..bc53eefba --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x2x4.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 64>; +using ThreadShape = cutlass::gemm::GemmShape<1, 2, 4>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x4x2.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x4x2.cu new file mode 100644 index 000000000..2c818beb5 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x4x2.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 64>; +using ThreadShape = cutlass::gemm::GemmShape<1, 4, 2>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x8_1x1x1.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x8_1x1x1.cu new file mode 100644 index 000000000..4efb152ba --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x8_1x1x1.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 8>; +using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x1x4.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x1x4.cu new file mode 100644 index 000000000..1ed408b3e --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x1x4.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 16>; +using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x2x2.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x2x2.cu new file mode 100644 index 000000000..ddf70bb8e --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x2x2.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 16>; +using ThreadShape = cutlass::gemm::GemmShape<1, 2, 2>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x4x1.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x4x1.cu new file mode 100644 index 000000000..f35e9ed30 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x4x1.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 16>; +using ThreadShape = cutlass::gemm::GemmShape<1, 4, 1>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x2x4.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x2x4.cu new file mode 100644 index 000000000..4b2e2fdf7 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x2x4.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 32>; +using ThreadShape = cutlass::gemm::GemmShape<1, 2, 4>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x4x2.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x4x2.cu new file mode 100644 index 000000000..c77711337 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x4x2.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 32>; +using ThreadShape = cutlass::gemm::GemmShape<1, 4, 2>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x4_1x1x1.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x4_1x1x1.cu new file mode 100644 index 000000000..8ab75b3e7 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x4_1x1x1.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>; +using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x64_1x4x4.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x64_1x4x4.cu new file mode 100644 index 000000000..2d281ce7a --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x64_1x4x4.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 64>; +using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x1x2.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x1x2.cu new file mode 100644 index 000000000..eeab0c509 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x1x2.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 8>; +using ThreadShape = cutlass::gemm::GemmShape<1, 1, 2>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x2x1.cu b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x2x1.cu new file mode 100644 index 000000000..aef942ea0 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x2x1.cu @@ -0,0 +1,26 @@ +#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2) +// generated by gen_cutlass_gemv_batched_strided_kern_impls.py +// ignore warning of cutlass +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl" + +using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 8>; +using ThreadShape = cutlass::gemm::GemmShape<1, 2, 1>; +using GemvKernel = cutlass::gemm::kernel::DefaultGemv< + ThreadBlockShape, + ThreadShape, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor, + float, cutlass::layout::RowMajor>; +template void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a, + const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c, + cudaStream_t stream); + +#pragma GCC diagnostic pop +#endif diff --git a/dnn/src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl new file mode 100644 index 000000000..2e7f6a687 --- /dev/null +++ b/dnn/src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl @@ -0,0 +1,70 @@ +/** + * \file + * dnn/src/cuda/matrix_mul/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "cutlass/gemm/kernel/default_gemv.h" +#include "cutlass/gemm/kernel/gemv_batched_strided.h" +#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh" +#include "src/cuda/query_blocksize.cuh" + +using namespace megdnn; +using namespace cuda; +using namespace cutlass_wrapper; + +template +void megdnn::cuda::cutlass_wrapper:: + cutlass_vector_matrix_mul_batched_strided_wrapper( + BatchedGemmCoord const& problem_size, + const typename GemvKernel::ElementA* d_A, size_t lda, + size_t batch_stride_a, const typename GemvKernel::ElementB* d_B, + size_t ldb, size_t batch_stride_b, + typename GemvKernel::ElementCD* d_C, size_t ldc, + size_t batch_stride_c, cudaStream_t stream) { + typename GemvKernel::IteratorA::TensorRef tensor_a{ + const_cast(d_A), + typename GemvKernel::LayoutA{static_cast(lda)}}; + typename GemvKernel::IteratorB::TensorRef tensor_b{ + const_cast(d_B), + typename GemvKernel::LayoutB{static_cast(ldb)}}; + typename GemvKernel::IteratorCD::TensorRef tensor_c{ + d_C, typename GemvKernel::LayoutCD{static_cast(ldc)}}; + static int constexpr kThreadsPerN = GemvKernel::Core::kThreadsPerN; + static int constexpr kThreadsPerK = GemvKernel::Core::kThreadsPerK; + void (*kern)(BatchedGemmCoord, typename GemvKernel::IteratorA::TensorRef, + typename GemvKernel::IteratorA::TensorRef::LongIndex, + typename GemvKernel::IteratorB::TensorRef, + typename GemvKernel::IteratorB::TensorRef::LongIndex, + typename GemvKernel::IteratorCD::TensorRef, + typename GemvKernel::IteratorCD::TensorRef::LongIndex); + kern = cutlass::gemm::kernel::GemvBatchedStrided; +// int nr_threads = static_cast( +// query_blocksize_for_kernel(reinterpret_cast(kern))); +// nr_threads = std::max(nr_threads, kThreadsPerN); +// megdnn_assert(nr_threads % kThreadsPerN == 0); +// int batch = nr_threads / kThreadsPerN; +// batch = std::min(batch, problem_size.batch()); + auto tile_size = BatchedGemmCoord(GemvKernel::ThreadBlockShape::kM, + GemvKernel::ThreadBlockShape::kN, + GemvKernel::ThreadBlockShape::kK, 1); + typename GemvKernel::ThreadBlockSwizzle swizzler; + auto tiled_shape = swizzler.get_tiled_shape(problem_size, tile_size); + dim3 grid = swizzler.get_grid_shape(tiled_shape); + dim3 block(kThreadsPerN, kThreadsPerK, 1); + int smem_size = + int(sizeof(typename GemvKernel::ThreadBlockGemv::SharedStorage)); + megdnn_assert(smem_size < (48 << 10)); + kern<<>>( + problem_size, tensor_a, batch_stride_a, tensor_b, batch_stride_b, + tensor_c, batch_stride_c); + after_kernel_launch(); +} + +// vim: syntax=cuda.doxygen diff --git a/dnn/src/cuda/matrix_mul/opr_impl.h b/dnn/src/cuda/matrix_mul/opr_impl.h index b554a9ea2..2a7e00a70 100644 --- a/dnn/src/cuda/matrix_mul/opr_impl.h +++ b/dnn/src/cuda/matrix_mul/opr_impl.h @@ -41,8 +41,11 @@ public: #if !MEGDNN_DISABLE_FLOAT16 class AlgoBFloat16; #endif +#if CUDA_VERSION >= 9020 class AlgoFloat32SIMT; class AlgoFloat32SIMTSplitK; + class AlgoFloat32SIMTGemvBatchedStrided; +#endif class AlgoPack; static const AlgoPack& algo_pack() { diff --git a/dnn/test/cuda/cutlass_matmul.cpp b/dnn/test/cuda/cutlass_matmul.cpp index ae04cd028..c76955a3f 100644 --- a/dnn/test/cuda/cutlass_matmul.cpp +++ b/dnn/test/cuda/cutlass_matmul.cpp @@ -90,7 +90,7 @@ void test_multibatchsize( if (std::regex_match( i.name.c_str(), std::regex("(" + std::string(algo) + ")(.*)"))) { - opr_reference->execution_policy().algo = i; + opr_reference->execution_policy().algo = i.desc; break; } } @@ -119,7 +119,7 @@ void test_multibatchsize( if (std::regex_match( i.name.c_str(), std::regex("(" + std::string(algo) + ")(.*)"))) { - opr_reference->execution_policy().algo = i; + opr_reference->execution_policy().algo = i.desc; break; } } @@ -292,6 +292,30 @@ TEST_F(CUDA, CUTLASS_GEMM_SPLIT_K_MULTI_BATCHSIZE) { [](const matrix_mul::TestArg& arg) { return arg.k <= arg.n; }); } +TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_128_MULTI_BATCHSIZE) { + auto args = matrix_mul::get_matmul_args_no_mask(); + test_multibatchsize(handle_cuda(), dtype::Float32(), dtype::Float32(), + dtype::Float32(), + "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_128", args, + param::MatrixMul::Format::DEFAULT); +} + +TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_64_MULTI_BATCHSIZE) { + auto args = matrix_mul::get_matmul_args_no_mask(); + test_multibatchsize(handle_cuda(), dtype::Float32(), dtype::Float32(), + dtype::Float32(), + "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_64", args, + param::MatrixMul::Format::DEFAULT); +} + +TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_32_MULTI_BATCHSIZE) { + auto args = matrix_mul::get_matmul_args_no_mask(); + test_multibatchsize(handle_cuda(), dtype::Float32(), dtype::Float32(), + dtype::Float32(), + "CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_32", args, + param::MatrixMul::Format::DEFAULT); +} + #define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \ cb(1, 64, 256, 8, 32, 64, 8); \ cb(2, 256, 64, 8, 64, 32, 8); \ -- GitLab