Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d610c987
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d610c987
编写于
6月 07, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(lite): add disable configure by model info interface
GitOrigin-RevId: cd155a1fcf8bf6b845fa9118a6a791d662b2b624
上级
07bdb3bf
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
146 addition
and
9 deletion
+146
-9
lite/include/lite/network.h
lite/include/lite/network.h
+18
-0
lite/lite-c/include/lite-c/network_c.h
lite/lite-c/include/lite-c/network_c.h
+17
-0
lite/lite-c/src/network.cpp
lite/lite-c/src/network.cpp
+14
-0
lite/pylite/megenginelite/network.py
lite/pylite/megenginelite/network.py
+32
-0
lite/pylite/test/test_network.py
lite/pylite/test/test_network.py
+7
-0
lite/src/mge/network_impl.cpp
lite/src/mge/network_impl.cpp
+3
-2
lite/src/mge/network_impl.h
lite/src/mge/network_impl.h
+5
-1
lite/src/network.cpp
lite/src/network.cpp
+18
-3
lite/src/parse_model/model_parser.cpp
lite/src/parse_model/model_parser.cpp
+6
-2
lite/src/parse_model/model_parser.h
lite/src/parse_model/model_parser.h
+1
-1
lite/test/test_network_options.cpp
lite/test/test_network_options.cpp
+25
-0
未找到文件。
lite/include/lite/network.h
浏览文件 @
d610c987
...
...
@@ -117,6 +117,17 @@ struct LITE_API Config {
Options
options
=
{};
};
/*!
* \brief Extra Configuration for a network
*
* \param disable_configure_by_model_info disable the configuration dumped with model,
* if set true, all configuration in the model will not apply, users should configure
* the network.
*/
struct
LITE_API
ExtraConfig
{
bool
disable_configure_by_model_info
=
false
;
};
/*!
* \brief config the network input and output item
*
...
...
@@ -275,6 +286,12 @@ public:
//! get static peak memory info showed by Graph visualization
void
get_static_memory_alloc_info
(
const
std
::
string
&
log_dir
=
"logs/test"
)
const
;
/** @brief the extra configuration
*
* @param extra_config the extra configuration to set into the network
*/
void
extra_configure
(
const
ExtraConfig
&
extra_config
);
public:
friend
class
NetworkHelper
;
...
...
@@ -288,6 +305,7 @@ private:
private:
bool
m_loaded
=
false
;
Config
m_config
;
ExtraConfig
m_extra_config
;
NetworkIO
m_network_io
;
std
::
unique_ptr
<
NetworkImplBase
>
m_impl
;
std
::
string
m_extra_info
;
...
...
lite/lite-c/include/lite-c/network_c.h
浏览文件 @
d610c987
...
...
@@ -113,6 +113,17 @@ typedef struct LiteConfig {
//! get default config
LITE_API
LiteConfig
*
default_config
();
/*!
* \brief Exetra Configuration for a network
*
* \param disable_configure_by_model_info disable the configuration dumped with model,
* if set true, all configuration in the model will not apply, users should configure
* the network.
*/
typedef
struct
LiteExtraConfig
{
int
disable_configure_by_model_info
;
}
LiteExtraConfig
;
/*!
* \brief config the network input and output item
*
...
...
@@ -599,6 +610,12 @@ LITE_API int LITE_get_model_io_info_by_memory(
const
void
*
model_mem
,
size_t
size
,
const
LiteConfig
config
,
LiteNetworkIO
*
ios
);
/** @brief the extra configuration
*
* @param extra_config the extra configuration to set into the network
*/
LITE_API
int
LITE_extra_configure
(
LiteNetwork
network
,
LiteExtraConfig
extra_config
);
#ifdef __cplusplus
}
#endif
...
...
lite/lite-c/src/network.cpp
浏览文件 @
d610c987
...
...
@@ -181,6 +181,12 @@ InnerIO convert_to_inner_io(const lite::NetworkIO& network_io) {
return
innner_io
;
}
lite
::
ExtraConfig
convert_extra_config
(
const
LiteExtraConfig
&
extra_config
)
{
lite
::
ExtraConfig
ret
;
ret
.
disable_configure_by_model_info
=
extra_config
.
disable_configure_by_model_info
;
return
ret
;
}
int
LITE_make_default_network
(
LiteNetwork
*
network
)
{
LITE_CAPI_BEGIN
();
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
...
...
@@ -734,4 +740,12 @@ int LITE_get_model_io_info_by_memory(
LITE_CAPI_END
();
}
LITE_API
int
LITE_extra_configure
(
LiteNetwork
network
,
LiteExtraConfig
extra_config
)
{
LITE_CAPI_BEGIN
();
LITE_ASSERT
(
network
,
"The network pass to LITE api is null"
);
static_cast
<
lite
::
Network
*>
(
network
)
->
extra_configure
(
convert_extra_config
(
extra_config
));
LITE_CAPI_END
();
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
lite/pylite/megenginelite/network.py
浏览文件 @
d610c987
...
...
@@ -134,6 +134,31 @@ class LiteConfig(Structure):
return
data
.
__repr__
()
class
LiteExtraConfig
(
Structure
):
"""
Extra configuration when load and compile the graph
disable_configure_by_model_info: disable the configuration dumped with
model, if set true, all configuration in the model will not apply, users
should configure the network.
"""
_fields_
=
[
(
"disable_configure_by_model_info"
,
c_int
),
]
def
__init__
(
self
,
disable_model_config
=
False
):
self
.
disable_configure_by_model_info
=
disable_model_config
def
__repr__
(
self
):
data
=
{
"disable_configure_by_model_info"
:
bool
(
self
.
disable_configure_by_model_info
),
}
return
data
.
__repr__
()
class
LiteIO
(
Structure
):
"""
config the network input and output item
...
...
@@ -365,6 +390,7 @@ class _NetworkAPI(_LiteCObjBase):
"LITE_get_model_io_info_by_memory"
,
[
c_char_p
,
c_size_t
,
LiteConfig
,
POINTER
(
_LiteNetworkIO
)],
),
(
"LITE_extra_configure"
,
[
_Cnetwork
,
LiteExtraConfig
]),
]
...
...
@@ -541,6 +567,12 @@ class LiteNetwork(object):
ret_name
=
[
names
[
i
].
decode
(
"utf-8"
)
for
i
in
range
(
nr_output
.
value
)]
return
ret_name
def
extra_configure
(
self
,
extra_config
):
"""
Extra Configuration to the network.
"""
self
.
_api
.
LITE_extra_configure
(
self
.
_network
,
extra_config
)
def
share_weights_with
(
self
,
src_network
):
"""
share weights with the loaded network
...
...
lite/pylite/test/test_network.py
浏览文件 @
d610c987
...
...
@@ -112,6 +112,13 @@ class TestNetwork(TestShuffleNet):
network
.
load
(
model_path
)
self
.
do_forward
(
network
)
def
test_disable_model_config
(
self
):
model_path
=
os
.
path
.
join
(
self
.
source_dir
,
"test_packed_model_rc4.lite"
)
network
=
LiteNetwork
()
network
.
extra_configure
(
LiteExtraConfig
(
True
))
network
.
load
(
model_path
)
self
.
do_forward
(
network
)
def
test_pack_cache_to_model
(
self
):
model_path
=
os
.
path
.
join
(
self
.
source_dir
,
"test_pack_cache_to_model.lite"
)
network
=
LiteNetwork
()
...
...
lite/src/mge/network_impl.cpp
浏览文件 @
d610c987
...
...
@@ -31,7 +31,6 @@ using namespace mgb;
LITE_DYN_TYPE_OBJ_FINAL_IMPL
(
NetworkImplDft
);
void
NetworkImplDft
::
set_config
(
const
Config
&
config
)
{
m_user_config
=
std
::
make_unique
<
Config
>
();
*
m_user_config
=
config
;
m_compnode_locator
=
to_compnode_locator
(
m_user_config
->
device_type
);
m_compnode_locator
.
device
=
config
.
device_id
;
...
...
@@ -428,8 +427,11 @@ void NetworkImplDft::load_model(
global_layout_transform
();
//! some optimization option maybe invalid in some case, so here just
//! auto determine whether some options will apply.
adapt_option_valid
();
//! find how many compnode the model has, this should call before update_io
cross_compnode_model_detect
();
//! update the IO of the network
...
...
@@ -496,7 +498,6 @@ void NetworkImplDft::finish() const {
}
void
NetworkImplDft
::
set_io
(
const
NetworkIO
&
network_io
)
{
m_network_io
=
std
::
make_unique
<
NetworkIOInner
>
();
for
(
auto
&&
in
:
network_io
.
inputs
)
{
m_network_io
->
inputs
.
emplace_back
(
in
);
}
...
...
lite/src/mge/network_impl.h
浏览文件 @
d610c987
...
...
@@ -29,7 +29,11 @@ class NetworkImplDft final : public Network::NetworkImplBase {
LITE_DYN_TYPE_OBJ_FINAL_DECL
;
public:
NetworkImplDft
()
{
m_load_config
.
comp_graph
=
mgb
::
ComputingGraph
::
make
();
}
NetworkImplDft
()
{
m_load_config
.
comp_graph
=
mgb
::
ComputingGraph
::
make
();
m_user_config
=
std
::
make_unique
<
Config
>
();
m_network_io
=
std
::
make_unique
<
NetworkIOInner
>
();
}
using
S
=
megdnn
::
param
::
ExecutionPolicy
::
Strategy
;
using
Var
=
mgb
::
cg
::
SymbolVar
;
//! set the config of the network, include:
...
...
lite/src/network.cpp
浏览文件 @
d610c987
...
...
@@ -80,14 +80,17 @@ void Network::prase_model(std::shared_ptr<void> model_data, size_t size) {
ModelParser
model_parser
(
model_data
,
size
);
//! parse the model info
if
(
model_parser
.
parse_model_info
(
m_config
,
m_network_io
,
separate_config_map
,
m_extra_info
))
{
m_config
,
m_network_io
,
separate_config_map
,
m_extra_info
,
!
m_extra_config
.
disable_configure_by_model_info
))
{
if
(
m_config
.
backend
==
LiteBackend
::
LITE_DEFAULT
&&
m_impl
->
get_backend_type
()
!=
LiteBackend
::
LITE_DEFAULT
)
{
m_impl
.
reset
(
try_call_func
<
NetworkImplDft
,
lite
::
Network
::
NetworkImplBase
*>
(
"parse_model"
));
}
m_impl
->
set_config
(
m_config
);
m_impl
->
set_io
(
m_network_io
);
if
(
!
m_extra_config
.
disable_configure_by_model_info
)
{
m_impl
->
set_config
(
m_config
);
m_impl
->
set_io
(
m_network_io
);
}
}
//! decryption the model
size_t
model_length
;
...
...
@@ -290,6 +293,18 @@ void Network::get_static_memory_alloc_info(const std::string& log_dir) const {
LITE_ERROR_HANDLER_END
}
void
Network
::
extra_configure
(
const
ExtraConfig
&
extra_config
)
{
LITE_ERROR_HANDLER_BEGIN
if
(
!
extra_config
.
disable_configure_by_model_info
)
{
LITE_ASSERT
(
!
m_loaded
,
"disable_configure_by_model_info should be configured before model "
"loaded."
);
}
m_extra_config
=
extra_config
;
LITE_ERROR_HANDLER_END
}
/*********************** MGE special network function ***************/
void
Runtime
::
set_cpu_threads_number
(
...
...
lite/src/parse_model/model_parser.cpp
浏览文件 @
d610c987
...
...
@@ -43,7 +43,7 @@ void ModelParser::parse_header() {
bool
ModelParser
::
parse_model_info
(
Config
&
network_config
,
NetworkIO
&
network_io
,
std
::
unordered_map
<
std
::
string
,
LiteAny
>&
isolated_config_map
,
std
::
string
&
extra_info
)
const
{
std
::
string
&
extra_info
,
bool
configure_valid
)
const
{
//! no model info, no parse, direct return
if
(
m_is_bare_model
||
!
m_info
)
{
return
false
;
...
...
@@ -78,7 +78,7 @@ bool ModelParser::parse_model_info(
}
}
//! parse ModelInfo::algo_policy
if
(
m_info
->
algo_policy
())
{
if
(
m_info
->
algo_policy
()
&&
configure_valid
)
{
size_t
cache_length
=
m_info
->
algo_policy
()
->
size
();
const
uint8_t
*
cache
=
m_info
->
algo_policy
()
->
Data
();
if
(
m_info_cache_parse_func_name
==
"LITE_parse_cache"
)
{
...
...
@@ -93,6 +93,10 @@ bool ModelParser::parse_model_info(
}
else
{
LITE_THROW
(
"opencl binary cache is not given"
);
}
}
else
{
LITE_THROW
(
ssprintf
(
"model cache parse function of %s is not defined."
,
m_info_cache_parse_func_name
.
c_str
()));
}
}
return
true
;
...
...
lite/src/parse_model/model_parser.h
浏览文件 @
d610c987
...
...
@@ -25,7 +25,7 @@ public:
bool
parse_model_info
(
Config
&
network_config
,
NetworkIO
&
network_io
,
std
::
unordered_map
<
std
::
string
,
LiteAny
>&
isolated_config_map
,
std
::
string
&
extra_info
)
const
;
std
::
string
&
extra_info
,
bool
configure_valid
)
const
;
//! parse the model and decrypt the model
std
::
shared_ptr
<
void
>
parse_model
(
size_t
&
model_length
,
const
Config
&
config
)
const
;
...
...
lite/test/test_network_options.cpp
浏览文件 @
d610c987
...
...
@@ -7,6 +7,8 @@
#include "lite/global.h"
#include "megbrain/tensor.h"
#include "megbrain/utils/infile_persistent_cache.h"
#include "megbrain/utils/persistent_cache.h"
#include "test_common.h"
#include <string.h>
...
...
@@ -173,6 +175,29 @@ TEST(TestNetWorkOptions, test_cache) {
compare_lite_tensor
<
float
>
(
output_tensor
,
result_mgb
);
}
TEST
(
TestNetWorkOptions
,
DisableModelInfo
)
{
//! clear the cache set by other test
mgb
::
PersistentCache
::
inst
().
set_impl
(
std
::
make_shared
<
mgb
::
InMemoryPersistentCache
>
());
Config
config
;
auto
tensor
=
get_input_data
(
"./input_data.npy"
);
std
::
string
model_path
=
"./test_pack_cache_to_model.lite"
;
std
::
string
model_path2
=
"./test_pack_cache_to_model.lite"
;
std
::
string
input_name
=
"data"
;
std
::
shared_ptr
<
Network
>
network
=
std
::
make_shared
<
Network
>
(
config
);
network
->
extra_configure
({
true
});
Runtime
::
set_cpu_inplace_mode
(
network
);
network
->
load_model
(
model_path
);
//! the fast-run cache will not configure, so it is not support dump
ASSERT_EQ
(
mgb
::
PersistentCache
::
inst
().
support_dump_cache
(),
false
);
ASSERT_EQ
(
Runtime
::
is_cpu_inplace_mode
(
network
),
true
);
std
::
shared_ptr
<
Network
>
network2
=
std
::
make_shared
<
Network
>
(
config
);
network2
->
load_model
(
model_path2
);
//! the fast-run cache is configured by the model information
ASSERT_EQ
(
mgb
::
PersistentCache
::
inst
().
support_dump_cache
(),
true
);
}
TEST
(
TestNetWorkOptions
,
FastRunIgnorBatch
)
{
Config
config
;
auto
tensor
=
get_input_data
(
"./input_data.npy"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录