Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
eef0308b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
eef0308b
编写于
6月 01, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(serialization): add no_change_graph and version param whem dump model
GitOrigin-RevId: 65064452c939ecd7313fe25d608f2b23e0778a77
上级
4ab5f970
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
95 addition
and
23 deletion
+95
-23
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+15
-1
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+15
-0
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+10
-4
src/opr/impl/dnn/dnn.sereg.v2.h
src/opr/impl/dnn/dnn.sereg.v2.h
+11
-2
src/serialization/impl/serializer.cpp
src/serialization/impl/serializer.cpp
+4
-3
src/serialization/impl/serializer_oss_v2.cpp
src/serialization/impl/serializer_oss_v2.cpp
+21
-8
src/serialization/include/megbrain/serialization/load_dump_config.h
...ization/include/megbrain/serialization/load_dump_config.h
+9
-2
src/serialization/include/megbrain/serialization/oss_opr_load_dump.h
...zation/include/megbrain/serialization/oss_opr_load_dump.h
+4
-1
src/serialization/include/megbrain/serialization/serializer.h
...serialization/include/megbrain/serialization/serializer.h
+3
-1
src/serialization/test/serializer_oss.cpp
src/serialization/test/serializer_oss.cpp
+3
-1
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
eef0308b
...
@@ -367,10 +367,12 @@ def dump_graph(
...
@@ -367,10 +367,12 @@ def dump_graph(
keep_opr_name
:
bool
=
False
,
keep_opr_name
:
bool
=
False
,
keep_param_name
:
bool
=
False
,
keep_param_name
:
bool
=
False
,
keep_opr_priority
:
bool
=
False
,
keep_opr_priority
:
bool
=
False
,
no_change_graph
:
bool
=
False
,
strip_info_file
=
None
,
strip_info_file
=
None
,
append_json
=
False
,
append_json
=
False
,
metadata
=
None
,
metadata
=
None
,
dump_format
=
None
dump_format
=
None
,
model_version
:
int
=
2
)
->
Tuple
[
bytes
,
CompGraphDumpResult
]:
)
->
Tuple
[
bytes
,
CompGraphDumpResult
]:
r
"""serialize the computing graph of `output_vars` and get byte result.
r
"""serialize the computing graph of `output_vars` and get byte result.
...
@@ -386,12 +388,22 @@ def dump_graph(
...
@@ -386,12 +388,22 @@ def dump_graph(
keep_param_name: whether to keep param names, so param values can be
keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
easily manipulated after loading model
keep_opr_priority: whether to keep priority setting for operators
keep_opr_priority: whether to keep priority setting for operators
no_change_graph: whether to change the compute graph when dump, for
model compatibility, some operators will convert to its compatible
format in this version.
* if set False, some operators maybe convert to other operator for
compatibility, all operators will ensure compatibility.
* if set True, no operator will change in the graph when dump.
strip_info_file: a string for path or a file handler. if is not None,
strip_info_file: a string for path or a file handler. if is not None,
then the dump information for code strip would be written to ``strip_info_file``
then the dump information for code strip would be written to ``strip_info_file``
append_json: will be check when `strip_info_file` is not None. if set
append_json: will be check when `strip_info_file` is not None. if set
true, the information for code strip will be append to strip_info_file.
true, the information for code strip will be append to strip_info_file.
if set false, will rewrite strip_info_file
if set false, will rewrite strip_info_file
dump_format: using different dump formats.
dump_format: using different dump formats.
model_version: the model version of "FBS_V2", begin with version 2, this
works only when dump format is "FBS_V2".
Note:
Note:
The underlying C++ API only accepts a var list. If a dict is given,
The underlying C++ API only accepts a var list. If a dict is given,
...
@@ -441,8 +453,10 @@ def dump_graph(
...
@@ -441,8 +453,10 @@ def dump_graph(
keep_opr_name
,
keep_opr_name
,
keep_param_name
,
keep_param_name
,
keep_opr_priority
,
keep_opr_priority
,
no_change_graph
,
metadata
,
metadata
,
dump_format
,
dump_format
,
model_version
,
stat
,
stat
,
inputs
,
inputs
,
outputs
,
outputs
,
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
eef0308b
...
@@ -549,6 +549,7 @@ class trace:
...
@@ -549,6 +549,7 @@ class trace:
keep_opr_name
:
bool
=
False
,
keep_opr_name
:
bool
=
False
,
keep_param_name
:
bool
=
False
,
keep_param_name
:
bool
=
False
,
keep_opr_priority
:
bool
=
False
,
keep_opr_priority
:
bool
=
False
,
no_change_graph
:
bool
=
False
,
strip_info_file
=
None
,
strip_info_file
=
None
,
append_json
=
False
,
append_json
=
False
,
optimize_for_inference
=
True
,
optimize_for_inference
=
True
,
...
@@ -562,6 +563,7 @@ class trace:
...
@@ -562,6 +563,7 @@ class trace:
resize_input
=
False
,
resize_input
=
False
,
input_transform
=
None
,
input_transform
=
None
,
dump_format
:
str
=
None
,
dump_format
:
str
=
None
,
model_version
:
int
=
2
,
**
kwargs
**
kwargs
):
):
r
"""Serializes trace to file system.
r
"""Serializes trace to file system.
...
@@ -583,6 +585,14 @@ class trace:
...
@@ -583,6 +585,14 @@ class trace:
keep_param_name: whether to keep param names, so param values can be
keep_param_name: whether to keep param names, so param values can be
easily manipulated after loading model
easily manipulated after loading model
keep_opr_priority: whether to keep priority setting for operators
keep_opr_priority: whether to keep priority setting for operators
no_change_graph: whether to change the compute graph when dump, for
model compatibility, some operators will convert to its compatible
format in this version.
* if set False, some operators maybe convert to other operator for
compatibility, all operators will ensure compatibility.
* if set True, no operator will change in the graph when dump.
strip_info_file: a string for path or a file handler. if is not None,
strip_info_file: a string for path or a file handler. if is not None,
then the dump information for code strip would be written to ``strip_info_file``
then the dump information for code strip would be written to ``strip_info_file``
append_json: will be check when `strip_info_file` is not None. if set
append_json: will be check when `strip_info_file` is not None. if set
...
@@ -616,6 +626,9 @@ class trace:
...
@@ -616,6 +626,9 @@ class trace:
dump_format: using different dump formats. the open source MegEngine
dump_format: using different dump formats. the open source MegEngine
defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose,
defaults to the FBS_V2 format, there are two format FBS_V2 and FBS to choose,
internal MegEngine have an other choice of internal proprietary formats
internal MegEngine have an other choice of internal proprietary formats
model_version: the model version of FBS_V2, begin with version 2, this
works only when dump format is FBS_V2.
Keyword Arguments:
Keyword Arguments:
...
@@ -762,10 +775,12 @@ class trace:
...
@@ -762,10 +775,12 @@ class trace:
keep_opr_name
=
keep_opr_name
,
keep_opr_name
=
keep_opr_name
,
keep_param_name
=
keep_param_name
,
keep_param_name
=
keep_param_name
,
keep_opr_priority
=
keep_opr_priority
,
keep_opr_priority
=
keep_opr_priority
,
no_change_graph
=
no_change_graph
,
strip_info_file
=
strip_info_file
,
strip_info_file
=
strip_info_file
,
append_json
=
append_json
,
append_json
=
append_json
,
metadata
=
metadata
,
metadata
=
metadata
,
dump_format
=
dump_format
,
dump_format
=
dump_format
,
model_version
=
model_version
,
)
)
file
.
write
(
dump_content
)
file
.
write
(
dump_content
)
...
...
imperative/python/src/graph_rt.cpp
浏览文件 @
eef0308b
...
@@ -381,20 +381,26 @@ void init_graph_rt(py::module m) {
...
@@ -381,20 +381,26 @@ void init_graph_rt(py::module m) {
m
.
def
(
"dump_graph"
,
m
.
def
(
"dump_graph"
,
[](
const
std
::
vector
<
VarNode
*>&
dest_vars
,
int
keep_var_name
,
[](
const
std
::
vector
<
VarNode
*>&
dest_vars
,
int
keep_var_name
,
bool
keep_opr_name
,
bool
keep_param_name
,
bool
keep_opr_priority
,
bool
keep_opr_name
,
bool
keep_param_name
,
bool
keep_opr_priority
,
std
::
optional
<
_SerializationMetadata
>
metadata
,
bool
no_change_graph
,
std
::
optional
<
_SerializationMetadata
>
metadata
,
std
::
optional
<
_SerializationFormat
>
dump_format
,
py
::
list
&
stat
,
std
::
optional
<
_SerializationFormat
>
dump_format
,
py
::
list
&
inputs
,
py
::
list
&
outputs
,
py
::
list
&
params
)
{
std
::
optional
<
int
>
model_version
,
py
::
list
&
stat
,
py
::
list
&
inputs
,
py
::
list
&
outputs
,
py
::
list
&
params
)
{
std
::
vector
<
uint8_t
>
buf
;
std
::
vector
<
uint8_t
>
buf
;
ser
::
GraphDumpFormat
format
=
ser
::
GraphDumpFormat
::
FLATBUFFERS_V2
;
ser
::
GraphDumpFormat
format
=
ser
::
GraphDumpFormat
::
FLATBUFFERS_V2
;
int
version
=
2
;
if
(
dump_format
.
has_value
())
{
if
(
dump_format
.
has_value
())
{
format
=
dump_format
.
value
();
format
=
dump_format
.
value
();
}
}
if
(
model_version
.
has_value
())
{
version
=
model_version
.
value
();
}
auto
dumper
=
ser
::
GraphDumper
::
make
(
auto
dumper
=
ser
::
GraphDumper
::
make
(
ser
::
OutputFile
::
make_vector_proxy
(
&
buf
),
format
);
ser
::
OutputFile
::
make_vector_proxy
(
&
buf
),
format
,
version
);
SymbolVarArray
symvars
(
dest_vars
.
begin
(),
dest_vars
.
end
());
SymbolVarArray
symvars
(
dest_vars
.
begin
(),
dest_vars
.
end
());
ser
::
GraphDumper
::
DumpConfig
config
{
ser
::
GraphDumper
::
DumpConfig
config
{
keep_var_name
,
keep_param_name
,
keep_opr_priority
,
keep_opr_name
};
keep_var_name
,
keep_param_name
,
keep_opr_priority
,
keep_opr_name
};
config
.
no_change_graph
=
no_change_graph
;
ser
::
GraphDumper
::
DumpResult
rst
;
ser
::
GraphDumper
::
DumpResult
rst
;
if
(
metadata
)
if
(
metadata
)
...
...
src/opr/impl/dnn/dnn.sereg.v2.h
浏览文件 @
eef0308b
...
@@ -21,6 +21,13 @@ struct OprLoadDumpImplV2<opr::Softmax, 1> {
...
@@ -21,6 +21,13 @@ struct OprLoadDumpImplV2<opr::Softmax, 1> {
ctx
.
write_param
<
PersisParam
>
(
opr
.
cast_final_safe
<
Opr
>
().
param
());
ctx
.
write_param
<
PersisParam
>
(
opr
.
cast_final_safe
<
Opr
>
().
param
());
}
}
/** This converter is just a example for Operator serialization compatible,
* Just in this situation: when optimize the softmax Operator by
* fusing the elemwise and reduce to a big Operator, but the whole softmax
* Operator can't be recognized by old version, in order to model
* compatibility the softmax Operator should be covert to elemwise and
* reduce Operators when dump the model
*/
static
cg
::
OperatorNodeBase
*
replace_opr
(
static
cg
::
OperatorNodeBase
*
replace_opr
(
cg
::
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
inputs
)
{
cg
::
OperatorNodeBase
*
opr
,
const
VarNodeArray
&
inputs
)
{
int32_t
axis
=
opr
->
cast_final_safe
<
Opr
>
().
param
().
axis
;
int32_t
axis
=
opr
->
cast_final_safe
<
Opr
>
().
param
().
axis
;
...
@@ -196,9 +203,11 @@ namespace opr {
...
@@ -196,9 +203,11 @@ namespace opr {
#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \
#define SERGE_OPR_V2_NO_CONVERTER(_cls, _arity) \
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION);
MGB_SEREG_OPR_V2(_cls, _arity, nullptr, VERSION_2, CURRENT_VERSION);
SERGE_OPR_V2_CONVERTER
(
//! this is just a example for Operator compatibility
/*SERGE_OPR_V2_CONVERTER(
Softmax, 1,
Softmax, 1,
(
mgb
::
serialization
::
OprLoadDumpImplV2
<
opr
::
Softmax
,
1
>::
replace_opr
));
(mgb::serialization::OprLoadDumpImplV2<opr::Softmax, 1>::replace_opr));*/
SERGE_OPR_V2_NO_CONVERTER
(
Softmax
,
1
)
SERGE_OPR_V2_NO_CONVERTER
(
ConvBiasForward
,
0
)
SERGE_OPR_V2_NO_CONVERTER
(
ConvBiasForward
,
0
)
SERGE_OPR_V2_NO_CONVERTER
(
BatchConvBiasForward
,
0
);
SERGE_OPR_V2_NO_CONVERTER
(
BatchConvBiasForward
,
0
);
...
...
src/serialization/impl/serializer.cpp
浏览文件 @
eef0308b
...
@@ -59,7 +59,8 @@ std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file);
...
@@ -59,7 +59,8 @@ std::unique_ptr<GraphLoader> make_fbs_loader(std::unique_ptr<InputFile> file);
std
::
unique_ptr
<
GraphDumper
>
make_fbs_dumper
(
std
::
unique_ptr
<
OutputFile
>
file
);
std
::
unique_ptr
<
GraphDumper
>
make_fbs_dumper
(
std
::
unique_ptr
<
OutputFile
>
file
);
std
::
unique_ptr
<
GraphLoader
>
make_fbs_v2_loader
(
std
::
unique_ptr
<
InputFile
>
file
);
std
::
unique_ptr
<
GraphLoader
>
make_fbs_v2_loader
(
std
::
unique_ptr
<
InputFile
>
file
);
std
::
unique_ptr
<
GraphDumper
>
make_fbs_v2_dumper
(
std
::
unique_ptr
<
OutputFile
>
file
);
std
::
unique_ptr
<
GraphDumper
>
make_fbs_v2_dumper
(
std
::
unique_ptr
<
OutputFile
>
file
,
int
version
);
bool
is_fbs_file
(
InputFile
&
file
);
bool
is_fbs_file
(
InputFile
&
file
);
bool
is_fbs_v2_file
(
InputFile
&
file
);
bool
is_fbs_v2_file
(
InputFile
&
file
);
...
@@ -72,7 +73,7 @@ bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) {
...
@@ -72,7 +73,7 @@ bool GraphDumper::should_remove_in_dump(cg::OperatorNodeBase* opr) {
}
}
std
::
unique_ptr
<
GraphDumper
>
GraphDumper
::
make
(
std
::
unique_ptr
<
GraphDumper
>
GraphDumper
::
make
(
std
::
unique_ptr
<
OutputFile
>
file
,
GraphDumpFormat
format
)
{
std
::
unique_ptr
<
OutputFile
>
file
,
GraphDumpFormat
format
,
int
version
)
{
switch
(
format
)
{
switch
(
format
)
{
case
GraphDumpFormat
::
FLATBUFFERS
:
case
GraphDumpFormat
::
FLATBUFFERS
:
#if MGB_ENABLE_FBS_SERIALIZATION
#if MGB_ENABLE_FBS_SERIALIZATION
...
@@ -81,7 +82,7 @@ std::unique_ptr<GraphDumper> GraphDumper::make(
...
@@ -81,7 +82,7 @@ std::unique_ptr<GraphDumper> GraphDumper::make(
MGB_FALLTHRU
MGB_FALLTHRU
case
GraphDumpFormat
::
FLATBUFFERS_V2
:
case
GraphDumpFormat
::
FLATBUFFERS_V2
:
#if MGB_ENABLE_FBS_SERIALIZATION
#if MGB_ENABLE_FBS_SERIALIZATION
return
make_fbs_v2_dumper
(
std
::
move
(
file
));
return
make_fbs_v2_dumper
(
std
::
move
(
file
)
,
version
);
#endif
#endif
MGB_FALLTHRU
MGB_FALLTHRU
default:
default:
...
...
src/serialization/impl/serializer_oss_v2.cpp
浏览文件 @
eef0308b
...
@@ -194,7 +194,7 @@ void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) {
...
@@ -194,7 +194,7 @@ void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) {
}
}
}
else
{
}
else
{
auto
registry
=
OprRegistryV2
::
versioned_find_by_typeinfo
(
auto
registry
=
OprRegistryV2
::
versioned_find_by_typeinfo
(
opr
->
dyn_typeinfo
(),
CURRENT_VERSION
);
opr
->
dyn_typeinfo
(),
m_version
);
if
(
!
registry
||
!
registry
->
dumper
)
{
if
(
!
registry
||
!
registry
->
dumper
)
{
mgb_throw
(
mgb_throw
(
cg
::
OperatorNodeExcExtraInfo
::
ExcMaker
{
opr
}.
make
<
MegBrainError
>
,
cg
::
OperatorNodeExcExtraInfo
::
ExcMaker
{
opr
}.
make
<
MegBrainError
>
,
...
@@ -202,6 +202,9 @@ void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) {
...
@@ -202,6 +202,9 @@ void GraphDumperOSSV2::init_oprs_to_dump(const SymbolVarArray& endpoints) {
"operator %s"
,
"operator %s"
,
opr
->
dyn_typeinfo
()
->
name
);
opr
->
dyn_typeinfo
()
->
name
);
}
}
mgb_assert
(
registry
->
version
<=
m_version
,
"The Operator version should less than model version"
);
m_oprs_to_dump
.
emplace_back
(
opr
,
registry
);
m_oprs_to_dump
.
emplace_back
(
opr
,
registry
);
}
}
};
};
...
@@ -352,7 +355,10 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
...
@@ -352,7 +355,10 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
const
Metadata
&
metadata
)
{
const
Metadata
&
metadata
)
{
mgb_throw_if
(
output_vars
.
empty
(),
SerializationError
,
"Can't dump empty graph"
);
mgb_throw_if
(
output_vars
.
empty
(),
SerializationError
,
"Can't dump empty graph"
);
auto
&&
new_output_vars
=
converter_all_opr_to_compatiable
(
output_vars
);
auto
new_output_vars
=
output_vars
;
if
(
!
config
.
no_change_graph
)
{
new_output_vars
=
converter_all_opr_to_compatiable
(
output_vars
);
}
auto
begin_pos
=
m_file
->
tell
();
auto
begin_pos
=
m_file
->
tell
();
m_config
=
config
;
m_config
=
config
;
...
@@ -416,6 +422,7 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
...
@@ -416,6 +422,7 @@ GraphDumper::DumpResult GraphDumperOSSV2::dump(
fbs
::
v2
::
ModelBuilder
model
(
m_builder
);
fbs
::
v2
::
ModelBuilder
model
(
m_builder
);
model
.
add_mge_version
(
MGB_VERSION
);
model
.
add_mge_version
(
MGB_VERSION
);
model
.
add_model_version
(
m_version
);
model
.
add_oprs
(
fb_oprs
);
model
.
add_oprs
(
fb_oprs
);
model
.
add_middle_tensors
(
fb_mid_tensor
);
model
.
add_middle_tensors
(
fb_mid_tensor
);
model
.
add_output_vars_idx
(
fb_output_vars
);
model
.
add_output_vars_idx
(
fb_output_vars
);
...
@@ -694,10 +701,8 @@ void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr(
...
@@ -694,10 +701,8 @@ void GraphLoaderOSSV2::OprLoadContextImpl::load_single_opr(
OprRegistryV2
::
versioned_find_by_id
(
type_id
,
opr_version
);
OprRegistryV2
::
versioned_find_by_id
(
type_id
,
opr_version
);
mgb_throw_if
(
mgb_throw_if
(
!
registry
,
SerializationError
,
!
registry
,
SerializationError
,
"failed to find opr with type %s , use python env "
"failed to find opr with type %s and version %d."
,
"config.dump_registered_oprs() to get a dict that maps from "
fbopr
->
type
()
->
str
().
c_str
(),
opr_version
);
"opr id to opr name"
,
fbopr
->
type
()
->
str
().
c_str
());
// load inputs
// load inputs
VarNodeArray
inputs
;
VarNodeArray
inputs
;
...
@@ -811,12 +816,19 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
...
@@ -811,12 +816,19 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
m_model
=
fbs
::
v2
::
GetModel
(
m_model_buf
.
data
());
m_model
=
fbs
::
v2
::
GetModel
(
m_model_buf
.
data
());
m_mgb_version
=
m_model
->
mge_version
();
m_mgb_version
=
m_model
->
mge_version
();
m_model_version
=
m_model
->
model_version
();
if
(
m_model
->
mge_version
()
>
MGB_VERSION
)
{
if
(
m_model
->
mge_version
()
>
MGB_VERSION
)
{
mgb_log_warn
(
mgb_log_warn
(
"loading model from future runtime: version=%u "
"loading model from future runtime: version=%u "
"model_version=%u"
,
"model_version=%u"
,
MGB_VERSION
,
m_model
->
mge_version
());
MGB_VERSION
,
m_model
->
mge_version
());
}
}
if
(
m_model_version
>
CURRENT_VERSION
)
{
mgb_log_warn
(
"The model dump in the future version %d, try to load it, maybe case "
"load error in %d version."
,
m_model_version
,
CURRENT_VERSION
);
}
if
(
m_shared_tensor_map
.
empty
())
{
if
(
m_shared_tensor_map
.
empty
())
{
m_shared_tensor_map
.
resize
(
m_model
->
nr_shared_tensor
());
m_shared_tensor_map
.
resize
(
m_model
->
nr_shared_tensor
());
...
@@ -845,8 +857,9 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
...
@@ -845,8 +857,9 @@ GraphLoader::LoadResult GraphLoaderOSSV2::load(const LoadConfig& config, bool re
return
result
;
return
result
;
}
}
std
::
unique_ptr
<
GraphDumper
>
make_fbs_v2_dumper
(
std
::
unique_ptr
<
OutputFile
>
file
)
{
std
::
unique_ptr
<
GraphDumper
>
make_fbs_v2_dumper
(
return
std
::
make_unique
<
GraphDumperOSSV2
>
(
std
::
move
(
file
));
std
::
unique_ptr
<
OutputFile
>
file
,
int
version
)
{
return
std
::
make_unique
<
GraphDumperOSSV2
>
(
std
::
move
(
file
),
version
);
}
}
std
::
unique_ptr
<
GraphLoader
>
make_fbs_v2_loader
(
std
::
unique_ptr
<
InputFile
>
file
)
{
std
::
unique_ptr
<
GraphLoader
>
make_fbs_v2_loader
(
std
::
unique_ptr
<
InputFile
>
file
)
{
...
...
src/serialization/include/megbrain/serialization/load_dump_config.h
浏览文件 @
eef0308b
...
@@ -58,18 +58,25 @@ struct GraphDumpConfig {
...
@@ -58,18 +58,25 @@ struct GraphDumpConfig {
//! names. this list record the mapping between output node and it's name
//! names. this list record the mapping between output node and it's name
std
::
vector
<
std
::
pair
<
std
::
string
,
SymbolVar
>>
alias_name_map
;
std
::
vector
<
std
::
pair
<
std
::
string
,
SymbolVar
>>
alias_name_map
;
//! whether just to dump all the op with no change the graph, sometimes the
//! opr maybe not compatible, if false, some opr will converter to the compatibility
//! format and then dump
bool
no_change_graph
;
GraphDumpConfig
(
GraphDumpConfig
(
int
keep_var_name_
=
1
,
bool
keep_param_name_
=
false
,
int
keep_var_name_
=
1
,
bool
keep_param_name_
=
false
,
bool
keep_opr_priority_
=
false
,
bool
keep_op_name_
=
true
,
bool
keep_opr_priority_
=
false
,
bool
keep_op_name_
=
true
,
const
std
::
shared_ptr
<
UserDataContainer
>&
user_data_
=
const
std
::
shared_ptr
<
UserDataContainer
>&
user_data_
=
std
::
make_shared
<
UserDataContainer
>
(),
std
::
make_shared
<
UserDataContainer
>
(),
const
TensorValueDumper
&
tensor_value_dumper_
=
{})
const
TensorValueDumper
&
tensor_value_dumper_
=
{},
bool
no_change_graph_
=
false
)
:
keep_var_name
{
keep_var_name_
},
:
keep_var_name
{
keep_var_name_
},
keep_param_name
{
keep_param_name_
},
keep_param_name
{
keep_param_name_
},
keep_opr_priority
{
keep_opr_priority_
},
keep_opr_priority
{
keep_opr_priority_
},
keep_op_name
{
keep_op_name_
},
keep_op_name
{
keep_op_name_
},
user_data
{
user_data_
},
user_data
{
user_data_
},
tensor_value_dumper
{
tensor_value_dumper_
}
{}
tensor_value_dumper
{
tensor_value_dumper_
},
no_change_graph
{
no_change_graph_
}
{}
};
};
//! config for loading a whole graph; setup in GraphLoader
//! config for loading a whole graph; setup in GraphLoader
...
...
src/serialization/include/megbrain/serialization/oss_opr_load_dump.h
浏览文件 @
eef0308b
...
@@ -15,6 +15,7 @@ namespace serialization {
...
@@ -15,6 +15,7 @@ namespace serialization {
class
GraphDumperOSSV2
final
:
public
GraphDumper
,
OprDumpContextFlatBuffers
{
class
GraphDumperOSSV2
final
:
public
GraphDumper
,
OprDumpContextFlatBuffers
{
const
std
::
unique_ptr
<
OutputFile
>
m_file
;
const
std
::
unique_ptr
<
OutputFile
>
m_file
;
int
m_version
;
flatbuffers
::
FlatBufferBuilder
m_builder
;
flatbuffers
::
FlatBufferBuilder
m_builder
;
DumpConfig
m_config
;
DumpConfig
m_config
;
...
@@ -51,7 +52,8 @@ class GraphDumperOSSV2 final : public GraphDumper, OprDumpContextFlatBuffers {
...
@@ -51,7 +52,8 @@ class GraphDumperOSSV2 final : public GraphDumper, OprDumpContextFlatBuffers {
flatbuffers
::
Offset
<
fbs
::
DType
>
build_dtype
(
DType
dtype
);
flatbuffers
::
Offset
<
fbs
::
DType
>
build_dtype
(
DType
dtype
);
public:
public:
GraphDumperOSSV2
(
std
::
unique_ptr
<
OutputFile
>
file
)
:
m_file
{
std
::
move
(
file
)}
{}
GraphDumperOSSV2
(
std
::
unique_ptr
<
OutputFile
>
file
,
int
version
)
:
m_file
{
std
::
move
(
file
)},
m_version
{
version
}
{}
DumpResult
dump
(
DumpResult
dump
(
const
SymbolVarArray
&
output_vars
,
const
DumpConfig
&
config
=
{},
const
SymbolVarArray
&
output_vars
,
const
DumpConfig
&
config
=
{},
...
@@ -95,6 +97,7 @@ class GraphLoaderOSSV2 final : public GraphLoader {
...
@@ -95,6 +97,7 @@ class GraphLoaderOSSV2 final : public GraphLoader {
const
fbs
::
v2
::
Model
*
m_model
;
const
fbs
::
v2
::
Model
*
m_model
;
SharedTensorIDMap
m_shared_tensor_map
;
SharedTensorIDMap
m_shared_tensor_map
;
uint32_t
m_mgb_version
=
0
;
uint32_t
m_mgb_version
=
0
;
uint32_t
m_model_version
=
CURRENT_VERSION
;
bool
m_model_loaded
=
false
;
bool
m_model_loaded
=
false
;
void
verify
();
void
verify
();
...
...
src/serialization/include/megbrain/serialization/serializer.h
浏览文件 @
eef0308b
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "megbrain/serialization/file.h"
#include "megbrain/serialization/file.h"
#include "megbrain/serialization/load_dump_config.h"
#include "megbrain/serialization/load_dump_config.h"
#include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/opr_load_dump.h"
namespace
mgb
{
namespace
mgb
{
namespace
serialization
{
namespace
serialization
{
...
@@ -160,7 +161,8 @@ public:
...
@@ -160,7 +161,8 @@ public:
};
};
MGE_WIN_DECLSPEC_FUC
static
std
::
unique_ptr
<
GraphDumper
>
make
(
MGE_WIN_DECLSPEC_FUC
static
std
::
unique_ptr
<
GraphDumper
>
make
(
std
::
unique_ptr
<
OutputFile
>
file
,
GraphDumpFormat
format
=
{});
std
::
unique_ptr
<
OutputFile
>
file
,
GraphDumpFormat
format
=
{},
int
version
=
VERSION_2
);
virtual
~
GraphDumper
()
=
default
;
virtual
~
GraphDumper
()
=
default
;
...
...
src/serialization/test/serializer_oss.cpp
浏览文件 @
eef0308b
...
@@ -987,7 +987,9 @@ TEST(TestSerializer2, TestSoftMaxLoadDump) {
...
@@ -987,7 +987,9 @@ TEST(TestSerializer2, TestSoftMaxLoadDump) {
OutputFile
::
make_fs
(
fname
.
c_str
()),
GraphDumpFormat
::
FLATBUFFERS_V2
);
OutputFile
::
make_fs
(
fname
.
c_str
()),
GraphDumpFormat
::
FLATBUFFERS_V2
);
auto
rst
=
dumper
->
dump
({
x
});
auto
rst
=
dumper
->
dump
({
x
});
func
->
execute
().
wait
();
func
->
execute
().
wait
();
ASSERT_EQ
(
rst
.
nr_opr
,
6
);
//! if convert to reduce and elemwise, nr_opr is 6
// ASSERT_EQ(rst.nr_opr, 6);
ASSERT_EQ
(
rst
.
nr_opr
,
2
);
ASSERT_EQ
(
rst
.
inputs
.
size
(),
1
);
ASSERT_EQ
(
rst
.
inputs
.
size
(),
1
);
ASSERT_EQ
(
rst
.
outputs
.
size
(),
1
);
ASSERT_EQ
(
rst
.
outputs
.
size
(),
1
);
ASSERT_EQ
(
rst
.
params
.
size
(),
0
);
ASSERT_EQ
(
rst
.
params
.
size
(),
0
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录