提交 5a14a892 编写于 作者: M Megvii Engine Team 提交者: huangxinda

refactor(dnn/cuda): refactor cutlass kernel generator for gemm and gemv

GitOrigin-RevId: 11d78ab2270f0720d7d79e186124a1254c467980
上级 b33217d8
......@@ -37,21 +37,21 @@ all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL)
../src/cuda/elemwise_multi_type/kimpl: gen_elemwise_multi_type_kern_impls.py
./$^ --type cuda $@
../src/cuda/conv_bias/int8/kimpl: gen_cuda_conv_bias_kern_impls.py cutlass_generator
../src/cuda/conv_bias/int8/kimpl: gen_cuda_conv_bias_kern_impls.py cutlass_generator/generator.py
./gen_cuda_conv_bias_kern_impls.py --type dp4a $@
python3 ./cutlass_generator/generator.py --operations all --type simt $@
python3 ./cutlass_generator/generator.py --operations conv2d --type simt $@
../src/cuda/conv_bias/int8_imma/kimpl: gen_cuda_conv_bias_kern_impls.py cutlass_generator
../src/cuda/conv_bias/int8_imma/kimpl: gen_cuda_conv_bias_kern_impls.py cutlass_generator/generator.py
./gen_cuda_conv_bias_kern_impls.py --type imma $@
python3 ./cutlass_generator/generator.py --operations conv2d --type tensorop8816 $@
../src/cuda/batch_conv_bias/int8/kimpl: gen_cuda_batch_conv_bias_kern_impls.py
./$^ --type dp4a $@
../src/cuda/matrix_mul/fp32_simt/kimpl: gen_cutlass_matmul_kern_impls.py
./$^ $@
../src/cuda/matrix_mul/fp32_simt/kimpl: cutlass_generator/generator.py
python3 ./cutlass_generator/generator.py --operations gemm --type simt $@
../src/cuda/matrix_mul/fp32_simt_gemv/kimpl: gen_cutlass_gemv_batched_strided_kern_impls.py
./$^ $@
../src/cuda/matrix_mul/fp32_simt_gemv/kimpl: cutlass_generator
python3 ./cutlass_generator/generator.py --operations gemv --type simt $@
.PHONY: all
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x128_16_tt_align2x4
using Operation_cutlass_simt_sgemv_batched_strided_1x128_16_tt_align2x4 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 128, 16>,
cutlass::gemm::GemmShape<1, 4, 2>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x128_16_tt_align2x4>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_16_tt_align2x4::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_16_tt_align2x4::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x128_16_tt_align2x4::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2
using Operation_cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 128, 16>,
cutlass::gemm::GemmShape<1, 2, 4>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x128_16_tt_align4x2::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x128_2_tt_align1x1
using Operation_cutlass_simt_sgemv_batched_strided_1x128_2_tt_align1x1 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 128, 2>,
cutlass::gemm::GemmShape<1, 1, 1>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x128_2_tt_align1x1>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_2_tt_align1x1::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_2_tt_align1x1::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x128_2_tt_align1x1::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4
using Operation_cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 128, 32>,
cutlass::gemm::GemmShape<1, 4, 4>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x128_32_tt_align4x4::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x128_4_tt_align1x2
using Operation_cutlass_simt_sgemv_batched_strided_1x128_4_tt_align1x2 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 128, 4>,
cutlass::gemm::GemmShape<1, 2, 1>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x128_4_tt_align1x2>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_4_tt_align1x2::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_4_tt_align1x2::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x128_4_tt_align1x2::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x128_4_tt_align2x1
using Operation_cutlass_simt_sgemv_batched_strided_1x128_4_tt_align2x1 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 128, 4>,
cutlass::gemm::GemmShape<1, 1, 2>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x128_4_tt_align2x1>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_4_tt_align2x1::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_4_tt_align2x1::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x128_4_tt_align2x1::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x128_8_tt_align1x4
using Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align1x4 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 128, 8>,
cutlass::gemm::GemmShape<1, 4, 1>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align1x4>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align1x4::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align1x4::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align1x4::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x128_8_tt_align2x2
using Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align2x2 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 128, 8>,
cutlass::gemm::GemmShape<1, 2, 2>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align2x2>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align2x2::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align2x2::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align2x2::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1
using Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 128, 8>,
cutlass::gemm::GemmShape<1, 1, 4>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x128_8_tt_align4x1::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x32_128_tt_align4x4
using Operation_cutlass_simt_sgemv_batched_strided_1x32_128_tt_align4x4 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 32, 128>,
cutlass::gemm::GemmShape<1, 4, 4>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x32_128_tt_align4x4>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_128_tt_align4x4::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_128_tt_align4x4::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x32_128_tt_align4x4::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x32_16_tt_align1x2
using Operation_cutlass_simt_sgemv_batched_strided_1x32_16_tt_align1x2 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 32, 16>,
cutlass::gemm::GemmShape<1, 2, 1>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x32_16_tt_align1x2>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_16_tt_align1x2::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_16_tt_align1x2::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x32_16_tt_align1x2::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x32_16_tt_align2x1
using Operation_cutlass_simt_sgemv_batched_strided_1x32_16_tt_align2x1 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 32, 16>,
cutlass::gemm::GemmShape<1, 1, 2>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x32_16_tt_align2x1>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_16_tt_align2x1::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_16_tt_align2x1::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x32_16_tt_align2x1::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x32_32_tt_align1x4
using Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align1x4 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 32, 32>,
cutlass::gemm::GemmShape<1, 4, 1>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align1x4>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align1x4::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align1x4::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align1x4::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x32_32_tt_align2x2
using Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align2x2 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 32, 32>,
cutlass::gemm::GemmShape<1, 2, 2>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align2x2>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align2x2::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align2x2::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align2x2::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x32_32_tt_align4x1
using Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align4x1 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 32, 32>,
cutlass::gemm::GemmShape<1, 1, 4>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align4x1>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align4x1::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align4x1::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x32_32_tt_align4x1::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x32_64_tt_align2x4
using Operation_cutlass_simt_sgemv_batched_strided_1x32_64_tt_align2x4 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 32, 64>,
cutlass::gemm::GemmShape<1, 4, 2>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x32_64_tt_align2x4>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_64_tt_align2x4::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_64_tt_align2x4::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x32_64_tt_align2x4::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x32_64_tt_align4x2
using Operation_cutlass_simt_sgemv_batched_strided_1x32_64_tt_align4x2 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 32, 64>,
cutlass::gemm::GemmShape<1, 2, 4>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x32_64_tt_align4x2>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_64_tt_align4x2::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_64_tt_align4x2::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x32_64_tt_align4x2::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x32_8_tt_align1x1
using Operation_cutlass_simt_sgemv_batched_strided_1x32_8_tt_align1x1 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 32, 8>,
cutlass::gemm::GemmShape<1, 1, 1>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x32_8_tt_align1x1>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_8_tt_align1x1::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x32_8_tt_align1x1::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x32_8_tt_align1x1::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x64_16_tt_align1x4
using Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align1x4 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 64, 16>,
cutlass::gemm::GemmShape<1, 4, 1>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align1x4>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align1x4::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align1x4::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align1x4::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x64_16_tt_align2x2
using Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align2x2 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 64, 16>,
cutlass::gemm::GemmShape<1, 2, 2>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align2x2>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align2x2::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align2x2::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align2x2::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x64_16_tt_align4x1
using Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align4x1 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 64, 16>,
cutlass::gemm::GemmShape<1, 1, 4>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align4x1>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align4x1::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align4x1::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x64_16_tt_align4x1::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x64_32_tt_align2x4
using Operation_cutlass_simt_sgemv_batched_strided_1x64_32_tt_align2x4 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 64, 32>,
cutlass::gemm::GemmShape<1, 4, 2>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x64_32_tt_align2x4>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_32_tt_align2x4::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_32_tt_align2x4::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x64_32_tt_align2x4::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x64_32_tt_align4x2
using Operation_cutlass_simt_sgemv_batched_strided_1x64_32_tt_align4x2 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 64, 32>,
cutlass::gemm::GemmShape<1, 2, 4>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x64_32_tt_align4x2>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_32_tt_align4x2::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_32_tt_align4x2::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x64_32_tt_align4x2::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x64_4_tt_align1x1
using Operation_cutlass_simt_sgemv_batched_strided_1x64_4_tt_align1x1 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 64, 4>,
cutlass::gemm::GemmShape<1, 1, 1>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x64_4_tt_align1x1>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_4_tt_align1x1::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_4_tt_align1x1::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x64_4_tt_align1x1::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x64_64_tt_align4x4
using Operation_cutlass_simt_sgemv_batched_strided_1x64_64_tt_align4x4 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 64, 64>,
cutlass::gemm::GemmShape<1, 4, 4>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x64_64_tt_align4x4>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_64_tt_align4x4::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_64_tt_align4x4::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x64_64_tt_align4x4::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x64_8_tt_align1x2
using Operation_cutlass_simt_sgemv_batched_strided_1x64_8_tt_align1x2 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 64, 8>,
cutlass::gemm::GemmShape<1, 2, 1>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x64_8_tt_align1x2>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_8_tt_align1x2::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_8_tt_align1x2::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x64_8_tt_align1x2::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#pragma GCC diagnostic ignored "-Wuninitialized"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper_batched_gemv_strided.cuinl"
// Gemm operator cutlass_simt_sgemv_batched_strided_1x64_8_tt_align2x1
using Operation_cutlass_simt_sgemv_batched_strided_1x64_8_tt_align2x1 = cutlass::gemm::kernel::DefaultGemv<
cutlass::gemm::GemmShape<1, 64, 8>,
cutlass::gemm::GemmShape<1, 1, 2>,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor
>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<Operation_cutlass_simt_sgemv_batched_strided_1x64_8_tt_align2x1>(
BatchedGemmCoord const& problem_size,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_8_tt_align2x1::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename Operation_cutlass_simt_sgemv_batched_strided_1x64_8_tt_align2x1::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename Operation_cutlass_simt_sgemv_batched_strided_1x64_8_tt_align2x1::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 2>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 32>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 4>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 4>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 8>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 8>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 128, 8>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 128>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 32>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 32>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 32>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 64>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 64>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 32, 8>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 16>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 32>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 32>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 4>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 64>;
using ThreadShape = cutlass::gemm::GemmShape<1, 4, 4>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 8>;
using ThreadShape = cutlass::gemm::GemmShape<1, 1, 2>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using ThreadBlockShape = cutlass::gemm::GemmShape<1, 64, 8>;
using ThreadShape = cutlass::gemm::GemmShape<1, 2, 1>;
using GemvKernel = cutlass::gemm::kernel::DefaultGemv<
ThreadBlockShape,
ThreadShape,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor,
float, cutlass::layout::RowMajor>;
template void megdnn::cuda::cutlass_wrapper::
cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>(
BatchedGemmCoord const& problem_size,
const typename GemvKernel::ElementA* d_A, size_t lda, size_t batch_stride_a,
const typename GemvKernel::ElementB* d_B, size_t ldb, size_t batch_stride_b,
typename GemvKernel::ElementCD* d_C, size_t ldc, size_t batch_stride_c,
cudaStream_t stream);
#pragma GCC diagnostic pop
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册