Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
02abc36e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
02abc36e
编写于
6月 11, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mbg/arm_common): fix nchw44-dot misc issue
GitOrigin-RevId: f870ad964c075fe55fb1c7ca131680873bec61eb
上级
9ed3882a
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
25 addition
and
14 deletion
+25
-14
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
...arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
+1
-1
dnn/src/arm_common/conv_bias/opr_impl.cpp
dnn/src/arm_common/conv_bias/opr_impl.cpp
+2
-2
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+13
-8
sdk/load-and-run/src/mgblar.cpp
sdk/load-and-run/src/mgblar.cpp
+6
-0
src/core/include/megbrain/utils/persistent_cache.h
src/core/include/megbrain/utils/persistent_cache.h
+0
-1
src/opr/impl/dnn/convolution.cpp
src/opr/impl/dnn/convolution.cpp
+0
-1
src/plugin/impl/opr_footprint.cpp
src/plugin/impl/opr_footprint.cpp
+3
-1
未找到文件。
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
浏览文件 @
02abc36e
...
...
@@ -182,7 +182,7 @@ bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable(
bool
ok_type
=
((
param
.
src_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
param
.
filter_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
&&
(
param
.
dst_type
.
enumv
()
==
DTypeEnum
::
QuantizedS8
)))
&&
(
fm
.
format
==
param
::
Convolution
::
Format
::
NCHW44
);
(
fm
.
format
==
param
::
Convolution
::
Format
::
NCHW44
_DOT
);
bool
ok_src_dst
=
(
oc
%
4
==
0
&&
oc
>=
4
&&
ic
<
4
);
bool
ok_filter
=
fm
.
spatial_ndim
==
2
&&
fh
==
fm
.
spatial
[
1
]
&&
(
fh
==
2
||
fh
==
3
||
fh
==
5
||
fh
==
7
);
...
...
dnn/src/arm_common/conv_bias/opr_impl.cpp
浏览文件 @
02abc36e
...
...
@@ -55,7 +55,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8ChanWiseStride2NCHW44
s8_channel_wise_stride2_nchw44
;
#if __ARM_FEATURE_DOTPROD
AlgoDotS8DirectNCHWNCHW44
ds8_direct_stride2_nchw_nchw44
;
AlgoDotS8DirectStride1
ds8_direct_stride1_large_group
{
true
};
AlgoDotS8DirectStride1
ds8_direct_stride1_small_group
{
false
};
AlgoDotS8DirectStride2
ds8_direct_stride2_large_group
{
true
};
...
...
@@ -66,6 +65,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoDotU8DirectStride2
du8_direct_stride2_small_group
{
false
};
AlgoDotS8Direct_NCHW44
ds8_direct_nchw44
;
AlgoDotS8DirectNCHWNCHW44
ds8_direct_nchw_nchw44
;
#endif
AlgoF32DirectNCHWNCHW44
f32_direct_stride2_nchw_nchw44
;
...
...
@@ -96,7 +96,6 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
public:
AlgoPack
()
{
#if __ARM_FEATURE_DOTPROD
direct_algos
.
emplace_back
(
&
ds8_direct_stride2_nchw_nchw44
);
direct_algos
.
emplace_back
(
&
ds8_direct_stride1_large_group
);
direct_algos
.
emplace_back
(
&
ds8_direct_stride1_small_group
);
direct_algos
.
emplace_back
(
&
ds8_direct_stride2_large_group
);
...
...
@@ -107,6 +106,7 @@ public:
direct_algos
.
emplace_back
(
&
du8_direct_stride2_small_group
);
direct_algos
.
emplace_back
(
&
ds8_direct_nchw44
);
direct_algos
.
emplace_back
(
&
ds8_direct_nchw_nchw44
);
#endif
direct_algos
.
emplace_back
(
&
qu8_direct_stride2_large_group
);
direct_algos
.
emplace_back
(
&
qu8_direct_stride2_small_group
);
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
02abc36e
...
...
@@ -582,14 +582,19 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2_SMALL_GROUP) {
/****************************dot qint8 direct*************************/
#if __ARM_FEATURE_DOTPROD
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_DOT_NCHW_NCHW44
)
{
checker_conv_bias_qint8x8x8
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
false
,
false
,
true
),
handle
(),
"ARMDOTS8_NCHW_NCHW44"
);
checker_conv_bias_qint8x8x8
(
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
1
,
false
,
false
,
false
,
true
),
handle
(),
"ARMDOTS8_NCHW_NCHW44"
);
auto
args
=
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
2
,
false
,
false
,
false
,
true
);
for
(
auto
&&
arg
:
args
)
{
arg
.
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
}
checker_conv_bias_qint8x8x8
(
args
,
handle
(),
"ARMDOTS8_NCHW_NCHW44"
);
args
=
get_nchw44_conv_bias_args
({
2
,
3
,
5
,
7
},
1
,
false
,
false
,
false
,
true
);
for
(
auto
&&
arg
:
args
)
{
arg
.
param
.
format
=
param
::
ConvBias
::
Format
::
NCHW44_DOT
;
}
checker_conv_bias_qint8x8x8
(
args
,
handle
(),
"ARMDOTS8_NCHW_NCHW44"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_INT8_STRIDE1_WITHDOTPROD_LARGE_GROUP
)
{
...
...
sdk/load-and-run/src/mgblar.cpp
浏览文件 @
02abc36e
...
...
@@ -987,6 +987,12 @@ Args Args::from_argv(int argc, char **argv) {
cb
(
nchw32
);
cb
(
nhwcd4
);
#undef cb
if
(
!
strcmp
(
argv
[
i
],
"--enable-nchw44-dot"
))
{
mgb_log_warn
(
"enable-nchw44-dot optimization"
);
graph_opt
.
graph_opt
.
enable_nchw44_dot
();
continue
;
}
if
(
!
strcmp
(
argv
[
i
],
"--enable-fuse-conv-bias-nonlinearity"
))
{
mgb_log_warn
(
"enable fuse-conv-bias-nonlinearity optimization"
);
graph_opt
.
graph_opt
.
enable_fuse_conv_bias_nonlinearity
();
...
...
src/core/include/megbrain/utils/persistent_cache.h
浏览文件 @
02abc36e
...
...
@@ -94,7 +94,6 @@ namespace mgb {
m_param
{
param
},
m_param_size
{
param_size
}
{
}
//! build a blob representation to be used as cache key
PersistentCache
::
Blob
build_blob
()
const
;
};
...
...
src/opr/impl/dnn/convolution.cpp
浏览文件 @
02abc36e
...
...
@@ -611,7 +611,6 @@ AlgoChooserProfileCache::Result AlgoChooser<Opr>::get_profile_result(
AlgoChooserProfileCache
::
Key
cache_key
{
origin_layouts
.
data
(),
origin_layouts
.
size
(),
&
origin_param
,
sizeof
(
origin_param
)};
{
auto
&&
rst
=
cache
.
get
(
cache_key
);
if
(
rst
.
valid
())
...
...
src/plugin/impl/opr_footprint.cpp
浏览文件 @
02abc36e
...
...
@@ -107,7 +107,8 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
src_shape
[
1
]
/
group
*
2
;
return
hybird_nchwx
?
computation
:
computation
*
8
;
}
if
(
param
.
format
==
Param
::
Format
::
NCHW44
)
{
if
(
param
.
format
==
Param
::
Format
::
NCHW44
||
param
.
format
==
Param
::
Format
::
NCHW44_DOT
)
{
//! if channel wise weight layout is {group/4, FH, FW, 1, 1, 4}
if
(
filter_shape
[
1
]
==
1
&&
filter_shape
[
2
]
==
1
)
{
group
*=
4
;
...
...
@@ -145,6 +146,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape,
if
(
param
.
format
==
Param
::
Format
::
NCHW4
||
param
.
format
==
Param
::
Format
::
NCHW88
||
param
.
format
==
Param
::
Format
::
NCHW44
||
param
.
format
==
Param
::
Format
::
NCHW44_DOT
||
param
.
format
==
Param
::
Format
::
NCHW32
)
{
return
eval_conv_computation_nchwx
();
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录