Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3b452d8c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3b452d8c
编写于
7月 28, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mgb): cuda conv support nhwc format and fp16 dtype
GitOrigin-RevId: b8ddcd108a4370a0b093c51bd90ebde0e007cb24
上级
10bcf757
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
45 addition
and
4 deletion
+45
-4
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
+6
-0
dnn/src/cuda/conv_bias/helper.cpp
dnn/src/cuda/conv_bias/helper.cpp
+4
-4
dnn/test/cuda/conv_bias.cpp
dnn/test/cuda/conv_bias.cpp
+35
-0
未找到文件。
dnn/src/cuda/conv_bias/cudnn_conv_bias_activation.cpp
浏览文件 @
3b452d8c
...
...
@@ -69,6 +69,12 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available(
return
false
;
}
if
(
args
.
src_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
Float16
&&
args
.
dst_layout
->
dtype
.
enumv
()
==
DTypeEnum
::
Float16
&&
param
.
format
==
param
::
ConvBias
::
Format
::
NHWC
)
{
return
false
;
}
//! FIXME: conv kernel of cudnn for NCHW4_NCHW tensor format causes illegal
//! memory access errors, so we have to disable this kernel here.
if
(
param
.
format
==
param
::
ConvBias
::
Format
::
NCHW4_NCHW
||
...
...
dnn/src/cuda/conv_bias/helper.cpp
浏览文件 @
3b452d8c
...
...
@@ -151,14 +151,14 @@ bool is_cudnn_supported(const BiasForwardSizeArgs& args) {
if
(
args
.
handle
->
is_tegra_k1
())
return
false
;
// TODO: We only support NCHW format now. It seems cuDNN provides support
// for NHWC as well.
if
(
args
.
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW4
)
{
if
(
args
.
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW4
||
args
.
filter_meta
.
format
==
param
::
Convolution
::
Format
::
NCHW32
)
{
if
(
args
.
dst_layout
->
dtype
.
enumv
()
!=
DTypeEnum
::
Int8
&&
args
.
dst_layout
->
dtype
.
enumv
()
!=
DTypeEnum
::
QuantizedS8
)
{
return
false
;
}
}
else
if
(
args
.
filter_meta
.
format
!=
param
::
Convolution
::
Format
::
NCHW
)
{
}
else
if
(
args
.
filter_meta
.
format
!=
param
::
Convolution
::
Format
::
NCHW
&&
args
.
filter_meta
.
format
!=
param
::
Convolution
::
Format
::
NHWC
)
{
return
false
;
}
auto
&
fm
=
args
.
filter_meta
;
...
...
dnn/test/cuda/conv_bias.cpp
浏览文件 @
3b452d8c
...
...
@@ -216,6 +216,41 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_QS8) {
}
}
TEST_F
(
CUDA
,
CONV_BIAS_FORWARD_FLOAT16
)
{
require_compute_capability
(
6
,
1
);
Checker
<
ConvBiasForward
>
checker
(
handle_cuda
());
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NHWC
;
param
.
nonlineMode
=
ConvBias
::
Param
::
NonlineMode
::
IDENTITY
;
checker
.
set_epsilon
(
2e-2
)
.
set_dtype
(
0
,
dtype
::
Float16
())
.
set_dtype
(
1
,
dtype
::
Float16
())
.
set_dtype
(
2
,
dtype
::
Float16
())
.
set_dtype
(
3
,
dtype
::
Float16
())
.
set_dtype
(
4
,
dtype
::
Float16
());
{
auto
src_shape
=
TensorShape
{
20
,
224
,
224
,
4
};
auto
filter_shape
=
TensorShape
{
24
,
1
,
1
,
4
};
auto
bias_shape
=
TensorShape
{
1
,
1
,
1
,
24
};
checker
.
set_param
(
param
).
execs
(
{
src_shape
,
filter_shape
,
bias_shape
,
{},
{}});
param
.
compute_mode
=
ConvBias
::
Param
::
ComputeMode
::
FLOAT32
;
checker
.
set_param
(
param
).
execs
(
{
src_shape
,
filter_shape
,
bias_shape
,
{},
{}});
}
{
param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
GROUP
;
auto
src_shape
=
TensorShape
{
20
,
224
,
224
,
16
};
auto
filter_shape
=
TensorShape
{
4
,
4
,
1
,
1
,
4
};
auto
bias_shape
=
TensorShape
{
1
,
1
,
1
,
16
};
checker
.
set_param
(
param
).
execs
(
{
src_shape
,
filter_shape
,
bias_shape
,
{},
{}});
}
}
TEST_F
(
CUDA
,
CONV_BIAS_NCHW_QS8
)
{
//! not support NonlineMode::SIGMOID and NonlineMode::H_SWISH
require_compute_capability
(
6
,
1
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录