Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a0a5fcf1
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a0a5fcf1
编写于
4月 19, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): support tf32
GitOrigin-RevId: 9e5871f933744468b91b7ab5ac6159a4b7a67084
上级
f0088335
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
23 addition
and
5 deletion
+23
-5
dnn/src/cuda/batched_matrix_mul/cublas.cpp
dnn/src/cuda/batched_matrix_mul/cublas.cpp
+8
-0
dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp
dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp
+1
-1
dnn/src/cuda/matrix_mul/cublas.cpp
dnn/src/cuda/matrix_mul/cublas.cpp
+12
-2
dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp
dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp
+1
-1
dnn/src/cuda/matrix_mul/cublas_lt.cpp
dnn/src/cuda/matrix_mul/cublas_lt.cpp
+1
-1
未找到文件。
dnn/src/cuda/batched_matrix_mul/cublas.cpp
浏览文件 @
a0a5fcf1
...
@@ -88,7 +88,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const {
...
@@ -88,7 +88,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const {
#if CUDART_VERSION >= 9010
#if CUDART_VERSION >= 9010
auto
io16_c32
=
[
&
]()
{
auto
io16_c32
=
[
&
]()
{
#if CUDART_VERSION >= 11000
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TF32_TENSOR_OP_MATH
));
#else
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TENSOR_OP_MATH
));
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TENSOR_OP_MATH
));
#endif
auto
zero
=
handle
->
zero_device
();
auto
zero
=
handle
->
zero_device
();
auto
one
=
handle
->
one_device
();
auto
one
=
handle
->
one_device
();
cublas_check
(
cublasGemmBatchedEx
(
cublas_check
(
cublasGemmBatchedEx
(
...
@@ -104,7 +108,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const {
...
@@ -104,7 +108,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const {
#if CUDART_VERSION >= 9000
#if CUDART_VERSION >= 9000
auto
io16_c16
=
[
&
]()
{
auto
io16_c16
=
[
&
]()
{
#if CUDART_VERSION >= 11000
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TF32_TENSOR_OP_MATH
));
#else
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TENSOR_OP_MATH
));
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TENSOR_OP_MATH
));
#endif
auto
zero
=
handle
->
zero_device_h
();
auto
zero
=
handle
->
zero_device_h
();
auto
one
=
handle
->
one_device_h
();
auto
one
=
handle
->
one_device_h
();
cublas_check
(
cublasHgemmBatched
(
cublas_check
(
cublasHgemmBatched
(
...
...
dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp
浏览文件 @
a0a5fcf1
...
@@ -124,7 +124,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const
...
@@ -124,7 +124,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const
batched_igemm
();
batched_igemm
();
}
else
if
(
desc
.
dt_compute
==
CUBLAS_COMPUTE_16F
)
{
}
else
if
(
desc
.
dt_compute
==
CUBLAS_COMPUTE_16F
)
{
batched_hgemm
();
batched_hgemm
();
}
else
if
(
desc
.
dt_compute
==
CUBLAS_COMPUTE_32F
)
{
}
else
if
(
desc
.
dt_compute
==
CUBLAS_COMPUTE_32F
_FAST_TF32
)
{
batched_sgemm
();
batched_sgemm
();
}
else
{
}
else
{
megdnn_throw
(
"compute_type must be int32/float16/float32"
);
megdnn_throw
(
"compute_type must be int32/float16/float32"
);
...
...
dnn/src/cuda/matrix_mul/cublas.cpp
浏览文件 @
a0a5fcf1
...
@@ -49,18 +49,26 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const {
...
@@ -49,18 +49,26 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const {
auto
sgemm
=
[
&
]()
{
auto
sgemm
=
[
&
]()
{
auto
zero
=
handle
->
zero_device
();
auto
zero
=
handle
->
zero_device
();
auto
one
=
handle
->
one_device
();
auto
one
=
handle
->
one_device
();
#if CUDART_VERSION >= 11000
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TF32_TENSOR_OP_MATH
));
#endif
cublas_check
(
cublasSgemm
(
cublas_check
(
cublasSgemm
(
cublas_handle
,
param
.
transposeB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
cublas_handle
,
param
.
transposeB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
param
.
transposeA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
n
,
m
,
k
,
one
,
param
.
transposeA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
,
n
,
m
,
k
,
one
,
args
.
tensor_b
.
ptr
<
dt_float32
>
(),
args
.
tensor_b
.
layout
.
stride
[
0
],
args
.
tensor_b
.
ptr
<
dt_float32
>
(),
args
.
tensor_b
.
layout
.
stride
[
0
],
args
.
tensor_a
.
ptr
<
dt_float32
>
(),
args
.
tensor_a
.
layout
.
stride
[
0
],
zero
,
args
.
tensor_a
.
ptr
<
dt_float32
>
(),
args
.
tensor_a
.
layout
.
stride
[
0
],
zero
,
args
.
tensor_c
.
ptr
<
dt_float32
>
(),
args
.
tensor_c
.
layout
.
stride
[
0
]));
args
.
tensor_c
.
ptr
<
dt_float32
>
(),
args
.
tensor_c
.
layout
.
stride
[
0
]));
#if CUDART_VERSION >= 11000
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_DEFAULT_MATH
));
#endif
};
};
auto
sgemm_ex
=
[
&
]()
{
auto
sgemm_ex
=
[
&
]()
{
auto
zero
=
handle
->
zero_device
();
auto
zero
=
handle
->
zero_device
();
auto
one
=
handle
->
one_device
();
auto
one
=
handle
->
one_device
();
#if CUDART_VERSION >= 9000
#if CUDART_VERSION >= 11000
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TF32_TENSOR_OP_MATH
));
#elif CUDART_VERSION >= 9000
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TENSOR_OP_MATH
));
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TENSOR_OP_MATH
));
#endif
#endif
auto
sgemm_ex_err
=
cublasSgemmEx
(
auto
sgemm_ex_err
=
cublasSgemmEx
(
...
@@ -78,7 +86,9 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const {
...
@@ -78,7 +86,9 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const {
};
};
auto
hgemm
=
[
&
]()
{
auto
hgemm
=
[
&
]()
{
#if CUDART_VERSION >= 9000
#if CUDART_VERSION >= 11000
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TF32_TENSOR_OP_MATH
));
#elif CUDART_VERSION >= 9000
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TENSOR_OP_MATH
));
cublas_check
(
cublasSetMathMode
(
cublas_handle
,
CUBLAS_TENSOR_OP_MATH
));
#endif
#endif
auto
one_half
=
handle
->
one_device_h
();
auto
one_half
=
handle
->
one_device_h
();
...
...
dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp
浏览文件 @
a0a5fcf1
...
@@ -28,7 +28,7 @@ static cublasComputeType_t to_cublas_compute_type(DType tp) {
...
@@ -28,7 +28,7 @@ static cublasComputeType_t to_cublas_compute_type(DType tp) {
case
DTypeEnum
::
Float16
:
case
DTypeEnum
::
Float16
:
return
CUBLAS_COMPUTE_16F
;
return
CUBLAS_COMPUTE_16F
;
case
DTypeEnum
::
Float32
:
case
DTypeEnum
::
Float32
:
return
CUBLAS_COMPUTE_32F
;
return
CUBLAS_COMPUTE_32F
_FAST_TF32
;
case
DTypeEnum
::
Int32
:
case
DTypeEnum
::
Int32
:
case
DTypeEnum
::
QuantizedS32
:
case
DTypeEnum
::
QuantizedS32
:
return
CUBLAS_COMPUTE_32I
;
return
CUBLAS_COMPUTE_32I
;
...
...
dnn/src/cuda/matrix_mul/cublas_lt.cpp
浏览文件 @
a0a5fcf1
...
@@ -107,7 +107,7 @@ void MatrixMulForwardImpl::AlgoCuBlasLt::exec(const ExecArgs& args) const {
...
@@ -107,7 +107,7 @@ void MatrixMulForwardImpl::AlgoCuBlasLt::exec(const ExecArgs& args) const {
case
CUBLAS_COMPUTE_16F
:
case
CUBLAS_COMPUTE_16F
:
hgemm
();
hgemm
();
break
;
break
;
case
CUBLAS_COMPUTE_32F
:
case
CUBLAS_COMPUTE_32F
_FAST_TF32
:
sgemm
();
sgemm
();
break
;
break
;
case
CUBLAS_COMPUTE_32I
:
case
CUBLAS_COMPUTE_32I
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录