Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2de2222e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2de2222e
编写于
1月 19, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/cuda): add cutlass batched gemv kernel for matmul operator
GitOrigin-RevId: 51702c4e79347175a993700be4022bc38102d79f
上级
973d2a0a
变更
38
隐藏空白更改
内联
并排
Showing
38 changed file
with
1018 addition
and
42 deletion
+1018
-42
dnn/scripts/Makefile
dnn/scripts/Makefile
+4
-1
dnn/src/cuda/matrix_mul/algos.cpp
dnn/src/cuda/matrix_mul/algos.cpp
+10
-0
dnn/src/cuda/matrix_mul/algos.h
dnn/src/cuda/matrix_mul/algos.h
+34
-0
dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp
dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp
+2
-4
dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp
.../matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp
+58
-0
dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp
dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp
+6
-4
dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu
dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu
+84
-30
dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh
dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh
+19
-1
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x2x4.cu
...trix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x2x4.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x4x2.cu
...trix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x4x2.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x2_1x1x1.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x128x2_1x1x1.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x32_1x4x4.cu
...trix_mul_fp32_simt_gemv_batched_strided_1x128x32_1x4x4.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x1x2.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x1x2.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x2x1.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x2x1.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x1x4.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x1x4.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x2x2.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x2x2.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x4x1.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x4x1.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x128_1x4x4.cu
...trix_mul_fp32_simt_gemv_batched_strided_1x32x128_1x4x4.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x1x2.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x1x2.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x2x1.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x2x1.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x1x4.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x1x4.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x2x2.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x2x2.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x4x1.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x4x1.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x2x4.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x2x4.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x4x2.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x4x2.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x8_1x1x1.cu
...matrix_mul_fp32_simt_gemv_batched_strided_1x32x8_1x1x1.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x1x4.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x1x4.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x2x2.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x2x2.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x4x1.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x4x1.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x2x4.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x2x4.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x4x2.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x4x2.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x4_1x1x1.cu
...matrix_mul_fp32_simt_gemv_batched_strided_1x64x4_1x1x1.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x64_1x4x4.cu
...atrix_mul_fp32_simt_gemv_batched_strided_1x64x64_1x4x4.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x1x2.cu
...matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x1x2.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x2x1.cu
...matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x2x1.cu
+26
-0
dnn/src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl
...mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl
+70
-0
dnn/src/cuda/matrix_mul/opr_impl.h
dnn/src/cuda/matrix_mul/opr_impl.h
+3
-0
dnn/test/cuda/cutlass_matmul.cpp
dnn/test/cuda/cutlass_matmul.cpp
+26
-2
未找到文件。
dnn/scripts/Makefile
浏览文件 @
2de2222e
...
...
@@ -9,7 +9,7 @@ ELEMWISE_IMPL := ../src/cuda/cond_take/kimpl \
../src/cuda/elemwise_multi_type/kimpl
CUDA_CONV_IMPL
:=
../src/cuda/conv_bias/int8/kimpl ../src/cuda/conv_bias/int8_imma/kimpl ../src/cuda/batch_conv_bias/int8/kimpl
CUDA_MATMUL_IMPL
:=
../src/cuda/matrix_mul/fp32_simt/kimpl
CUDA_MATMUL_IMPL
:=
../src/cuda/matrix_mul/fp32_simt/kimpl
../src/cuda/matrix_mul/fp32_simt_gemv/kimpl
all
:
${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL)
...
...
@@ -51,4 +51,7 @@ all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} $(CUDA_MATMUL_IMPL)
../src/cuda/matrix_mul/fp32_simt/kimpl
:
gen_cutlass_matmul_kern_impls.py
./
$^
$@
../src/cuda/matrix_mul/fp32_simt_gemv/kimpl
:
gen_cutlass_gemv_batched_strided_kern_impls.py
./
$^
$@
.PHONY
:
all
dnn/src/cuda/matrix_mul/algos.cpp
浏览文件 @
2de2222e
...
...
@@ -33,6 +33,7 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
#if !MEGDNN_DISABLE_FLOAT16
all_algos
.
push_back
(
&
bfloat16
);
#endif
#if CUDA_VERSION >= 9020
fill_cutlass_algos
();
for
(
auto
&&
algo
:
simt_float32
)
{
all_algos
.
push_back
(
&
algo
);
...
...
@@ -40,12 +41,17 @@ MatrixMulForwardImpl::AlgoPack::AlgoPack() {
for
(
auto
&&
algo
:
simt_float32_split_k
)
{
all_algos
.
push_back
(
&
algo
);
}
for
(
auto
&&
algo
:
simt_float32_gemv_batched_strided
)
{
all_algos
.
push_back
(
&
algo
);
}
#endif
for
(
auto
&&
algo
:
all_algos
)
{
m_all_algos_map
.
emplace
(
algo
->
info
().
desc
,
algo
);
}
}
#if CUDA_VERSION >= 9020
void
MatrixMulForwardImpl
::
AlgoPack
::
fill_cutlass_algos
()
{
using
AlgoParam
=
AlgoFloat32SIMT
::
AlgoParam
;
simt_float32
.
emplace_back
(
AlgoParam
{
64
,
256
,
8
,
32
,
64
,
8
});
...
...
@@ -82,7 +88,11 @@ void MatrixMulForwardImpl::AlgoPack::fill_cutlass_algos() {
simt_float32_split_k
.
emplace_back
(
AlgoParam
{
16
,
32
,
8
,
16
,
32
,
8
});
simt_float32_split_k
.
emplace_back
(
AlgoParam
{
16
,
64
,
8
,
16
,
64
,
8
});
simt_float32_split_k
.
emplace_back
(
AlgoParam
{
16
,
128
,
8
,
16
,
64
,
8
});
simt_float32_gemv_batched_strided
.
emplace_back
(
128
);
simt_float32_gemv_batched_strided
.
emplace_back
(
64
);
simt_float32_gemv_batched_strided
.
emplace_back
(
32
);
}
#endif
MatrixMulForwardImpl
::
AlgoPack
MatrixMulForwardImpl
::
sm_algo_pack
;
...
...
dnn/src/cuda/matrix_mul/algos.h
浏览文件 @
2de2222e
...
...
@@ -42,8 +42,11 @@ public:
CUDA_CUBLASLT
,
CUDA_NAIVE
,
CUDA_BFLOAT16
,
#if CUDA_VERSION >= 9020
CUDA_FLOAT32_SIMT
,
CUDA_FLOAT32_SIMT_SPLIT_K
,
CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED
,
#endif
};
using
Mapper
=
std
::
unordered_map
<
AlgorithmDesc
,
AlgoBase
*>
;
...
...
@@ -167,6 +170,7 @@ private:
};
#endif
#if CUDA_VERSION >= 9020
class
MatrixMulForwardImpl
::
AlgoFloat32SIMT
final
:
public
AlgoBase
{
public:
struct
AlgoParam
{
...
...
@@ -224,6 +228,32 @@ private:
std
::
string
m_name
;
};
class
MatrixMulForwardImpl
::
AlgoFloat32SIMTGemvBatchedStrided
final
:
public
AlgoBase
{
public:
AlgoFloat32SIMTGemvBatchedStrided
(
int
threadblock_n
)
:
m_threadblock_n
{
threadblock_n
},
m_name
{
ssprintf
(
"CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_%d"
,
m_threadblock_n
)}
{}
bool
is_available
(
const
SizeArgs
&
args
)
const
override
;
size_t
get_workspace_in_bytes
(
const
SizeArgs
&
args
)
const
override
;
const
char
*
name
()
const
override
{
return
m_name
.
c_str
();
}
void
exec
(
const
ExecArgs
&
args
)
const
override
;
bool
is_reproducible
()
const
override
{
return
true
;
}
MEGDNN_DECL_ALGO_TYPE
(
CUDA_FLOAT32_SIMT_GEMV_BATCHED_STRIDED
)
std
::
string
param
()
const
override
{
std
::
string
ret
;
serialize_write_pod
(
m_threadblock_n
,
ret
);
return
ret
;
}
private:
int
m_threadblock_n
;
std
::
string
m_name
;
};
#endif
class
MatrixMulForwardImpl
::
AlgoPack
:
NonCopyableObj
{
private:
AlgoBase
::
Mapper
m_all_algos_map
;
...
...
@@ -241,8 +271,12 @@ public:
#if !MEGDNN_DISABLE_FLOAT16
AlgoBFloat16
bfloat16
;
#endif
#if CUDA_VERSION >= 9020
std
::
vector
<
AlgoFloat32SIMT
>
simt_float32
;
std
::
vector
<
AlgoFloat32SIMTSplitK
>
simt_float32_split_k
;
std
::
vector
<
AlgoFloat32SIMTGemvBatchedStrided
>
simt_float32_gemv_batched_strided
;
#endif
std
::
vector
<
AlgoBase
*>
all_algos
;
const
AlgoBase
::
Mapper
&
all_algos_map
()
const
{
return
m_all_algos_map
;
}
...
...
dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp
浏览文件 @
2de2222e
...
...
@@ -15,20 +15,17 @@
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 9020
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
cutlass_wrapper
;
bool
MatrixMulForwardImpl
::
AlgoFloat32SIMT
::
is_available
(
const
SizeArgs
&
args
)
const
{
#if CUDA_VERSION >= 9200
return
args
.
opr
->
param
().
format
==
param
::
MatrixMul
::
Format
::
DEFAULT
&&
args
.
layout_a
.
dtype
==
dtype
::
Float32
()
&&
args
.
layout_b
.
dtype
==
dtype
::
Float32
()
&&
args
.
layout_c
.
dtype
==
dtype
::
Float32
();
#else
return
false
;
#endif
}
size_t
MatrixMulForwardImpl
::
AlgoFloat32SIMT
::
get_workspace_in_bytes
(
...
...
@@ -69,5 +66,6 @@ void MatrixMulForwardImpl::AlgoFloat32SIMT::exec(const ExecArgs& args) const {
m_algo_param
.
warp_k
},
stream
);
}
#endif
// vim: syntax=cpp.doxygen
dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp
0 → 100644
浏览文件 @
2de2222e
/**
* \file dnn/src/cuda/matrix_mul/cutlass_float32_simt_gemv_batched_strided.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/handle.h"
#include "src/cuda/matrix_mul/algos.h"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 9020
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
cutlass_wrapper
;
bool
MatrixMulForwardImpl
::
AlgoFloat32SIMTGemvBatchedStrided
::
is_available
(
const
SizeArgs
&
args
)
const
{
auto
&&
param
=
args
.
opr
->
param
();
bool
ta
=
param
.
transposeA
,
tb
=
param
.
transposeB
;
return
args
.
opr
->
param
().
format
==
param
::
MatrixMul
::
Format
::
DEFAULT
&&
args
.
layout_a
.
dtype
==
dtype
::
Float32
()
&&
args
.
layout_b
.
dtype
==
dtype
::
Float32
()
&&
args
.
layout_c
.
dtype
==
dtype
::
Float32
()
&&
((
!
ta
)
&&
(
!
tb
));
}
size_t
MatrixMulForwardImpl
::
AlgoFloat32SIMTGemvBatchedStrided
::
get_workspace_in_bytes
(
const
SizeArgs
&
/* args */
)
const
{
return
0
;
}
void
MatrixMulForwardImpl
::
AlgoFloat32SIMTGemvBatchedStrided
::
exec
(
const
ExecArgs
&
args
)
const
{
size_t
lda
=
args
.
tensor_a
.
layout
.
stride
[
0
],
ldb
=
args
.
tensor_b
.
layout
.
stride
[
0
],
ldc
=
args
.
tensor_c
.
layout
.
stride
[
0
];
auto
&&
param
=
args
.
opr
->
param
();
int
m
=
args
.
tensor_c
.
layout
.
shape
[
0
],
n
=
args
.
tensor_c
.
layout
.
shape
[
1
],
k
=
args
.
tensor_a
.
layout
.
shape
[
param
.
transposeA
?
0
:
1
];
// m is always 1 in gemv batched strided case
BatchedGemmCoord
problem_size
{
1
,
n
,
k
,
m
};
auto
&&
stream
=
cuda_stream
(
args
.
opr
->
handle
());
return
cutlass_matrix_mul_float32_simt_gemv_batched_strided
(
args
.
tensor_a
.
ptr
<
dt_float32
>
(),
lda
,
lda
,
args
.
tensor_b
.
ptr
<
dt_float32
>
(),
ldb
,
0
,
args
.
tensor_c
.
ptr
<
dt_float32
>
(),
ldc
,
ldc
,
problem_size
,
m_threadblock_n
,
stream
);
}
#endif
// vim: syntax=cpp.doxygen
dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp
浏览文件 @
2de2222e
...
...
@@ -15,6 +15,7 @@
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 9020
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
cutlass_wrapper
;
...
...
@@ -22,12 +23,12 @@ using namespace cutlass_wrapper;
bool
MatrixMulForwardImpl
::
AlgoFloat32SIMTSplitK
::
is_available
(
const
SizeArgs
&
args
)
const
{
auto
&&
param
=
args
.
opr
->
param
();
int
m
=
args
.
layout_c
.
shape
[
0
],
n
=
args
.
layout_c
.
shape
[
1
],
int
n
=
args
.
layout_c
.
shape
[
1
],
k
=
args
.
layout_a
.
shape
[
param
.
transposeA
?
0
:
1
];
return
args
.
opr
->
param
().
format
==
param
::
MatrixMul
::
Format
::
DEFAULT
&&
args
.
layout_a
.
dtype
==
dtype
::
Float32
()
&&
args
.
layout_b
.
dtype
==
dtype
::
Float32
()
&&
args
.
layout_c
.
dtype
==
dtype
::
Float32
()
&&
k
>
std
::
max
(
m
,
n
)
;
args
.
layout_c
.
dtype
==
dtype
::
Float32
()
&&
k
>
n
;
}
size_t
MatrixMulForwardImpl
::
AlgoFloat32SIMTSplitK
::
get_workspace_in_bytes
(
...
...
@@ -38,7 +39,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes(
int
m
=
args
.
layout_c
.
shape
[
0
],
n
=
args
.
layout_c
.
shape
[
1
],
k
=
args
.
layout_a
.
shape
[
param
.
transposeA
?
0
:
1
];
GemmCoord
problem_size
{
m
,
n
,
k
};
int
split_k_slices
=
k
/
std
::
max
(
m
,
n
)
;
int
split_k_slices
=
k
/
n
;
return
cutlass_matrix_mul_float32_simt_get_workspace_size
(
param
.
transposeA
,
lda
,
param
.
transposeB
,
ldb
,
ldc
,
problem_size
,
1.
f
,
0.
f
,
...
...
@@ -58,7 +59,7 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec(
int
m
=
args
.
tensor_c
.
layout
.
shape
[
0
],
n
=
args
.
tensor_c
.
layout
.
shape
[
1
],
k
=
args
.
tensor_a
.
layout
.
shape
[
param
.
transposeA
?
0
:
1
];
GemmCoord
problem_size
{
m
,
n
,
k
};
int
split_k_slices
=
k
/
std
::
max
(
m
,
n
)
;
int
split_k_slices
=
k
/
n
;
auto
&&
stream
=
cuda_stream
(
args
.
opr
->
handle
());
int
*
workspace
=
reinterpret_cast
<
int
*>
(
args
.
workspace
.
raw_ptr
);
return
cutlass_matrix_mul_float32_simt
(
...
...
@@ -72,5 +73,6 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec(
m_algo_param
.
warp_k
},
stream
,
split_k_slices
);
}
#endif
// vim: syntax=cpp.doxygen
dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cu
浏览文件 @
2de2222e
...
...
@@ -10,16 +10,16 @@
* implied.
*/
// ignore warning of cutlass
#include "cuda.h"
#if __CUDACC_VER_MAJOR__ > 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "cuda.h"
#if __CUDACC_VER_MAJOR__ > 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_splitk_parallel.h"
#
endif
#
include "cutlass/gemm/kernel/default_gemv.h"
#include "src/common/opr_param_defs_enumv.cuh"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#pragma GCC diagnostic pop
...
...
@@ -54,18 +54,6 @@ using namespace cutlass_wrapper;
threadblock_shape.m(), threadblock_shape.n(), \
threadblock_shape.k(), warp_shape.m(), warp_shape.n(), \
warp_shape.k());
#if __CUDACC_VER_MAJOR__ < 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ <= 2)
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_matrix_mul_float32_simt
(
const
float
*
/* d_A */
,
bool
/* transpose_A */
,
size_t
/* lda */
,
const
float
*
/* d_B */
,
bool
/* transpose_B */
,
size_t
/* ldb */
,
float
*
/* d_C */
,
size_t
/* ldc */
,
int
*
/* workspace */
,
GemmCoord
const
&
/* problem_size */
,
float
/* alpha */
,
float
/* beta */
,
const
GemmCoord
&
/* threadblock_shape */
,
const
GemmCoord
&
/* warp_shape */
,
cudaStream_t
/* stream */
,
int
/* split_k_slices */
)
{}
#else
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_matrix_mul_float32_simt
(
const
float
*
d_A
,
bool
transpose_A
,
size_t
lda
,
const
float
*
d_B
,
bool
transpose_B
,
size_t
ldb
,
float
*
d_C
,
size_t
ldc
,
int
*
workspace
,
...
...
@@ -162,20 +150,7 @@ void megdnn::cuda::cutlass_wrapper::cutlass_matrix_mul_float32_simt(
#undef cb
}
}
#endif
#if __CUDACC_VER_MAJOR__ < 9 || \
(__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ <= 2)
size_t
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_matrix_mul_float32_simt_get_workspace_size
(
bool
/* transpose_A */
,
size_t
/* lda */
,
bool
/* transpose_B */
,
size_t
/* ldb */
,
size_t
/* ldc */
,
GemmCoord
const
&
/* problem_size */
,
float
/* alpha */
,
float
/* beta */
,
const
GemmCoord
&
/* threadblock_shape */
,
const
GemmCoord
&
/* warp_shape */
,
int
/* split_k_slices */
)
{
return
0
;
}
#else
size_t
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_matrix_mul_float32_simt_get_workspace_size
(
bool
transpose_A
,
size_t
lda
,
bool
transpose_B
,
size_t
ldb
,
...
...
@@ -294,7 +269,86 @@ size_t megdnn::cuda::cutlass_wrapper::
#undef cb
}
}
#endif
#undef DISPATCH
/* ============ cutlass kernel wrapper for f32 vector-matrix mul batched strided
* ===========
*/
#define DISPATCH(cb) \
cb(128, 4, 4); \
cb(128, 4, 2); \
cb(128, 4, 1); \
cb(128, 2, 4); \
cb(128, 1, 4); \
cb(128, 2, 2); \
cb(128, 1, 2); \
cb(128, 2, 1); \
cb(128, 1, 1); \
cb(64, 4, 4); \
cb(64, 4, 2); \
cb(64, 4, 1); \
cb(64, 2, 4); \
cb(64, 1, 4); \
cb(64, 2, 2); \
cb(64, 1, 2); \
cb(64, 2, 1); \
cb(64, 1, 1); \
cb(32, 4, 4); \
cb(32, 4, 2); \
cb(32, 4, 1); \
cb(32, 2, 4); \
cb(32, 1, 4); \
cb(32, 2, 2); \
cb(32, 1, 2); \
cb(32, 2, 1); \
cb(32, 1, 1); \
megdnn_assert(false, \
"unsupported gemv batched strided A=%dX%dX%d, B=%dX%dX%d", \
problem_size.batch(), problem_size.m(), problem_size.k(), \
problem_size.batch(), problem_size.k(), problem_size.n());
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_matrix_mul_float32_simt_gemv_batched_strided
(
const
float
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
float
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
float
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
BatchedGemmCoord
const
&
problem_size
,
int
threadblock_n
,
cudaStream_t
stream
)
{
int
LDG_K
,
LDG_N
;
if
(
lda
%
4
==
0
)
LDG_K
=
4
;
else
if
(
lda
%
2
==
0
)
LDG_K
=
2
;
else
LDG_K
=
1
;
if
(
ldb
%
4
==
0
)
LDG_N
=
4
;
else
if
(
ldb
%
2
==
0
)
LDG_N
=
2
;
else
LDG_N
=
1
;
#define cb(threadblock_n_, LDG_K_, LDG_N_) \
if (threadblock_n == threadblock_n_ && LDG_K == LDG_K_ && \
LDG_N == LDG_N_) { \
using ThreadBlockShape = \
cutlass::gemm::GemmShape<1, threadblock_n_, \
(256 * LDG_K_) / \
(threadblock_n_ / LDG_N_)>; \
using ThreadShape = cutlass::gemm::GemmShape<1, LDG_N_, LDG_K_>; \
using GemvKernel = cutlass::gemm::kernel::DefaultGemv< \
ThreadBlockShape, ThreadShape, float, \
cutlass::layout::RowMajor, float, cutlass::layout::RowMajor, \
float, cutlass::layout::RowMajor>; \
return cutlass_vector_matrix_mul_batched_strided_wrapper<GemvKernel>( \
problem_size, d_A, lda, batch_stride_a, d_B, ldb, \
batch_stride_b, d_C, ldc, batch_stride_c, stream); \
}
DISPATCH
(
cb
)
#undef cb
}
#undef DISPATCH
#endif
// vim: syntax=cuda.doxygen
dnn/src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh
浏览文件 @
2de2222e
...
...
@@ -13,11 +13,13 @@
#include "cutlass/gemm/gemm.h"
#include "src/cuda/utils.cuh"
#if CUDA_VERSION >= 9020
namespace
megdnn
{
namespace
cuda
{
namespace
cutlass_wrapper
{
using
GemmCoord
=
cutlass
::
gemm
::
GemmCoord
;
using
BatchedGemmCoord
=
cutlass
::
gemm
::
BatchedGemmCoord
;
template
<
typename
Gemm
>
void
cutlass_matrix_mul_wrapper
(
...
...
@@ -38,10 +40,26 @@ void cutlass_matrix_mul_float32_simt(
size_t
cutlass_matrix_mul_float32_simt_get_workspace_size
(
bool
transpose_A
,
size_t
lda
,
bool
transpose_B
,
size_t
ldb
,
size_t
ldc
,
GemmCoord
const
&
problem_size
,
float
alpha
,
float
beta
,
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
int
split_k_slices
=
1
);
const
GemmCoord
&
threadblock_shape
,
const
GemmCoord
&
warp_shape
,
int
split_k_slices
=
1
);
template
<
typename
GemvKernel
>
void
cutlass_vector_matrix_mul_batched_strided_wrapper
(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
void
cutlass_matrix_mul_float32_simt_gemv_batched_strided
(
const
float
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
float
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
float
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
BatchedGemmCoord
const
&
problem_size
,
int
threadblock_n
,
cudaStream_t
stream
);
}
// namespace cutlass_wrapper
}
// namespace cuda
}
// namespace megdnn
#endif
// vim: syntax=cuda.doxygen
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x2x4.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
128
,
16
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
2
,
4
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x16_1x4x2.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
128
,
16
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
4
,
2
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x2_1x1x1.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
128
,
2
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
1
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x32_1x4x4.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
128
,
32
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
4
,
4
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x1x2.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
128
,
4
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
2
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x4_1x2x1.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
128
,
4
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
2
,
1
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x1x4.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
128
,
8
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
4
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x2x2.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
128
,
8
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
2
,
2
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x128x8_1x4x1.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
128
,
8
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
4
,
1
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x128_1x4x4.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
32
,
128
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
4
,
4
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x1x2.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
32
,
16
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
2
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x16_1x2x1.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
32
,
16
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
2
,
1
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x1x4.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
32
,
32
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
4
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x2x2.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
32
,
32
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
2
,
2
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x32_1x4x1.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
32
,
32
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
4
,
1
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x2x4.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
32
,
64
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
2
,
4
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x64_1x4x2.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
32
,
64
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
4
,
2
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x32x8_1x1x1.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
32
,
8
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
1
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x1x4.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
64
,
16
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
4
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x2x2.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
64
,
16
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
2
,
2
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x16_1x4x1.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
64
,
16
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
4
,
1
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x2x4.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
64
,
32
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
2
,
4
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x32_1x4x2.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
64
,
32
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
4
,
2
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x4_1x1x1.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
64
,
4
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
1
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x64_1x4x4.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
64
,
64
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
4
,
4
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x1x2.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
64
,
8
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
2
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/kimpl/matrix_mul_fp32_simt_gemv_batched_strided_1x64x8_1x2x1.cu
0 → 100644
浏览文件 @
2de2222e
#if __CUDACC_VER_MAJOR__ > 9 || (__CUDACC_VER_MAJOR__ == 9 && __CUDACC_VER_MINOR__ >= 2)
// generated by gen_cutlass_gemv_batched_strided_kern_impls.py
// ignore warning of cutlass
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#include "src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl"
using
ThreadBlockShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
64
,
8
>
;
using
ThreadShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
2
,
1
>
;
using
GemvKernel
=
cutlass
::
gemm
::
kernel
::
DefaultGemv
<
ThreadBlockShape
,
ThreadShape
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
,
float
,
cutlass
::
layout
::
RowMajor
>
;
template
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
<
GemvKernel
>(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
);
#pragma GCC diagnostic pop
#endif
dnn/src/cuda/matrix_mul/fp32_simt_gemv/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl
0 → 100644
浏览文件 @
2de2222e
/**
* \file
* dnn/src/cuda/matrix_mul/matrix_mul_float_simt_gemv_batched_strided_cutlass_wrapper.cuinl
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "cutlass/gemm/kernel/default_gemv.h"
#include "cutlass/gemm/kernel/gemv_batched_strided.h"
#include "src/cuda/matrix_mul/cutlass_matrix_mul_wrapper.cuh"
#include "src/cuda/query_blocksize.cuh"
using
namespace
megdnn
;
using
namespace
cuda
;
using
namespace
cutlass_wrapper
;
template
<
typename
GemvKernel
>
void
megdnn
::
cuda
::
cutlass_wrapper
::
cutlass_vector_matrix_mul_batched_strided_wrapper
(
BatchedGemmCoord
const
&
problem_size
,
const
typename
GemvKernel
::
ElementA
*
d_A
,
size_t
lda
,
size_t
batch_stride_a
,
const
typename
GemvKernel
::
ElementB
*
d_B
,
size_t
ldb
,
size_t
batch_stride_b
,
typename
GemvKernel
::
ElementCD
*
d_C
,
size_t
ldc
,
size_t
batch_stride_c
,
cudaStream_t
stream
)
{
typename
GemvKernel
::
IteratorA
::
TensorRef
tensor_a
{
const_cast
<
typename
GemvKernel
::
ElementA
*>
(
d_A
),
typename
GemvKernel
::
LayoutA
{
static_cast
<
int
>
(
lda
)}};
typename
GemvKernel
::
IteratorB
::
TensorRef
tensor_b
{
const_cast
<
typename
GemvKernel
::
ElementB
*>
(
d_B
),
typename
GemvKernel
::
LayoutB
{
static_cast
<
int
>
(
ldb
)}};
typename
GemvKernel
::
IteratorCD
::
TensorRef
tensor_c
{
d_C
,
typename
GemvKernel
::
LayoutCD
{
static_cast
<
int
>
(
ldc
)}};
static
int
constexpr
kThreadsPerN
=
GemvKernel
::
Core
::
kThreadsPerN
;
static
int
constexpr
kThreadsPerK
=
GemvKernel
::
Core
::
kThreadsPerK
;
void
(
*
kern
)(
BatchedGemmCoord
,
typename
GemvKernel
::
IteratorA
::
TensorRef
,
typename
GemvKernel
::
IteratorA
::
TensorRef
::
LongIndex
,
typename
GemvKernel
::
IteratorB
::
TensorRef
,
typename
GemvKernel
::
IteratorB
::
TensorRef
::
LongIndex
,
typename
GemvKernel
::
IteratorCD
::
TensorRef
,
typename
GemvKernel
::
IteratorCD
::
TensorRef
::
LongIndex
);
kern
=
cutlass
::
gemm
::
kernel
::
GemvBatchedStrided
<
GemvKernel
>
;
// int nr_threads = static_cast<int>(
// query_blocksize_for_kernel(reinterpret_cast<const void*>(kern)));
// nr_threads = std::max(nr_threads, kThreadsPerN);
// megdnn_assert(nr_threads % kThreadsPerN == 0);
// int batch = nr_threads / kThreadsPerN;
// batch = std::min(batch, problem_size.batch());
auto
tile_size
=
BatchedGemmCoord
(
GemvKernel
::
ThreadBlockShape
::
kM
,
GemvKernel
::
ThreadBlockShape
::
kN
,
GemvKernel
::
ThreadBlockShape
::
kK
,
1
);
typename
GemvKernel
::
ThreadBlockSwizzle
swizzler
;
auto
tiled_shape
=
swizzler
.
get_tiled_shape
(
problem_size
,
tile_size
);
dim3
grid
=
swizzler
.
get_grid_shape
(
tiled_shape
);
dim3
block
(
kThreadsPerN
,
kThreadsPerK
,
1
);
int
smem_size
=
int
(
sizeof
(
typename
GemvKernel
::
ThreadBlockGemv
::
SharedStorage
));
megdnn_assert
(
smem_size
<
(
48
<<
10
));
kern
<<<
grid
,
block
,
smem_size
,
stream
>>>
(
problem_size
,
tensor_a
,
batch_stride_a
,
tensor_b
,
batch_stride_b
,
tensor_c
,
batch_stride_c
);
after_kernel_launch
();
}
// vim: syntax=cuda.doxygen
dnn/src/cuda/matrix_mul/opr_impl.h
浏览文件 @
2de2222e
...
...
@@ -41,8 +41,11 @@ public:
#if !MEGDNN_DISABLE_FLOAT16
class
AlgoBFloat16
;
#endif
#if CUDA_VERSION >= 9020
class
AlgoFloat32SIMT
;
class
AlgoFloat32SIMTSplitK
;
class
AlgoFloat32SIMTGemvBatchedStrided
;
#endif
class
AlgoPack
;
static
const
AlgoPack
&
algo_pack
()
{
...
...
dnn/test/cuda/cutlass_matmul.cpp
浏览文件 @
2de2222e
...
...
@@ -90,7 +90,7 @@ void test_multibatchsize(
if
(
std
::
regex_match
(
i
.
name
.
c_str
(),
std
::
regex
(
"("
+
std
::
string
(
algo
)
+
")(.*)"
)))
{
opr_reference
->
execution_policy
().
algo
=
i
;
opr_reference
->
execution_policy
().
algo
=
i
.
desc
;
break
;
}
}
...
...
@@ -119,7 +119,7 @@ void test_multibatchsize(
if
(
std
::
regex_match
(
i
.
name
.
c_str
(),
std
::
regex
(
"("
+
std
::
string
(
algo
)
+
")(.*)"
)))
{
opr_reference
->
execution_policy
().
algo
=
i
;
opr_reference
->
execution_policy
().
algo
=
i
.
desc
;
break
;
}
}
...
...
@@ -292,6 +292,30 @@ TEST_F(CUDA, CUTLASS_GEMM_SPLIT_K_MULTI_BATCHSIZE) {
[](
const
matrix_mul
::
TestArg
&
arg
)
{
return
arg
.
k
<=
arg
.
n
;
});
}
TEST_F
(
CUDA
,
CUTLASS_GEMV_BATCHED_STRIDED_128_MULTI_BATCHSIZE
)
{
auto
args
=
matrix_mul
::
get_matmul_args_no_mask
();
test_multibatchsize
(
handle_cuda
(),
dtype
::
Float32
(),
dtype
::
Float32
(),
dtype
::
Float32
(),
"CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_128"
,
args
,
param
::
MatrixMul
::
Format
::
DEFAULT
);
}
TEST_F
(
CUDA
,
CUTLASS_GEMV_BATCHED_STRIDED_64_MULTI_BATCHSIZE
)
{
auto
args
=
matrix_mul
::
get_matmul_args_no_mask
();
test_multibatchsize
(
handle_cuda
(),
dtype
::
Float32
(),
dtype
::
Float32
(),
dtype
::
Float32
(),
"CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_64"
,
args
,
param
::
MatrixMul
::
Format
::
DEFAULT
);
}
TEST_F
(
CUDA
,
CUTLASS_GEMV_BATCHED_STRIDED_32_MULTI_BATCHSIZE
)
{
auto
args
=
matrix_mul
::
get_matmul_args_no_mask
();
test_multibatchsize
(
handle_cuda
(),
dtype
::
Float32
(),
dtype
::
Float32
(),
dtype
::
Float32
(),
"CUTLASS_FLOAT32_SIMT_GEMV_BATCHED_STRIDED_32"
,
args
,
param
::
MatrixMul
::
Format
::
DEFAULT
);
}
#define MEGDNN_FOREACH_CUTLASS_KERNEL(cb) \
cb(1, 64, 256, 8, 32, 64, 8); \
cb(2, 256, 64, 8, 64, 32, 8); \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录