提交 2de2222e 编写于 作者: M Megvii Engine Team

feat(dnn/cuda): add cutlass batched gemv kernel for matmul operator

GitOrigin-RevId: 51702c4e79347175a993700be4022bc38102d79f
上级 973d2a0a
......@@ -9,7 +9,7 @@ ELEMWISE_IMPL := ../src/cuda/cond_take/kimpl \
../src/cuda/elemwise_multi_type/kimpl
CUDA_CONV_IMPL := ../src/cuda/conv_bias/int8/kimpl ../src/cuda/conv_bias/int8_imma/kimpl ../src/cuda/batch_conv_bias/int8/kimpl
CUDA_MATMUL_IMPL := ../src/cuda/matrix_mul/fp32_simt/kimpl
CUDA_MATMUL_IMPL := ../src/cuda/matrix_mul/fp32_simt/kimpl ../src/cuda/matrix_mul/fp32_simt_gemv/kimpl
all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL)
......@@ -51,4 +51,7 @@ all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL)
../src/cuda/matrix_mul/fp32_simt/kimpl: gen_cutlass_matmul_kern_impls.py
./$^ $@
../src/cuda/matrix_mul/fp32_simt_gemv/kimpl: gen_cutlass_gemv_batched_strided_kern_impls.py
./$^ $@
.PHONY: all
......@@ -33,6 +33,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
#if !MEGDNN_DISABLE_FLOAT16
all_algos.push_back(&bfloat16);
#endif
#if CUDA_VERSION >= 9020
fill_cutlass_algos();
for (auto&& algo : simt_float32) {
all_algos.push_back(&algo);
......@@ -40,12 +41,17 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
for (auto&& algo : simt_float32_split_k) {
all_algos.push_back(&algo);
}
for (auto&& algo : simt_float32_gemv_batched_strided) {
all_algos.push_back(&algo);
}
#endif
for (auto&& algo : all_algos) {
m_all_algos_map.emplace(algo->info().desc, algo);
}
}
#if CUDA_VERSION >= 9020
void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() {
using AlgoParam = AlgoFloat32SIMT::AlgoParam;
simt_float32.emplace_back(AlgoParam{64, 256, 8, 32, 64, 8});
......@@ -82,7 +88,11 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() {
simt_float32_split_k.emplace_back(AlgoParam{16, 32, 8, 16, 32, 8});
simt_float32_split_k.emplace_back(AlgoParam{16, 64, 8, 16, 64, 8});
simt_float32_split_k.emplace_back(AlgoParam{16, 128, 8, 16, 64, 8});
simt_float32_gemv_batched_strided.emplace_back(128);
simt_float32_gemv_batched_strided.emplace_back(64);
simt_float32_gemv_batched_strided.emplace_back(32);
}
#endif
MatrixMulForwardImpl::AlgoPack MatrixMulForwardImpl::sm_algo_pack;
......
......@@ -42,8 +42,11 @@ public:
CUDA_CUBLASLT,
CUDA_NAIVE,
CUDA_BFLOAT16,
#if CUDA_VERSION >= 9020
CUDA_FLOAT32_SIMT,
CUDA_FLOAT32_SIMT_SPLIT_K,
CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED,
#endif
};
using Mapper = std::unordered_map<AlgorithmDesc, AlgoBase*>;
......@@ -167,6 +170,7 @@ private:
};
#endif
#if CUDA_VERSION >= 9020
class MatrixMulForwardImpl::AlgoFloat32SIMT final : public AlgoBase {
public:
struct AlgoParam {
......@@ -224,6 +228,32 @@ private:
std::string m_name;
};
class MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided final
: public AlgoBase {
public:
AlgoFloat32SIMTGemvBatchedStrided(int threadblock_n)
: m_threadblock_n{threadblock_n},
m_name{ssprintf("CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_%d",
m_threadblock_n)} {}
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
const char* name() const override { return m_name.c_str(); }
void exec(const ExecArgs& args) const override;
bool is_reproducible() const override { return true; }
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED)
std::string param() const override {
std::string ret;
serialize_write_pod(m_threadblock_n, ret);
return ret;
}
private:
int m_threadblock_n;
std::string m_name;
};
#endif
class MatrixMulForwardImpl::AlgoPack : NonCopyableObj {
private:
AlgoBase::Mapper m_all_algos_map;
......@@ -241,8 +271,12 @@ public:
#if !MEGDNN_DISABLE_FLOAT16
AlgoBFloat16 bfloat16;
#endif
#if CUDA_VERSION >= 9020
std::vector<AlgoFloat32SIMT> simt_float32;
std::vector<AlgoFloat32SIMTSplitK> simt_float32_split_k;
std::vector<AlgoFloat32SIMTGemvBatchedStrided>
simt_float32_gemv_batched_strided;
#endif
std::vector<AlgoBase*> all_algos;
const AlgoBase::Mapper& all_algos_map() const { return m_all_algos_map; }
......
......@@ -15,20 +15,17 @@
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 9020
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available(
const SizeArgs& args) const {
#if CUDA_VERSION >= 9200
return args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_a.dtype == dtype::Float32() &&
args.layout_b.dtype == dtype::Float32() &&
args.layout_c.dtype == dtype::Float32();
#else
return false;
#endif
}
size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes(
......@@ -69,5 +66,6 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const {
m_algo_param.warp_k},
stream);
}
#endif
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 9020
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
bool MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::is_available(
const SizeArgs& args) const {
auto&& param = args.opr->param();
bool ta = param.transposeA, tb = param.transposeB;
return args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_a.dtype == dtype::Float32() &&
args.layout_b.dtype == dtype::Float32() &&
args.layout_c.dtype == dtype::Float32() && ((!ta) && (!tb));
}
size_t
MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::get_workspace_in_bytes(
const SizeArgs& /* args */) const {
return 0;
}
void MatrixMulForwardImpl::AlgoFloat32SIMTGemvBatchedStrided::exec(
const ExecArgs& args) const {
size_t lda = args.tensor_a.layout.stride[0],
ldb = args.tensor_b.layout.stride[0],
ldc = args.tensor_c.layout.stride[0];
auto&& param = args.opr->param();
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];
// m is always 1 in gemv batched strided case
BatchedGemmCoord problem_size{1, n, k, m};
auto&& stream = cuda_stream(args.opr->handle());
return cutlass_matrix_mul_float32_simt_gemv_batched_strided(
args.tensor_a.ptr<dt_float32>(), lda, lda,
args.tensor_b.ptr<dt_float32>(), ldb, 0,
args.tensor_c.ptr<dt_float32>(), ldc, ldc, problem_size,
m_threadblock_n, stream);
}
#endif
// vim: syntax=cpp.doxygen
......@@ -15,6 +15,7 @@
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 9020
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
......@@ -22,12 +23,12 @@ using namespace cutlass_wrapper;
bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available(
const SizeArgs& args) const {
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
int n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
return args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_a.dtype == dtype::Float32() &&
args.layout_b.dtype == dtype::Float32() &&
args.layout_c.dtype == dtype::Float32() && k > std::max(m, n);
args.layout_c.dtype == dtype::Float32() && k > n;
}
size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes(
......@@ -38,7 +39,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes(
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
GemmCoord problem_size{m, n, k};
int split_k_slices = k / std::max(m, n);
int split_k_slices = k / n;
return cutlass_matrix_mul_float32_simt_get_workspace_size(
param.transposeA, lda, param.transposeB, ldb, ldc, problem_size,
1.f, 0.f,
......@@ -58,7 +59,7 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec(
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];
GemmCoord problem_size{m, n, k};
int split_k_slices = k / std::max(m, n);
int split_k_slices = k / n;
auto&& stream = cuda_stream(args.opr->handle());
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr);
return cutlass_matrix_mul_float32_simt(
......@@ -72,5 +73,6 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec(
m_algo_param.warp_k},
stream, split_k_slices);
}
#endif
// vim: syntax=cpp.doxygen
......@@ -10,16 +10,16 @@
* implied.
*/
// ignore warning of cutlass
#include "cuda.h"
#if __CUDACC_VER_MAJOR__ > 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "cuda.h"
#if __CUDACC_VER_MAJOR__ > 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#endif
#include "cutlass/gemm/kernel/default_gemv.h"
#include "src/common/opr_param_defs_enumv.cuh"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#pragma GCC diagnostic pop
......@@ -54,18 +54,6 @@ using namespace cutlass_wrapper;
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
#if __CUDACC_VER_MAJOR__ < 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ <= 2)
void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt(
const float* /* d_A */, bool /* transpose_A */, size_t /* lda */,
const float* /* d_B */, bool /* transpose_B */, size_t /* ldb */,
float* /* d_C */, size_t /* ldc */, int* /* workspace */,
GemmCoord const& /* problem_size */, float /* alpha */,
float /* beta */, const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */, cudaStream_t /* stream */,
int /* split_k_slices */) {}
#else
void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt(
const float* d_A, bool transpose_A, size_t lda, const float* d_B,
bool transpose_B, size_t ldb, float* d_C, size_t ldc, int* workspace,
......@@ -162,20 +150,7 @@ void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt(
#undef cb
}
}
#endif
#if __CUDACC_VER_MAJOR__ < 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ <= 2)
size_t megdnn::cuda::cutlass_wrapper::
cutlass_matrix_mul_float32_simt_get_workspace_size(
bool /* transpose_A */, size_t /* lda */,
bool /* transpose_B */, size_t /* ldb */, size_t /* ldc */,
GemmCoord const& /* problem_size */, float /* alpha */,
float /* beta */, const GemmCoord& /* threadblock_shape */,
const GemmCoord& /* warp_shape */, int /* split_k_slices */) {
return 0;
}
#else
size_t megdnn::cuda::cutlass_wrapper::
cutlass_matrix_mul_float32_simt_get_workspace_size(
bool transpose_A, size_t lda, bool transpose_B, size_t ldb,
......@@ -294,7 +269,86 @@ size_t megdnn::cuda::cutlass_wrapper::
#undef cb
}
}
#endif
#undef DISPATCH
/* ============ cutlass kernel wrapper for f32 vector-matrix mul batched strided
* ===========
*/
#define DISPATCH(cb) \
cb(128, 4, 4); \
cb(128, 4, 2); \
cb(128, 4, 1); \
cb(128, 2, 4); \
cb(128, 1, 4); \
cb(128, 2, 2); \
cb(128, 1, 2); \
cb(128, 2, 1); \
cb(128, 1, 1); \
cb(64, 4, 4); \
cb(64, 4, 2); \
cb(64, 4, 1); \
cb(64, 2, 4); \
cb(64, 1, 4); \
cb(64, 2, 2); \
cb(64, 1, 2); \
cb(64, 2, 1); \
cb(64, 1, 1); \
cb(32, 4, 4); \
cb(32, 4, 2); \
cb(32, 4, 1); \
cb(32, 2, 4); \
cb(32, 1, 4); \
cb(32, 2, 2); \
cb(32, 1, 2); \
cb(32, 2, 1); \
cb(32, 1, 1); \
megdnn_assert(false, \
"unsupported gemv batched strided A=%dX%dX%d, B=%dX%dX%d", \
problem_size.batch(), problem_size.m(), problem_size.k(), \
problem_size.batch(), problem_size.k(), problem_size.n());
void megdnn::cuda::cutlass_wrapper::
cutlass_matrix_mul_float32_simt_gemv_batched_strided(
const float* d_A, size_t lda, size_t batch_stride_a,
const float* d_B, size_t ldb, size_t batch_stride_b, float* d_C,
size_t ldc, size_t batch_stride_c,
BatchedGemmCoord const& problem_size, int threadblock_n,
cudaStream_t stream) {
int LDG_K, LDG_N;
if (lda % 4 == 0)
LDG_K = 4;
else if (lda % 2 == 0)
LDG_K = 2;
else
LDG_K = 1;
if (ldb % 4 == 0)
LDG_N = 4;
else if (ldb % 2 == 0)
LDG_N = 2;
else
LDG_N = 1;
#define cb(threadblock_n_, LDG_K_, LDG_N_) \
if (threadblock_n == threadblock_n_ && LDG_K == LDG_K_ && \
LDG_N == LDG_N_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<1, threadblock_n_, \
(256 * LDG_K_) / \
(threadblock_n_ / LDG_N_)>; \
using ThreadShape = cutlass::gemm::GemmShape<1, LDG_N_, LDG_K_>; \
using GemvKernel = cutlass::gemm::kernel::DefaultGemv< \
ThreadBlockShape, ThreadShape, float, \
cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, \
float, cutlass::layout::RowMajor>; \
return cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( \
problem_size, d_A, lda, batch_stride_a, d_B, ldb, \
batch_stride_b, d_C, ldc, batch_stride_c, stream); \
}
DISPATCH(cb)
#undef cb
}
#undef DISPATCH
#endif
// vim: syntax=cuda.doxygen
......@@ -13,11 +13,13 @@
#include "cutlass/gemm/gemm.h"
#include "src/cuda/utils.cuh"
#if CUDA_VERSION >= 9020
namespace megdnn {
namespace cuda {
namespace cutlass_wrapper {
using GemmCoord = cutlass::gemm::GemmCoord;
using BatchedGemmCoord = cutlass::gemm::BatchedGemmCoord;
template <typename Gemm>
void cutlass_matrix_mul_wrapper(
......@@ -38,10 +40,26 @@ void cutlass_matrix_mul_float32_simt(
size_t cutlass_matrix_mul_float32_simt_get_workspace_size(
bool transpose_A, size_t lda, bool transpose_B, size_t ldb, size_t ldc,
GemmCoord const& problem_size, float alpha, float beta,
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape, int split_k_slices = 1);
const GemmCoord& threadblock_shape, const GemmCoord& warp_shape,
int split_k_slices = 1);
template <typename GemvKernel>
void cutlass_vector_matrix_mul_batched_strided_wrapper(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda,
size_t batch_stride_a, const typename GemvKernel::ElementB* d_B,
size_t ldb, size_t batch_stride_b, typename GemvKernel::ElementCD* d_C,
size_t ldc, size_t batch_stride_c, cudaStream_t stream);
void cutlass_matrix_mul_float32_simt_gemv_batched_strided(
const float* d_A, size_t lda, size_t batch_stride_a, const float* d_B,
size_t ldb, size_t batch_stride_b, float* d_C, size_t ldc,
size_t batch_stride_c, BatchedGemmCoord const& problem_size,
int threadblock_n, cudaStream_t stream);
} // namespace cutlass_wrapper
} // namespace cuda
} // namespace megdnn
#endif
// vim: syntax=cuda.doxygen
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 2>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 32>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 4>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 4>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 8>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 8>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 8>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 128>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 32>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 32>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 32>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 64>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 64>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 8>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 32>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 32>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 64>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 8>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 8>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
/**
* \file
* dnn/src/cuda/matrix_mul/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "cutlass/gemm/kernel/default_gemv.h"
#include "cutlass/gemm/kernel/gemv_batched_strided.h"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/query_blocksize.cuh"
using namespace megdnn;
using namespace cuda;
using namespace cutlass_wrapper;
template <typename GemvKernel>
void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda,
size_t batch_stride_a, const typename GemvKernel::ElementB* d_B,
size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc,
size_t batch_stride_c, cudaStream_t stream) {
typename GemvKernel::IteratorA::TensorRef tensor_a{
const_cast<typename GemvKernel::ElementA*>(d_A),
typename GemvKernel::LayoutA{static_cast<int>(lda)}};
typename GemvKernel::IteratorB::TensorRef tensor_b{
const_cast<typename GemvKernel::ElementB*>(d_B),
typename GemvKernel::LayoutB{static_cast<int>(ldb)}};
typename GemvKernel::IteratorCD::TensorRef tensor_c{
d_C, typename GemvKernel::LayoutCD{static_cast<int>(ldc)}};
static int constexpr kThreadsPerN = GemvKernel::Core::kThreadsPerN;
static int constexpr kThreadsPerK = GemvKernel::Core::kThreadsPerK;
void (*kern)(BatchedGemmCoord, typename GemvKernel::IteratorA::TensorRef,
typename GemvKernel::IteratorA::TensorRef::LongIndex,
typename GemvKernel::IteratorB::TensorRef,
typename GemvKernel::IteratorB::TensorRef::LongIndex,
typename GemvKernel::IteratorCD::TensorRef,
typename GemvKernel::IteratorCD::TensorRef::LongIndex);
kern = cutlass::gemm::kernel::GemvBatchedStrided<GemvKernel>;
// int nr_threads = static_cast<int>(
// query_blocksize_for_kernel(reinterpret_cast<const void*>(kern)));
// nr_threads = std::max(nr_threads, kThreadsPerN);
// megdnn_assert(nr_threads % kThreadsPerN == 0);
// int batch = nr_threads / kThreadsPerN;
// batch = std::min(batch, problem_size.batch());
auto tile_size = BatchedGemmCoord(GemvKernel::ThreadBlockShape::kM,
GemvKernel::ThreadBlockShape::kN,
GemvKernel::ThreadBlockShape::kK, 1);
typename GemvKernel::ThreadBlockSwizzle swizzler;
auto tiled_shape = swizzler.get_tiled_shape(problem_size, tile_size);
dim3 grid = swizzler.get_grid_shape(tiled_shape);
dim3 block(kThreadsPerN, kThreadsPerK, 1);
int smem_size =
int(sizeof(typename GemvKernel::ThreadBlockGemv::SharedStorage));
megdnn_assert(smem_size < (48 << 10));
kern<<<grid, block, smem_size, stream>>>(
problem_size, tensor_a, batch_stride_a, tensor_b, batch_stride_b,
tensor_c, batch_stride_c);
after_kernel_launch();
}
// vim: syntax=cuda.doxygen
......@@ -41,8 +41,11 @@ public:
#if !MEGDNN_DISABLE_FLOAT16
class AlgoBFloat16;
#endif
#if CUDA_VERSION >= 9020
class AlgoFloat32SIMT;
class AlgoFloat32SIMTSplitK;
class AlgoFloat32SIMTGemvBatchedStrided;
#endif
class AlgoPack;
static const AlgoPack& algo_pack() {
......
......@@ -90,7 +90,7 @@ void test_multibatchsize(
if (std::regex_match(
i.name.c_str(),
std::regex("(" + std::string(algo) + ")(.*)"))) {
opr_reference->execution_policy().algo = i;
opr_reference->execution_policy().algo = i.desc;
break;
}
}
......@@ -119,7 +119,7 @@ void test_multibatchsize(
if (std::regex_match(
i.name.c_str(),
std::regex("(" + std::string(algo) + ")(.*)"))) {
opr_reference->execution_policy().algo = i;
opr_reference->execution_policy().algo = i.desc;
break;
}
}
......@@ -292,6 +292,30 @@ TEST_F(CUDA, CUTLASS_GEMM_SPLIT_K_MULTI_BATCHSIZE) {
[](const matrix_mul::TestArg& arg) { return arg.k <= arg.n; });
}
TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_128_MULTI_BATCHSIZE) {
auto args = matrix_mul::get_matmul_args_no_mask();
test_multibatchsize(handle_cuda(), dtype::Float32(), dtype::Float32(),
dtype::Float32(),
"CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_128", args,
param::MatrixMul::Format::DEFAULT);
}
TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_64_MULTI_BATCHSIZE) {
auto args = matrix_mul::get_matmul_args_no_mask();
test_multibatchsize(handle_cuda(), dtype::Float32(), dtype::Float32(),
dtype::Float32(),
"CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_64", args,
param::MatrixMul::Format::DEFAULT);
}
TEST_F(CUDA, CUTLASS_GEMV_BATCHED_STRIDED_32_MULTI_BATCHSIZE) {
auto args = matrix_mul::get_matmul_args_no_mask();
test_multibatchsize(handle_cuda(), dtype::Float32(), dtype::Float32(),
dtype::Float32(),
"CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_32", args,
param::MatrixMul::Format::DEFAULT);
}
#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
cb(1, 64, 256, 8, 32, 64, 8); \
cb(2, 256, 64, 8, 64, 32, 8); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册