Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2d6476a4
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2d6476a4
编写于
6月 30, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(lite): add auto decide model inference format option
GitOrigin-RevId: fcbf945de59a8d9a861e3605a40e69c942c05f4e
上级
10a0349e
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
162 addition
and
14 deletion
+162
-14
lite/include/lite/network.h
lite/include/lite/network.h
+4
-0
lite/lite-c/include/lite-c/network_c.h
lite/lite-c/include/lite-c/network_c.h
+4
-0
lite/lite-c/src/network.cpp
lite/lite-c/src/network.cpp
+4
-1
lite/pylite/megenginelite/network.py
lite/pylite/megenginelite/network.py
+7
-1
lite/src/mge/network_impl.cpp
lite/src/mge/network_impl.cpp
+107
-10
lite/src/mge/network_impl.h
lite/src/mge/network_impl.h
+7
-2
lite/test/test_network_options.cpp
lite/test/test_network_options.cpp
+29
-0
未找到文件。
lite/include/lite/network.h
浏览文件 @
2d6476a4
...
@@ -114,6 +114,9 @@ struct LITE_API Options {
...
@@ -114,6 +114,9 @@ struct LITE_API Options {
* model is not pack json information data inside
* model is not pack json information data inside
*
*
* @param options configuration of Options
* @param options configuration of Options
*
* @param auto_optimize_inference lite will detect the device information add
* set the options heuristically
*/
*/
struct
LITE_API
Config
{
struct
LITE_API
Config
{
bool
has_compression
=
false
;
bool
has_compression
=
false
;
...
@@ -122,6 +125,7 @@ struct LITE_API Config {
...
@@ -122,6 +125,7 @@ struct LITE_API Config {
LiteBackend
backend
=
LiteBackend
::
LITE_DEFAULT
;
LiteBackend
backend
=
LiteBackend
::
LITE_DEFAULT
;
std
::
string
bare_model_cryption_name
=
{};
std
::
string
bare_model_cryption_name
=
{};
Options
options
=
{};
Options
options
=
{};
bool
auto_optimize_inference
=
false
;
};
};
/*!
/*!
...
...
lite/lite-c/include/lite-c/network_c.h
浏览文件 @
2d6476a4
...
@@ -100,6 +100,9 @@ extern LITE_API const LiteOptions default_option;
...
@@ -100,6 +100,9 @@ extern LITE_API const LiteOptions default_option;
*
*
*\param has_compression flag whether the model is compressed, the compress
*\param has_compression flag whether the model is compressed, the compress
*method will read form the model
*method will read form the model
*\param auto_optimize_inference lite will detect the device information add
* set the options heuristically
*/
*/
typedef
struct
LiteConfig
{
typedef
struct
LiteConfig
{
int
has_compression
;
int
has_compression
;
...
@@ -108,6 +111,7 @@ typedef struct LiteConfig {
...
@@ -108,6 +111,7 @@ typedef struct LiteConfig {
LiteBackend
backend
;
LiteBackend
backend
;
const
char
*
bare_model_cryption_name
;
const
char
*
bare_model_cryption_name
;
LiteOptions
options
;
LiteOptions
options
;
int
auto_optimize_inference
;
}
LiteConfig
;
}
LiteConfig
;
//! get default config
//! get default config
...
...
lite/lite-c/src/network.cpp
浏览文件 @
2d6476a4
...
@@ -42,7 +42,8 @@ LiteConfig default_config_t = {
...
@@ -42,7 +42,8 @@ LiteConfig default_config_t = {
.
device_type
=
LiteDeviceType
::
LITE_CPU
,
.
device_type
=
LiteDeviceType
::
LITE_CPU
,
.
backend
=
LiteBackend
::
LITE_DEFAULT
,
.
backend
=
LiteBackend
::
LITE_DEFAULT
,
.
bare_model_cryption_name
=
nullptr
,
.
bare_model_cryption_name
=
nullptr
,
.
options
=
default_option
};
.
options
=
default_option
,
.
auto_optimize_inference
=
false
};
LiteConfig
*
default_config
()
{
LiteConfig
*
default_config
()
{
return
&
default_config_t
;
return
&
default_config_t
;
}
}
...
@@ -133,6 +134,8 @@ lite::Config convert_to_lite_config(const LiteConfig c_config) {
...
@@ -133,6 +134,8 @@ lite::Config convert_to_lite_config(const LiteConfig c_config) {
lite_config
.
options
.
enable_nchw32
=
c_config
.
options
.
enable_nchw32
;
lite_config
.
options
.
enable_nchw32
=
c_config
.
options
.
enable_nchw32
;
lite_config
.
options
.
enable_nchw64
=
c_config
.
options
.
enable_nchw64
;
lite_config
.
options
.
enable_nchw64
=
c_config
.
options
.
enable_nchw64
;
lite_config
.
auto_optimize_inference
=
c_config
.
auto_optimize_inference
;
return
lite_config
;
return
lite_config
;
}
}
...
...
lite/pylite/megenginelite/network.py
浏览文件 @
2d6476a4
...
@@ -171,15 +171,18 @@ class LiteConfig(Structure):
...
@@ -171,15 +171,18 @@ class LiteConfig(Structure):
options: configuration of Options
options: configuration of Options
auto_optimize_inference: lite will detect the device information add set the options heuristically
Examples:
Examples:
.. code-block::
.. code-block::
from megenginelite import *
from megenginelite import *
config = LiteConfig()
config = LiteConfig()
config.has_compression =
f
alse
config.has_compression =
F
alse
config.device_type = LiteDeviceType.LITE_CPU
config.device_type = LiteDeviceType.LITE_CPU
config.backend = LiteBackend.LITE_DEFAULT
config.backend = LiteBackend.LITE_DEFAULT
config.bare_model_cryption_name = "AES_default".encode("utf-8")
config.bare_model_cryption_name = "AES_default".encode("utf-8")
config.auto_optimize_inference = False
"""
"""
_fields_
=
[
_fields_
=
[
...
@@ -189,6 +192,7 @@ class LiteConfig(Structure):
...
@@ -189,6 +192,7 @@ class LiteConfig(Structure):
(
"backend"
,
c_int
),
(
"backend"
,
c_int
),
(
"_bare_model_cryption_name"
,
c_char_p
),
(
"_bare_model_cryption_name"
,
c_char_p
),
(
"options"
,
LiteOptions
),
(
"options"
,
LiteOptions
),
(
"auto_optimize_inference"
,
c_int
),
]
]
def
__init__
(
self
,
device_type
=
LiteDeviceType
.
LITE_CPU
,
option
=
None
):
def
__init__
(
self
,
device_type
=
LiteDeviceType
.
LITE_CPU
,
option
=
None
):
...
@@ -202,6 +206,7 @@ class LiteConfig(Structure):
...
@@ -202,6 +206,7 @@ class LiteConfig(Structure):
self
.
use_loader_dynamic_param
=
0
self
.
use_loader_dynamic_param
=
0
self
.
has_compression
=
0
self
.
has_compression
=
0
self
.
backend
=
LiteBackend
.
LITE_DEFAULT
self
.
backend
=
LiteBackend
.
LITE_DEFAULT
self
.
auto_optimize_inference
=
0
@
property
@
property
def
bare_model_cryption_name
(
self
):
def
bare_model_cryption_name
(
self
):
...
@@ -223,6 +228,7 @@ class LiteConfig(Structure):
...
@@ -223,6 +228,7 @@ class LiteConfig(Structure):
"backend"
:
LiteBackend
(
self
.
backend
),
"backend"
:
LiteBackend
(
self
.
backend
),
"bare_model_cryption_name"
:
self
.
bare_model_cryption_name
,
"bare_model_cryption_name"
:
self
.
bare_model_cryption_name
,
"options"
:
self
.
options
,
"options"
:
self
.
options
,
"auto_optimize_inference"
:
self
.
auto_optimize_inference
,
}
}
return
data
.
__repr__
()
return
data
.
__repr__
()
...
...
lite/src/mge/network_impl.cpp
浏览文件 @
2d6476a4
...
@@ -21,6 +21,10 @@
...
@@ -21,6 +21,10 @@
#include "megcore_opencl.h"
#include "megcore_opencl.h"
#endif
#endif
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
#include <fstream>
#include <fstream>
#include <memory>
#include <memory>
#include <set>
#include <set>
...
@@ -42,14 +46,7 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) {
...
@@ -42,14 +46,7 @@ void NetworkImplDft::shared_weight_with(const NetworkImplBase* src_network) {
LITE_ASSERT
(
src_impl
.
m_loader
,
"Clone network must after the network is loaded."
);
LITE_ASSERT
(
src_impl
.
m_loader
,
"Clone network must after the network is loaded."
);
m_load_result
=
src_impl
.
m_loader
->
load
(
m_load_config
,
true
);
m_load_result
=
src_impl
.
m_loader
->
load
(
m_load_config
,
true
);
//! flag weather the mode is cross compnode model
configure_after_loaded
();
cross_compnode_model_detect
();
//! update the IO of the network
update_io
();
//! replace the IO when there is device input or output
compile_graph
();
}
}
void
NetworkImplDft
::
application_config
()
{
void
NetworkImplDft
::
application_config
()
{
...
@@ -364,7 +361,7 @@ void NetworkImplDft::adapt_option_valid() {
...
@@ -364,7 +361,7 @@ void NetworkImplDft::adapt_option_valid() {
}
}
}
}
void
NetworkImplDft
::
global_layout_transform
()
{
void
NetworkImplDft
::
layout_transform_optimization
()
{
if
(
m_set_layout_transform
)
{
if
(
m_set_layout_transform
)
{
mgb
::
ThinHashMap
<
mgb
::
SymbolVar
,
mgb
::
SymbolVar
>
out_var_map
;
mgb
::
ThinHashMap
<
mgb
::
SymbolVar
,
mgb
::
SymbolVar
>
out_var_map
;
auto
output_var_array
=
mgb
::
gopt
::
layout_transform
(
auto
output_var_array
=
mgb
::
gopt
::
layout_transform
(
...
@@ -382,6 +379,103 @@ void NetworkImplDft::global_layout_transform() {
...
@@ -382,6 +379,103 @@ void NetworkImplDft::global_layout_transform() {
for
(
auto
&&
item
:
m_load_result
.
output_var_map
)
{
for
(
auto
&&
item
:
m_load_result
.
output_var_map
)
{
item
.
second
=
out_var_map
[
item
.
second
];
item
.
second
=
out_var_map
[
item
.
second
];
}
}
}
else
if
(
m_user_config
->
auto_optimize_inference
)
{
//! set model weight preprocess
m_load_config
.
comp_graph
->
options
().
graph_opt
.
weight_preprocess
=
true
;
LITE_LOG
(
"weight_preprocess is enabled, this maybe use more memory when "
"infernece."
);
//! get the current format and data type of the model
bool
is_model_nchw
=
true
;
//! is any convolution is int8
bool
is_model_int8
=
false
;
//! is all convolution is float32
bool
is_model_float32
=
true
;
float
conv_cnt
=
0
;
float
dimshuffle_cnt
=
0
;
auto
detect_int8_model
=
[
&
](
const
VarNode
*
input
)
{
if
(
input
->
dtype
().
enumv
()
==
megdnn
::
DTypeEnum
::
QuantizedS8
||
input
->
dtype
().
enumv
()
==
megdnn
::
DTypeEnum
::
Quantized8Asymm
)
{
is_model_int8
=
true
;
is_model_float32
=
false
;
}
else
if
(
input
->
dtype
().
enumv
()
==
megdnn
::
DTypeEnum
::
Float32
)
{
is_model_float32
=
(
is_model_float32
&&
true
);
}
else
{
is_model_float32
=
false
;
}
};
cg
::
DepOprIter
dep
([
&
](
cg
::
OperatorNodeBase
*
opr
)
{
if
(
auto
conv
=
opr
->
try_cast_final
<
opr
::
ConvolutionForward
>
())
{
if
(
conv
->
param
().
format
!=
megdnn
::
param
::
ConvBias
::
Format
::
NCHW
)
{
is_model_nchw
=
false
;
}
conv_cnt
++
;
detect_int8_model
(
conv
->
input
(
0
));
}
else
if
(
auto
conv_bias
=
opr
->
try_cast_final
<
opr
::
ConvBias
>
())
{
if
(
conv_bias
->
param
().
format
!=
megdnn
::
param
::
ConvBias
::
Format
::
NCHW
)
{
is_model_nchw
=
false
;
}
conv_cnt
++
;
detect_int8_model
(
conv
->
input
(
0
));
}
else
if
(
auto
dimshuffle
=
opr
->
try_cast_final
<
opr
::
Dimshuffle
>
())
{
LITE_MARK_USED_VAR
(
dimshuffle
);
dimshuffle_cnt
++
;
}
});
for
(
auto
&&
i
:
m_load_result
.
output_var_list
)
dep
.
add
(
i
);
float
radio_dimshuffle_conv
=
0
;
if
(
conv_cnt
>
0
)
{
radio_dimshuffle_conv
=
dimshuffle_cnt
/
conv_cnt
;
}
//! format optimize can only applied on nchw model,
//! shufflenet like model will hurt the performance when using nchw88 or nchw44
//! format, here just heuristically decide the gate radio of
//! dimshuffle and convolution
if
(
!
is_model_nchw
||
radio_dimshuffle_conv
>
0.15
f
)
{
return
;
}
//! determine the layout by the device information
//! TODO: shufflenet like model use nchw88 or nchw44 will hurt the
//! performance
if
(
m_user_config
->
device_type
==
LITE_CPU
)
{
#if defined(MGB_ENABLE_CPUINFO_CHECK) && MGB_ENABLE_CPUINFO
cpuinfo_initialize
();
//! if all convolution and matmul data type is float32
if
(
is_model_float32
)
{
//! if device is x86
//! if x86 support avx, use format nchw88
if
(
cpuinfo_has_x86_avx
())
{
m_load_config
.
comp_graph
->
options
().
graph_opt
.
enable_nchw88
();
LITE_LOG
(
"Configure model inference with nchw88 format."
);
}
else
if
(
cpuinfo_has_x86_sse2
()
&&
!
cpuinfo_has_x86_sse3
())
{
//! if x86 only support sse2, use format nchw44
m_load_config
.
comp_graph
->
options
().
graph_opt
.
enable_nchw44
();
LITE_LOG
(
"Configure model inference with nchw44 format."
);
}
else
if
(
cpuinfo_has_arm_neon
())
{
//! if device is arm, use format nchw44
m_load_config
.
comp_graph
->
options
().
graph_opt
.
enable_nchw44
();
LITE_LOG
(
"Configure model inference with nchw44 format."
);
}
}
else
if
(
is_model_int8
)
{
//! if date type of convolution is int8
//! if device is arm and support dot, use nchw44-dot format
if
(
cpuinfo_has_arm_neon
()
&&
cpuinfo_has_arm_neon_dot
())
{
m_load_config
.
comp_graph
->
options
().
graph_opt
.
enable_nchw44_dot
();
LITE_LOG
(
"Configure model inference with nchw44-dot format."
);
}
else
if
(
cpuinfo_has_arm_neon
())
{
//! if device is arm and do not support dot, use nchw44 format
m_load_config
.
comp_graph
->
options
().
graph_opt
.
enable_nchw44
();
LITE_LOG
(
"Configure model inference with nchw44 format."
);
}
}
#endif
}
}
}
}
}
...
@@ -422,10 +516,13 @@ void NetworkImplDft::load_model(
...
@@ -422,10 +516,13 @@ void NetworkImplDft::load_model(
}
}
m_load_result
=
m_loader
->
load
(
m_load_config
,
true
);
m_load_result
=
m_loader
->
load
(
m_load_config
,
true
);
configure_after_loaded
();
}
void
NetworkImplDft
::
configure_after_loaded
()
{
modify_exection_policy
();
modify_exection_policy
();
global_layout_transform
();
layout_transform_optimization
();
//! some optimization option maybe invalid in some case, so here just
//! some optimization option maybe invalid in some case, so here just
//! auto determine whether some options will apply.
//! auto determine whether some options will apply.
...
...
lite/src/mge/network_impl.h
浏览文件 @
2d6476a4
...
@@ -178,8 +178,10 @@ private:
...
@@ -178,8 +178,10 @@ private:
//! call_back to the outputspec
//! call_back to the outputspec
void
make_output_spec
();
void
make_output_spec
();
//! do the global layout transform for the given platform target
//! do layout transform for the given platform target, maybe the global
void
global_layout_transform
();
//! layout optimization or heuristically choose the best layout according to
//! the device information
void
layout_transform_optimization
();
//! modify the execution policy
//! modify the execution policy
void
modify_exection_policy
();
void
modify_exection_policy
();
...
@@ -223,6 +225,9 @@ private:
...
@@ -223,6 +225,9 @@ private:
//! adapt option valid, it should call after update_io
//! adapt option valid, it should call after update_io
void
adapt_option_valid
();
void
adapt_option_valid
();
//! configure and optimize network after loaded
void
configure_after_loaded
();
private:
private:
bool
m_async
=
false
;
bool
m_async
=
false
;
bool
m_is_cpu_inplace_mode
=
false
;
bool
m_is_cpu_inplace_mode
=
false
;
...
...
lite/test/test_network_options.cpp
浏览文件 @
2d6476a4
...
@@ -48,6 +48,35 @@ TEST(TestNetWorkOptions, no_var_sanity_check_and_record) {
...
@@ -48,6 +48,35 @@ TEST(TestNetWorkOptions, no_var_sanity_check_and_record) {
compare_lite_tensor
<
float
>
(
output_tensor
,
result_mgb
);
compare_lite_tensor
<
float
>
(
output_tensor
,
result_mgb
);
}
}
TEST
(
TestNetWorkOptions
,
auto_optimize_inference_layout
)
{
Config
config
;
auto
tensor
=
get_input_data
(
"./input_data.npy"
);
std
::
string
model_path
=
"./shufflenet.mge"
;
std
::
string
input_name
=
"data"
;
auto
result_mgb
=
mgb_lar
(
model_path
,
config
,
input_name
,
tensor
);
config
.
auto_optimize_inference
=
true
;
std
::
shared_ptr
<
Network
>
network
=
std
::
make_shared
<
Network
>
(
config
);
network
->
load_model
(
model_path
);
std
::
shared_ptr
<
Tensor
>
input_tensor
=
network
->
get_io_tensor
(
input_name
);
auto
src_ptr
=
tensor
->
get_memory_ptr
();
auto
src_layout
=
tensor
->
get_layout
();
input_tensor
->
reset
(
src_ptr
,
src_layout
);
std
::
shared_ptr
<
Tensor
>
output_tensor
=
network
->
get_output_tensor
(
0
);
auto
result_tensor
=
std
::
make_shared
<
Tensor
>
(
LiteDeviceType
::
LITE_CPU
,
Layout
{{
1
,
1000
},
2
,
LiteDataType
::
LITE_FLOAT
});
void
*
out_data
=
result_tensor
->
get_memory_ptr
();
output_tensor
->
reset
(
out_data
,
result_tensor
->
get_layout
());
network
->
forward
();
network
->
wait
();
compare_lite_tensor
<
float
>
(
output_tensor
,
result_mgb
);
}
TEST
(
TestNetWorkOptions
,
const_shape
)
{
TEST
(
TestNetWorkOptions
,
const_shape
)
{
Config
config
;
Config
config
;
auto
tensor
=
get_input_data
(
"./input_data.npy"
);
auto
tensor
=
get_input_data
(
"./input_data.npy"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录