Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e9db061e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
396
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e9db061e
编写于
2月 05, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb): fix compiling error for cuda-11.1
GitOrigin-RevId: f63e71afa75160746f0d69c67282b3a18b544ed1
上级
cd02d7c8
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
36 addition
and
15 deletion
+36
-15
dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp
dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp
+34
-14
dnn/src/cuda/matrix_mul/cublasLt_wrapper.h
dnn/src/cuda/matrix_mul/cublasLt_wrapper.h
+2
-1
未找到文件。
dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp
浏览文件 @
e9db061e
...
...
@@ -12,6 +12,7 @@
#include "src/common/utils.h"
#include "src/cuda/utils.h"
#if CUDA_VERSION >= 10010
namespace
megdnn
{
namespace
cuda
{
static
cudaDataType_t
to_cuda_dtype
(
DType
tp
)
{
...
...
@@ -31,6 +32,22 @@ static cudaDataType_t to_cuda_dtype(DType tp) {
"dtype must be float16/float32/int8/qs8/int32"
));
}
}
static
cublasComputeType_t
to_cublas_compute_type
(
DType
tp
)
{
switch
(
tp
.
enumv
())
{
case
DTypeEnum
::
Float16
:
return
CUBLAS_COMPUTE_16F
;
case
DTypeEnum
::
Float32
:
return
CUBLAS_COMPUTE_32F
;
case
DTypeEnum
::
Int32
:
case
DTypeEnum
::
QuantizedS32
:
return
CUBLAS_COMPUTE_32I
;
default:
megdnn_throw
(
megdnn_mangle
(
"dtype must be float16/float32/int32/Qs32"
));
}
}
static
const
char
*
cuda_type_to_str
(
cudaDataType_t
tp
)
{
switch
(
tp
)
{
case
CUDA_R_16F
:
...
...
@@ -46,6 +63,7 @@ static const char* cuda_type_to_str(cudaDataType_t tp) {
megdnn_mangle
(
"dtype must be float16/float32/int8/int32"
));
}
}
static
size_t
cuda_dtype_size
(
cudaDataType_t
dt
)
{
switch
(
dt
)
{
case
CUDA_R_8I
:
...
...
@@ -60,6 +78,7 @@ static size_t cuda_dtype_size(cudaDataType_t dt) {
megdnn_mangle
(
"dtype must be float16/float32/int8/int32"
));
}
}
CUBLASLTMatmulDesc
::~
CUBLASLTMatmulDesc
()
{
if
(
matmul_desc
)
cublas_check
(
cublasLtMatmulDescDestroy
(
matmul_desc
));
...
...
@@ -86,9 +105,10 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) {
uint32_t
pm
=
CUBLAS_POINTER_MODE_DEVICE
;
dt_b
=
to_cuda_dtype
(
args
.
layout_b
.
dtype
);
dt_a
=
to_cuda_dtype
(
args
.
layout_a
.
dtype
);
dt_compute
=
dt_c
=
to_cuda_dtype
(
args
.
layout_c
.
dtype
);
dt_c
=
to_cuda_dtype
(
args
.
layout_c
.
dtype
);
dt_compute
=
to_cublas_compute_type
(
args
.
layout_c
.
dtype
);
megdnn_assert
(
dt_a
==
dt_b
,
"matrix A and B should have same precision"
);
cublas_check
(
cublasLtMatmulDescCreate
(
&
matmul_desc
,
dt_compute
));
cublas_check
(
cublasLtMatmulDescCreate
(
&
matmul_desc
,
dt_compute
,
dt_c
));
cublas_check
(
cublasLtMatmulDescSetAttribute
(
matmul_desc
,
CUBLASLT_MATMUL_DESC_POINTER_MODE
,
&
pm
,
sizeof
(
pm
)));
...
...
@@ -100,7 +120,7 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) {
* So we calculate C^t = B^t * A^t by cublas. Here the transpose symbol
* implies row-major to column-major conversion
*/
if
(
dt_c
ompute
==
CUDA_R_32I
)
{
if
(
dt_c
==
CUDA_R_32I
)
{
/**
* \NOTE: To use IMMA kernels, use computeType = CUDA_R_32I and
* CUBLASLT_ORDER_COL32 for matrices A,C,D and
...
...
@@ -209,7 +229,7 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) {
bool
CUBLASLTMatmulDesc
::
is_available
(
const
SizeArgs
&
args
,
size_t
ws_limit
)
{
bool
support
;
cublasLtMatmulAlgo_t
algo
;
switch
(
dt_c
ompute
)
{
switch
(
dt_c
)
{
case
CUDA_R_16F
:
support
=
(
dt_a
==
CUDA_R_16F
);
break
;
...
...
@@ -239,17 +259,17 @@ WorkspaceBundle CUBLASLTMatmulDesc::get_workspace_bundle(
cublasLtMatmulHeuristicResult_t
result
{};
status
=
cublasLtMatmulAlgoCheck
(
cublasLt_handle
,
matmul_desc
,
dt_c
ompute
==
CUDA_R_32I
?
layout_trans_b
:
layout_b
,
dt_c
ompute
==
CUDA_R_32I
?
layout_trans_a
:
layout_a
,
dt_c
ompute
==
CUDA_R_32I
?
layout_trans_c
:
layout_c
,
dt_c
ompute
==
CUDA_R_32I
?
layout_trans_c
:
layout_c
,
&
algo
,
dt_c
==
CUDA_R_32I
?
layout_trans_b
:
layout_b
,
dt_c
==
CUDA_R_32I
?
layout_trans_a
:
layout_a
,
dt_c
==
CUDA_R_32I
?
layout_trans_c
:
layout_c
,
dt_c
==
CUDA_R_32I
?
layout_trans_c
:
layout_c
,
&
algo
,
&
result
);
// return empty WorkspaceBundle if cublasLtMatmulAlgoCheck() failed
if
(
status
!=
CUBLAS_STATUS_SUCCESS
)
return
{
nullptr
,
{}};
algo_workspace_size
=
result
.
workspaceSize
;
return
{
nullptr
,
(
dt_c
ompute
==
CUDA_R_32I
)
(
dt_c
==
CUDA_R_32I
)
?
SmallVector
<
size_t
>
{
algo_workspace_size
,
workspace_b
,
workspace_a
,
workspace_c
}
:
SmallVector
<
size_t
>
{
algo_workspace_size
}};
...
...
@@ -273,7 +293,7 @@ bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args,
* \Note: algo_ws_limit must be zero if cublasLtGetVersion() <= 10100
*/
// algo_ws_limit = 0;
if
(
dt_c
ompute
==
CUDA_R_32I
)
{
if
(
dt_c
==
CUDA_R_32I
)
{
//[FIXME]: cublasLt(Version 10020) produce wrong result when k in
//[64*n+1 , 64*n+32] for small matrix
...
...
@@ -291,10 +311,10 @@ bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args,
sizeof
(
algo_ws_limit
)));
status
=
cublasLtMatmulAlgoGetHeuristic
(
cublasLt_handle
,
matmul_desc
,
dt_c
ompute
==
CUDA_R_32I
?
layout_trans_b
:
layout_b
,
dt_c
ompute
==
CUDA_R_32I
?
layout_trans_a
:
layout_a
,
dt_c
ompute
==
CUDA_R_32I
?
layout_trans_c
:
layout_c
,
dt_c
ompute
==
CUDA_R_32I
?
layout_trans_c
:
layout_c
,
algo_pref
,
1
,
dt_c
==
CUDA_R_32I
?
layout_trans_b
:
layout_b
,
dt_c
==
CUDA_R_32I
?
layout_trans_a
:
layout_a
,
dt_c
==
CUDA_R_32I
?
layout_trans_c
:
layout_c
,
dt_c
==
CUDA_R_32I
?
layout_trans_c
:
layout_c
,
algo_pref
,
1
,
&
algo_result
,
&
return_algo_count
);
if
(
status
==
CUBLAS_STATUS_SUCCESS
&&
return_algo_count
>
0
&&
// perform cublasLtAlgoCheck() to make sure the algo is correct
...
...
dnn/src/cuda/matrix_mul/cublasLt_wrapper.h
浏览文件 @
e9db061e
...
...
@@ -47,7 +47,8 @@ struct CUBLASLTMatmulDesc {
};
bool
is_batched
;
cublasLtMatmulDesc_t
matmul_desc
;
cudaDataType_t
dt_a
,
dt_b
,
dt_c
,
dt_compute
;
cudaDataType_t
dt_a
,
dt_b
,
dt_c
;
cublasComputeType_t
dt_compute
;
cublasLtMatrixLayout_t
layout_a
,
layout_b
,
layout_c
;
cublasLtMatrixLayout_t
layout_trans_a
,
layout_trans_b
,
layout_trans_c
;
size_t
workspace_a
,
workspace_b
,
workspace_c
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录