Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
54a4d70e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
54a4d70e
编写于
5月 11, 2021
作者:
M
Megvii Engine Team
提交者:
huangxinda
7月 19, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(src/serialization): add support of serializing metadata
GitOrigin-RevId: b563c94451b06055d53c99a85bb5689b3f907365
上级
721091fa
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
365 addition
and
34 deletion
+365
-34
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+51
-6
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+20
-2
imperative/python/megengine/tools/load_network_and_run.py
imperative/python/megengine/tools/load_network_and_run.py
+2
-1
imperative/python/megengine/utils/comp_graph_tools.py
imperative/python/megengine/utils/comp_graph_tools.py
+2
-1
imperative/python/megengine/utils/network.py
imperative/python/megengine/utils/network.py
+46
-7
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+26
-3
imperative/python/test/unit/jit/test_tracing.py
imperative/python/test/unit/jit/test_tracing.py
+3
-4
imperative/python/test/unit/utils/test_cgtools.py
imperative/python/test/unit/utils/test_cgtools.py
+5
-3
imperative/python/test/unit/utils/test_dump_naming.py
imperative/python/test/unit/utils/test_dump_naming.py
+2
-2
imperative/python/test/unit/utils/test_network.py
imperative/python/test/unit/utils/test_network.py
+44
-0
sdk/load-and-run/dump_with_testcase_mge.py
sdk/load-and-run/dump_with_testcase_mge.py
+2
-1
src/gopt/include/megbrain/gopt/inference.h
src/gopt/include/megbrain/gopt/inference.h
+25
-1
src/serialization/impl/schema.fbs
src/serialization/impl/schema.fbs
+8
-0
src/serialization/impl/serializer_oss.cpp
src/serialization/impl/serializer_oss.cpp
+40
-2
src/serialization/include/megbrain/serialization/metadata.h
src/serialization/include/megbrain/serialization/metadata.h
+46
-0
src/serialization/include/megbrain/serialization/serializer.h
...serialization/include/megbrain/serialization/serializer.h
+6
-1
src/serialization/test/serializer_oss.cpp
src/serialization/test/serializer_oss.cpp
+37
-0
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
54a4d70e
...
@@ -11,7 +11,7 @@ import json
...
@@ -11,7 +11,7 @@ import json
import
os
import
os
import
weakref
import
weakref
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
...
@@ -256,6 +256,9 @@ def optimize_for_inference(dest_vars, **kwargs):
...
@@ -256,6 +256,9 @@ def optimize_for_inference(dest_vars, **kwargs):
* enable_chwn4 --
* enable_chwn4 --
whether to use CHWN4 data layout, currently
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
into one opr.
...
@@ -273,6 +276,7 @@ def optimize_for_inference(dest_vars, **kwargs):
...
@@ -273,6 +276,7 @@ def optimize_for_inference(dest_vars, **kwargs):
"enable_nchw44"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW44
,
"enable_nchw44"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW44
,
"enable_nchw44_dot"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW44_DOT
,
"enable_nchw44_dot"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW44_DOT
,
"enable_chwn4"
:
GraphOptimizeOptions
.
LayoutTransform
.
CHWN4
,
"enable_chwn4"
:
GraphOptimizeOptions
.
LayoutTransform
.
CHWN4
,
"enable_nchw64"
:
GraphOptimizeOptions
.
LayoutTransform
.
NCHW64
,
}
}
for
k
,
v
in
inference_optimize_layout_transform_map
.
items
():
for
k
,
v
in
inference_optimize_layout_transform_map
.
items
():
...
@@ -293,7 +297,46 @@ def optimize_for_inference(dest_vars, **kwargs):
...
@@ -293,7 +297,46 @@ def optimize_for_inference(dest_vars, **kwargs):
dest_vars
=
_unwrap
(
dest_vars
)
dest_vars
=
_unwrap
(
dest_vars
)
res_vars
=
_imperative_rt
.
optimize_for_inference
(
dest_vars
,
inference_options
)
res_vars
=
_imperative_rt
.
optimize_for_inference
(
dest_vars
,
inference_options
)
return
_wrap
(
res_vars
)
return
_wrap
(
res_vars
),
inference_options
.
serialize
()
def
deserialize_infer_option
(
x
:
int
)
->
Dict
[
str
,
bool
]:
r
"""
Deserailize optimize options generated by ``imperative_rt.GraphOptimizeOptions``.
:param x: inference options represented by int.
:return: inference options represented by dict.
"""
inference_options
=
GraphOptimizeOptions
.
deserialize
(
x
)
inference_optimize_layout_transform_map
=
{
GraphOptimizeOptions
.
LayoutTransform
.
NHWCD4
:
"enable_hwcd4"
,
GraphOptimizeOptions
.
LayoutTransform
.
NCHW4
:
"enable_nchw4"
,
GraphOptimizeOptions
.
LayoutTransform
.
NCHW88
:
"enable_nchw88"
,
GraphOptimizeOptions
.
LayoutTransform
.
NCHW32
:
"enable_nchw32"
,
GraphOptimizeOptions
.
LayoutTransform
.
NCHW44
:
"enable_nchw44"
,
GraphOptimizeOptions
.
LayoutTransform
.
NCHW44_DOT
:
"enable_nchw44_dot"
,
GraphOptimizeOptions
.
LayoutTransform
.
CHWN4
:
"enable_chwn4"
,
GraphOptimizeOptions
.
LayoutTransform
.
NCHW64
:
"enable_nchw64"
,
}
ret
=
dict
()
layout
=
inference_options
.
layout_transform
if
layout
!=
GraphOptimizeOptions
.
LayoutTransform
.
DEFAULT
:
ret
[
inference_optimize_layout_transform_map
[
layout
]]
=
True
if
inference_options
.
f16_io_f32_comp
:
ret
[
"enable_io16xc32"
]
=
True
if
inference_options
.
f16_io_comp
:
ret
[
"enable_ioc16"
]
=
True
if
inference_options
.
fuse_conv_bias_nonlinearity
:
ret
[
"enable_fuse_conv_bias_nonlinearity"
]
=
True
if
inference_options
.
fuse_conv_bias_with_z
:
ret
[
"enable_fuse_conv_bias_with_z"
]
=
True
return
ret
def
modify_opr_algo_strategy_inplace
(
dest_vars
,
strategy
:
str
):
def
modify_opr_algo_strategy_inplace
(
dest_vars
,
strategy
:
str
):
...
@@ -331,7 +374,8 @@ def dump_graph(
...
@@ -331,7 +374,8 @@ def dump_graph(
keep_param_name
:
bool
=
False
,
keep_param_name
:
bool
=
False
,
keep_opr_priority
:
bool
=
False
,
keep_opr_priority
:
bool
=
False
,
strip_info_file
=
None
,
strip_info_file
=
None
,
append_json
=
False
append_json
=
False
,
metadata
=
None
)
->
Tuple
[
bytes
,
CompGraphDumpResult
]:
)
->
Tuple
[
bytes
,
CompGraphDumpResult
]:
"""
"""
serialize the computing graph of `output_vars` and get byte result.
serialize the computing graph of `output_vars` and get byte result.
...
@@ -393,6 +437,7 @@ def dump_graph(
...
@@ -393,6 +437,7 @@ def dump_graph(
keep_opr_name
,
keep_opr_name
,
keep_param_name
,
keep_param_name
,
keep_opr_priority
,
keep_opr_priority
,
metadata
,
stat
,
stat
,
inputs
,
inputs
,
outputs
,
outputs
,
...
@@ -427,7 +472,7 @@ def dump_graph(
...
@@ -427,7 +472,7 @@ def dump_graph(
CompGraphLoadResult
=
collections
.
namedtuple
(
CompGraphLoadResult
=
collections
.
namedtuple
(
"CompGraphLoadResult"
,
[
"graph"
,
"output_vars_dict"
,
"output_vars_list"
]
"CompGraphLoadResult"
,
[
"graph"
,
"output_vars_dict"
,
"output_vars_list"
,
"metadata"
]
)
)
...
@@ -450,8 +495,8 @@ def load_graph(fpath) -> CompGraphLoadResult:
...
@@ -450,8 +495,8 @@ def load_graph(fpath) -> CompGraphLoadResult:
buf
=
open
(
fpath
,
"rb"
).
read
()
buf
=
open
(
fpath
,
"rb"
).
read
()
else
:
else
:
buf
=
fpath
.
read
()
buf
=
fpath
.
read
()
cg
=
_imperative_rt
.
load_graph
(
buf
,
output_vars_map
,
output_vars_list
)
cg
,
metadata
=
_imperative_rt
.
load_graph
(
buf
,
output_vars_map
,
output_vars_list
)
return
CompGraphLoadResult
(
cg
,
dict
(
output_vars_map
),
output_vars_list
)
return
CompGraphLoadResult
(
cg
,
dict
(
output_vars_map
),
output_vars_list
,
metadata
)
def
_wrap
(
x
):
def
_wrap
(
x
):
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
54a4d70e
...
@@ -12,10 +12,12 @@ import functools
...
@@ -12,10 +12,12 @@ import functools
import
itertools
import
itertools
import
json
import
json
import
os
import
os
import
pickle
from
typing
import
Any
import
numpy
as
np
import
numpy
as
np
from
..core._imperative_rt
import
GraphProfiler
from
..core._imperative_rt
import
GraphProfiler
,
SerializationMetadata
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
(
from
..core._imperative_rt.core2
import
(
TensorWeakRef
,
TensorWeakRef
,
...
@@ -670,6 +672,8 @@ class trace:
...
@@ -670,6 +672,8 @@ class trace:
strip_info_file
=
None
,
strip_info_file
=
None
,
append_json
=
False
,
append_json
=
False
,
optimize_for_inference
=
True
,
optimize_for_inference
=
True
,
user_info
:
Any
=
None
,
enable_metadata
:
bool
=
True
,
**
kwargs
**
kwargs
):
):
r
"""
r
"""
...
@@ -697,6 +701,8 @@ class trace:
...
@@ -697,6 +701,8 @@ class trace:
if set false, will rewrite strip_info_file
if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations,
:param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
will skip all optimize options if this is False. Default: True
:param user_info: any type object, which will be pickled to bytes.
:param enable_metadata: whether to save metadata into output file.
:Keyword Arguments:
:Keyword Arguments:
...
@@ -729,6 +735,9 @@ class trace:
...
@@ -729,6 +735,9 @@ class trace:
* enable_chwn4 --
* enable_chwn4 --
whether to use CHWN4 data layout, currently
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
into one opr.
...
@@ -851,7 +860,15 @@ class trace:
...
@@ -851,7 +860,15 @@ class trace:
dest_vars
.
append
(
v
)
dest_vars
.
append
(
v
)
if
optimize_for_inference
:
if
optimize_for_inference
:
dest_vars
=
G
.
optimize_for_inference
(
dest_vars
,
**
kwargs
)
dest_vars
,
optimize_options
=
G
.
optimize_for_inference
(
dest_vars
,
**
kwargs
)
metadata
=
SerializationMetadata
()
if
enable_metadata
:
metadata
.
user_info
=
pickle
.
dumps
(
user_info
)
metadata
.
is_valid
=
True
metadata
.
graph_modified
=
False
if
optimize_for_inference
:
metadata
.
optimize_options
=
optimize_options
if
isinstance
(
file
,
str
):
if
isinstance
(
file
,
str
):
permission
=
"wb"
if
append
==
False
else
"ab"
permission
=
"wb"
if
append
==
False
else
"ab"
...
@@ -864,6 +881,7 @@ class trace:
...
@@ -864,6 +881,7 @@ class trace:
keep_opr_priority
=
keep_opr_priority
,
keep_opr_priority
=
keep_opr_priority
,
strip_info_file
=
strip_info_file
,
strip_info_file
=
strip_info_file
,
append_json
=
append_json
,
append_json
=
append_json
,
metadata
=
metadata
,
)
)
file
.
write
(
dump_content
)
file
.
write
(
dump_content
)
return
dump_info
return
dump_info
...
...
imperative/python/megengine/tools/load_network_and_run.py
浏览文件 @
54a4d70e
...
@@ -411,7 +411,8 @@ def main():
...
@@ -411,7 +411,8 @@ def main():
args
.
embed_input
=
True
args
.
embed_input
=
True
logger
.
info
(
"loading model ..."
)
logger
.
info
(
"loading model ..."
)
graph
,
_
,
output_vars
=
G
.
load_graph
(
args
.
net
)
ret
=
G
.
load_graph
(
args
.
net
)
graph
,
output_vars
=
ret
.
graph
,
ret
.
output_vars_list
input_vars
=
tools
.
get_dep_vars
(
output_vars
,
"Host2DeviceCopy"
)
input_vars
=
tools
.
get_dep_vars
(
output_vars
,
"Host2DeviceCopy"
)
if
args
.
output_name
is
not
None
:
if
args
.
output_name
is
not
None
:
...
...
imperative/python/megengine/utils/comp_graph_tools.py
浏览文件 @
54a4d70e
...
@@ -391,7 +391,8 @@ class GraphInference:
...
@@ -391,7 +391,8 @@ class GraphInference:
optimize_for_inference
:
bool
=
False
,
optimize_for_inference
:
bool
=
False
,
**
kwargs
**
kwargs
):
):
self
.
_graph
,
_
,
output_nodes
=
G
.
load_graph
(
file
)
ret
=
G
.
load_graph
(
file
)
self
.
_graph
,
output_nodes
=
ret
.
graph
,
ret
.
output_vars_list
if
outputs
is
not
None
:
if
outputs
is
not
None
:
output_nodes
=
find_vars_by_name
(
output_nodes
,
outputs
)
output_nodes
=
find_vars_by_name
(
output_nodes
,
outputs
)
self
.
_origin_outputs
=
output_nodes
self
.
_origin_outputs
=
output_nodes
...
...
imperative/python/megengine/utils/network.py
浏览文件 @
54a4d70e
...
@@ -9,14 +9,12 @@
...
@@ -9,14 +9,12 @@
import
collections
import
collections
import
fnmatch
import
fnmatch
import
itertools
import
itertools
import
pickle
import
re
import
re
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Dict
,
List
,
Sequence
from
typing
import
Any
,
Dict
,
List
,
Sequence
import
numpy
as
np
from
..core._imperative_rt
import
ComputingGraph
,
SerializationMetadata
from
..core._imperative_rt
import
ComputingGraph
from
..core._imperative_rt.core2
import
SymbolVar
from
..core._trace_option
import
set_symbolic_shape
as
_set_symbolic_shape
from
..core._trace_option
import
set_symbolic_shape
as
_set_symbolic_shape
from
..core.tensor
import
megbrain_graph
as
G
from
..core.tensor
import
megbrain_graph
as
G
from
..logger
import
get_logger
from
..logger
import
get_logger
...
@@ -42,6 +40,30 @@ class Network:
...
@@ -42,6 +40,30 @@ class Network:
self
.
all_oprs_map
=
OrderedDict
()
self
.
all_oprs_map
=
OrderedDict
()
self
.
all_vars_map
=
OrderedDict
()
self
.
all_vars_map
=
OrderedDict
()
self
.
graph
=
ComputingGraph
()
self
.
graph
=
ComputingGraph
()
self
.
_metadata
=
None
@
property
def
metadata
(
self
):
r
"""
Load metadata as a dict.
"""
if
not
self
.
_metadata
.
is_valid
:
logger
.
info
(
"metadata is not valid!"
)
return
None
ret
=
dict
()
try
:
user_info
=
pickle
.
loads
(
self
.
_metadata
.
user_info
)
except
:
# pylint: disable=bare-except
logger
.
warning
(
"can't parse user info by pickle, so return the original bytes object!"
)
user_info
=
self
.
_metadata
.
user_info
ret
[
"user_info"
]
=
user_info
ret
[
"graph_modified"
]
=
self
.
_metadata
.
graph_modified
ret
[
"optimized_for_inference"
]
=
self
.
_metadata
.
optimized_for_inference
if
ret
[
"optimized_for_inference"
]:
ret
.
update
(
G
.
deserialize_infer_option
(
self
.
_metadata
.
optimize_options
))
return
ret
@
classmethod
@
classmethod
def
load
(
cls
,
model_path
:
str
,
outspec
:
List
[
str
]
=
None
):
def
load
(
cls
,
model_path
:
str
,
outspec
:
List
[
str
]
=
None
):
...
@@ -51,7 +73,8 @@ class Network:
...
@@ -51,7 +73,8 @@ class Network:
:param outspec: only load the subgraph with outspec as its endpoints.
:param outspec: only load the subgraph with outspec as its endpoints.
"""
"""
self
=
cls
()
self
=
cls
()
_
,
_
,
outputs
=
G
.
load_graph
(
model_path
)
ret
=
G
.
load_graph
(
model_path
)
outputs
,
self
.
_metadata
=
ret
.
output_vars_list
,
ret
.
metadata
if
outspec
is
not
None
:
if
outspec
is
not
None
:
output_spec
=
outspec
.
copy
()
output_spec
=
outspec
.
copy
()
all_vars
=
get_dep_vars
(
outputs
)
+
outputs
all_vars
=
get_dep_vars
(
outputs
)
+
outputs
...
@@ -125,6 +148,9 @@ class Network:
...
@@ -125,6 +148,9 @@ class Network:
* enable_chwn4 --
* enable_chwn4 --
whether to use CHWN4 data layout, currently
whether to use CHWN4 data layout, currently
used in nvidia backend with tensorcore.
used in nvidia backend with tensorcore.
* enable_nchw64 --
whether to use NCHW64 data layout, used for fast int4
support on Nvidia GPU.
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
* enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
into one opr.
into one opr.
...
@@ -152,6 +178,8 @@ class Network:
...
@@ -152,6 +178,8 @@ class Network:
append_json
=
False
,
append_json
=
False
,
optimize_for_inference
=
True
,
optimize_for_inference
=
True
,
append
=
False
,
append
=
False
,
user_info
:
Any
=
None
,
enable_metadata
=
True
,
**
kwargs
**
kwargs
):
):
"""
"""
...
@@ -176,6 +204,8 @@ class Network:
...
@@ -176,6 +204,8 @@ class Network:
if set false, will rewrite strip_info_file
if set false, will rewrite strip_info_file
:param optimize_for_inference: enbale optmizations,
:param optimize_for_inference: enbale optmizations,
will skip all optimize options if this is False. Default: True
will skip all optimize options if this is False. Default: True
:param user_info: any type object, which will be pickled to bytes.
:param enable_metadata: whether to save metadata into output file.
:Keyword Arguments:
:Keyword Arguments:
...
@@ -201,7 +231,15 @@ class Network:
...
@@ -201,7 +231,15 @@ class Network:
)
)
if
optimize_for_inference
:
if
optimize_for_inference
:
out
=
G
.
optimize_for_inference
(
out
,
**
kwargs
)
out
,
optimize_options
=
G
.
optimize_for_inference
(
out
,
**
kwargs
)
metadata
=
SerializationMetadata
()
if
enable_metadata
:
metadata
.
is_valid
=
True
metadata
.
graph_modified
=
True
metadata
.
user_info
=
pickle
.
dumps
(
user_info
)
if
optimize_for_inference
:
metadata
.
optimize_options
=
optimize_options
dump_content
,
_
=
G
.
dump_graph
(
dump_content
,
_
=
G
.
dump_graph
(
out
,
out
,
...
@@ -211,6 +249,7 @@ class Network:
...
@@ -211,6 +249,7 @@ class Network:
keep_opr_priority
=
keep_opr_priority
,
keep_opr_priority
=
keep_opr_priority
,
strip_info_file
=
strip_info_file
,
strip_info_file
=
strip_info_file
,
append_json
=
append_json
,
append_json
=
append_json
,
metadata
=
metadata
,
)
)
if
isinstance
(
file
,
str
):
if
isinstance
(
file
,
str
):
permission
=
"wb"
if
append
==
False
else
"ab"
permission
=
"wb"
if
append
==
False
else
"ab"
...
...
imperative/python/src/graph_rt.cpp
浏览文件 @
54a4d70e
...
@@ -34,6 +34,7 @@ namespace ser = mgb::serialization;
...
@@ -34,6 +34,7 @@ namespace ser = mgb::serialization;
using
_OptimizeForInferenceOptions
=
mgb
::
gopt
::
OptimizeForInferenceOptions
;
using
_OptimizeForInferenceOptions
=
mgb
::
gopt
::
OptimizeForInferenceOptions
;
using
_LayoutTransform
=
_OptimizeForInferenceOptions
::
LayoutTransform
;
using
_LayoutTransform
=
_OptimizeForInferenceOptions
::
LayoutTransform
;
using
_AlgoStrategy
=
opr
::
mixin
::
AlgoChooserHelper
::
ExecutionPolicy
::
Strategy
;
using
_AlgoStrategy
=
opr
::
mixin
::
AlgoChooserHelper
::
ExecutionPolicy
::
Strategy
;
using
_SerializationMetadata
=
mgb
::
serialization
::
Metadata
;
namespace
{
namespace
{
class
_CompGraphProfilerImpl
{
class
_CompGraphProfilerImpl
{
...
@@ -240,6 +241,8 @@ void init_graph_rt(py::module m) {
...
@@ -240,6 +241,8 @@ void init_graph_rt(py::module m) {
auto
GraphOptimizeOptions
=
py
::
class_
<
_OptimizeForInferenceOptions
>
(
m
,
"GraphOptimizeOptions"
)
auto
GraphOptimizeOptions
=
py
::
class_
<
_OptimizeForInferenceOptions
>
(
m
,
"GraphOptimizeOptions"
)
.
def
(
py
::
init
())
.
def
(
py
::
init
())
.
def
(
"serialize"
,
&
_OptimizeForInferenceOptions
::
serialize
)
.
def_static
(
"deserialize"
,
&
_OptimizeForInferenceOptions
::
deserialize
)
.
def_readwrite
(
"f16_io_f32_comp"
,
&
_OptimizeForInferenceOptions
::
f16_io_f32_comp
)
.
def_readwrite
(
"f16_io_f32_comp"
,
&
_OptimizeForInferenceOptions
::
f16_io_f32_comp
)
.
def_readwrite
(
"f16_io_comp"
,
&
_OptimizeForInferenceOptions
::
f16_io_comp
)
.
def_readwrite
(
"f16_io_comp"
,
&
_OptimizeForInferenceOptions
::
f16_io_comp
)
.
def_readwrite
(
"fuse_conv_bias_nonlinearity"
,
&
_OptimizeForInferenceOptions
::
fuse_conv_bias_nonlinearity
)
.
def_readwrite
(
"fuse_conv_bias_nonlinearity"
,
&
_OptimizeForInferenceOptions
::
fuse_conv_bias_nonlinearity
)
...
@@ -256,6 +259,7 @@ void init_graph_rt(py::module m) {
...
@@ -256,6 +259,7 @@ void init_graph_rt(py::module m) {
.
value
(
"NCHW44_DOT"
,
_LayoutTransform
::
NCHW44_DOT
)
.
value
(
"NCHW44_DOT"
,
_LayoutTransform
::
NCHW44_DOT
)
.
value
(
"NCHW32"
,
_LayoutTransform
::
NCHW32
)
.
value
(
"NCHW32"
,
_LayoutTransform
::
NCHW32
)
.
value
(
"CHWN4"
,
_LayoutTransform
::
CHWN4
)
.
value
(
"CHWN4"
,
_LayoutTransform
::
CHWN4
)
.
value
(
"NCHW64"
,
_LayoutTransform
::
NCHW64
)
.
export_values
()
.
export_values
()
;
;
...
@@ -307,12 +311,24 @@ void init_graph_rt(py::module m) {
...
@@ -307,12 +311,24 @@ void init_graph_rt(py::module m) {
})
->
to_string
();
})
->
to_string
();
});
});
py
::
class_
<
_SerializationMetadata
>
(
m
,
"SerializationMetadata"
)
.
def
(
py
::
init
())
.
def_property
(
"user_info"
,
[](
const
_SerializationMetadata
&
meta
){
return
py
::
bytes
(
meta
.
get_user_info
());
},
&
_SerializationMetadata
::
set_user_info
)
.
def_readonly
(
"optimized_for_inference"
,
&
_SerializationMetadata
::
optimized_for_inference
)
.
def_property
(
"optimize_options"
,
&
_SerializationMetadata
::
get_optimize_options
,
&
_SerializationMetadata
::
set_optimize_options
)
.
def_readwrite
(
"graph_modified"
,
&
_SerializationMetadata
::
graph_modified
)
.
def_readwrite
(
"is_valid"
,
&
_SerializationMetadata
::
is_valid
)
;
m
.
def
(
"dump_graph"
,
[](
m
.
def
(
"dump_graph"
,
[](
const
std
::
vector
<
VarNode
*>&
dest_vars
,
const
std
::
vector
<
VarNode
*>&
dest_vars
,
int
keep_var_name
,
int
keep_var_name
,
bool
keep_opr_name
,
bool
keep_opr_name
,
bool
keep_param_name
,
bool
keep_param_name
,
bool
keep_opr_priority
,
bool
keep_opr_priority
,
std
::
optional
<
_SerializationMetadata
>
metadata
,
py
::
list
&
stat
,
py
::
list
&
stat
,
py
::
list
&
inputs
,
py
::
list
&
inputs
,
py
::
list
&
outputs
,
py
::
list
&
outputs
,
...
@@ -325,7 +341,12 @@ void init_graph_rt(py::module m) {
...
@@ -325,7 +341,12 @@ void init_graph_rt(py::module m) {
ser
::
GraphDumper
::
DumpConfig
config
{
keep_var_name
,
keep_param_name
,
ser
::
GraphDumper
::
DumpConfig
config
{
keep_var_name
,
keep_param_name
,
keep_opr_priority
,
keep_opr_name
};
keep_opr_priority
,
keep_opr_name
};
auto
rst
=
dumper
->
dump
(
symvars
,
config
);
ser
::
GraphDumper
::
DumpResult
rst
;
if
(
metadata
)
rst
=
dumper
->
dump
(
symvars
,
config
,
*
metadata
);
else
rst
=
dumper
->
dump
(
symvars
,
config
);
for
(
auto
i
:
rst
.
inputs
)
{
for
(
auto
i
:
rst
.
inputs
)
{
inputs
.
append
(
py
::
cast
(
i
));
inputs
.
append
(
py
::
cast
(
i
));
}
}
...
@@ -377,8 +398,10 @@ void init_graph_rt(py::module m) {
...
@@ -377,8 +398,10 @@ void init_graph_rt(py::module m) {
for
(
const
auto
&
var
:
rst
.
output_var_list
)
{
for
(
const
auto
&
var
:
rst
.
output_var_list
)
{
iter
.
add
(
var
);
iter
.
add
(
var
);
}
}
return
rst
.
graph
;
auto
ret
=
py
::
tuple
(
2
);
ret
[
0
]
=
py
::
cast
(
rst
.
graph
);
ret
[
1
]
=
py
::
cast
(
rst
.
metadata
);
return
ret
;
});
});
#define CURRENT_CLASS cg::ComputingGraph::Options
#define CURRENT_CLASS cg::ComputingGraph::Options
...
...
imperative/python/test/unit/jit/test_tracing.py
浏览文件 @
54a4d70e
...
@@ -239,8 +239,7 @@ def test_dump_volatile():
...
@@ -239,8 +239,7 @@ def test_dump_volatile():
file
=
io
.
BytesIO
()
file
=
io
.
BytesIO
()
f
.
dump
(
file
,
optimize_for_inference
=
False
)
f
.
dump
(
file
,
optimize_for_inference
=
False
)
file
.
seek
(
0
)
file
.
seek
(
0
)
cg
,
_
,
outputs
=
G
.
load_graph
(
file
)
(
out
,)
=
G
.
load_graph
(
file
).
output_vars_list
(
out
,)
=
outputs
assert
(
assert
(
cgtools
.
get_owner_opr_type
(
cgtools
.
get_owner_opr_inputs
(
out
)[
1
])
cgtools
.
get_owner_opr_type
(
cgtools
.
get_owner_opr_inputs
(
out
)[
1
])
==
"ImmutableTensor"
==
"ImmutableTensor"
...
@@ -337,12 +336,12 @@ def test_goptions_log_exp():
...
@@ -337,12 +336,12 @@ def test_goptions_log_exp():
f
(
tensor
(
1.0
))
f
(
tensor
(
1.0
))
_
,
out
=
mkstemp
()
_
,
out
=
mkstemp
()
f
.
dump
(
out
,
optimize_for_inference
=
False
)
f
.
dump
(
out
,
optimize_for_inference
=
False
)
*
_
,
outputs
=
G
.
load_graph
(
out
)
outputs
=
G
.
load_graph
(
out
).
output_vars_list
oprs_1
=
cgtools
.
get_oprs_seq
(
outputs
)
oprs_1
=
cgtools
.
get_oprs_seq
(
outputs
)
g
(
tensor
(
1.0
))
g
(
tensor
(
1.0
))
g
.
dump
(
out
,
optimize_for_inference
=
False
)
g
.
dump
(
out
,
optimize_for_inference
=
False
)
*
_
,
outputs
=
G
.
load_graph
(
out
)
outputs
=
G
.
load_graph
(
out
).
output_vars_list
oprs_2
=
cgtools
.
get_oprs_seq
(
outputs
)
oprs_2
=
cgtools
.
get_oprs_seq
(
outputs
)
assert
len
(
oprs_1
)
-
len
(
oprs_2
)
==
2
assert
len
(
oprs_1
)
-
len
(
oprs_2
)
==
2
...
...
imperative/python/test/unit/utils/test_cgtools.py
浏览文件 @
54a4d70e
...
@@ -88,7 +88,7 @@ def test_graph_traversal():
...
@@ -88,7 +88,7 @@ def test_graph_traversal():
file
=
io
.
BytesIO
()
file
=
io
.
BytesIO
()
fun
.
dump
(
file
,
optimize_for_inference
=
False
)
fun
.
dump
(
file
,
optimize_for_inference
=
False
)
file
.
seek
(
0
)
file
.
seek
(
0
)
cg
,
_
,
outputs
=
mgb_graph
.
load_graph
(
file
)
outputs
=
mgb_graph
.
load_graph
(
file
).
output_vars_list
_
,
map_vars
,
var2oprs
,
*
_
=
cgtools
.
graph_traversal
(
outputs
)
_
,
map_vars
,
var2oprs
,
*
_
=
cgtools
.
graph_traversal
(
outputs
)
input_var
=
map_vars
[
1
]
input_var
=
map_vars
[
1
]
...
@@ -101,7 +101,9 @@ def test_load_refcnt():
...
@@ -101,7 +101,9 @@ def test_load_refcnt():
graph
=
mgb_graph
.
Graph
()
graph
=
mgb_graph
.
Graph
()
varnode
=
graph
.
make_const
(
0
)
varnode
=
graph
.
make_const
(
0
)
buf
,
_
=
mgb_graph
.
dump_graph
([
varnode
])
buf
,
_
=
mgb_graph
.
dump_graph
([
varnode
])
graph
,
_
,
(
varnode
,)
=
mgb_graph
.
load_graph
(
io
.
BytesIO
(
buf
))
ret
=
mgb_graph
.
load_graph
(
io
.
BytesIO
(
buf
))
graph
,
(
varnode
,)
=
ret
.
graph
,
ret
.
output_vars_list
del
ret
del
graph
del
graph
varnode
.
owner
varnode
.
owner
...
@@ -132,7 +134,7 @@ def test_get_opr_seq():
...
@@ -132,7 +134,7 @@ def test_get_opr_seq():
file
=
io
.
BytesIO
()
file
=
io
.
BytesIO
()
func
.
dump
(
file
,
optimize_for_inference
=
False
)
func
.
dump
(
file
,
optimize_for_inference
=
False
)
file
.
seek
(
0
)
file
.
seek
(
0
)
*
_
,
outputs
=
mgb_graph
.
load_graph
(
file
)
outputs
=
mgb_graph
.
load_graph
(
file
).
output_vars_list
seq_1
=
cgtools
.
get_oprs_seq
(
outputs
,
True
)
seq_1
=
cgtools
.
get_oprs_seq
(
outputs
,
True
)
assert
len
(
seq_1
)
==
5
assert
len
(
seq_1
)
==
5
...
...
imperative/python/test/unit/utils/test_dump_naming.py
浏览文件 @
54a4d70e
...
@@ -35,7 +35,7 @@ def _dump_and_load(func, symbolic, keep_opr_name=True):
...
@@ -35,7 +35,7 @@ def _dump_and_load(func, symbolic, keep_opr_name=True):
keep_var_name
=
2
,
keep_var_name
=
2
,
)
)
file
.
seek
(
0
)
file
.
seek
(
0
)
*
_
,
outputs
=
G
.
load_graph
(
file
)
outputs
=
G
.
load_graph
(
file
).
output_vars_list
ops
=
cgtools
.
get_oprs_seq
(
outputs
)
ops
=
cgtools
.
get_oprs_seq
(
outputs
)
return
ops
return
ops
...
@@ -223,7 +223,7 @@ def test_catch_input_name(tensor_name, var_name):
...
@@ -223,7 +223,7 @@ def test_catch_input_name(tensor_name, var_name):
file
=
io
.
BytesIO
()
file
=
io
.
BytesIO
()
func
.
dump
(
file
,
optimize_for_inference
=
False
,
keep_opr_name
=
True
,
keep_var_name
=
2
)
func
.
dump
(
file
,
optimize_for_inference
=
False
,
keep_opr_name
=
True
,
keep_var_name
=
2
)
file
.
seek
(
0
)
file
.
seek
(
0
)
*
_
,
outputs
=
G
.
load_graph
(
file
)
outputs
=
G
.
load_graph
(
file
).
output_vars_list
op
=
cgtools
.
get_oprs_seq
(
outputs
)[
-
1
]
op
=
cgtools
.
get_oprs_seq
(
outputs
)[
-
1
]
assert
op
.
inputs
[
0
].
name
==
var_name
assert
op
.
inputs
[
0
].
name
==
var_name
...
...
imperative/python/test/unit/utils/test_network.py
浏览文件 @
54a4d70e
...
@@ -14,6 +14,50 @@ from megengine.utils.network import as_oprnode, set_symbolic_shape
...
@@ -14,6 +14,50 @@ from megengine.utils.network import as_oprnode, set_symbolic_shape
from
megengine.utils.network_node
import
Host2DeviceCopy
,
VarNode
from
megengine.utils.network_node
import
Host2DeviceCopy
,
VarNode
def
test_metadata
():
x
=
Tensor
(
0
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
fwd
(
x
):
return
x
*
2
fwd
(
x
)
orig_model
=
io
.
BytesIO
()
fwd
.
dump
(
orig_model
,
user_info
=
"test"
,
optimize_for_inference
=
False
)
orig_model
.
seek
(
0
)
graph
=
Net
.
load
(
orig_model
)
assert
graph
.
metadata
==
{
"user_info"
:
"test"
,
"graph_modified"
:
False
,
# False: tracing.dump
"optimized_for_inference"
:
False
,
}
orig_model
.
seek
(
0
)
graph
.
dump
(
orig_model
,
user_info
=
{
"str"
:
"x"
,
"tensor"
:
x
,
"module"
:
M
.
Module
,
"none"
:
None
},
optimize_for_inference
=
True
,
enable_nchw4
=
True
,
enable_ioc16
=
True
,
)
orig_model
.
seek
(
0
)
graph
=
Net
.
load
(
orig_model
)
assert
graph
.
metadata
==
{
"user_info"
:
{
"str"
:
"x"
,
"tensor"
:
x
,
"module"
:
M
.
Module
,
"none"
:
None
},
"graph_modified"
:
True
,
# True: Network.dump
"optimized_for_inference"
:
True
,
"enable_nchw4"
:
True
,
"enable_ioc16"
:
True
,
}
orig_model
.
seek
(
0
)
fwd
.
dump
(
orig_model
,
enable_metadata
=
False
)
orig_model
.
seek
(
0
)
graph
=
Net
.
load
(
orig_model
)
assert
graph
.
metadata
is
None
def
test_replace_var
():
def
test_replace_var
():
a
=
Tensor
([
1
,
2
])
a
=
Tensor
([
1
,
2
])
...
...
sdk/load-and-run/dump_with_testcase_mge.py
浏览文件 @
54a4d70e
...
@@ -170,7 +170,8 @@ def gen_one_testcase(args, inputs, spec):
...
@@ -170,7 +170,8 @@ def gen_one_testcase(args, inputs, spec):
def
make_feeds
(
args
):
def
make_feeds
(
args
):
cg_rt
,
_
,
outputs
=
G
.
load_graph
(
args
.
input
)
ret
=
G
.
load_graph
(
args
.
input
)
cg_rt
,
outputs
=
ret
.
graph
,
ret
.
output_vars_list
inputs
=
cgtools
.
get_dep_vars
(
outputs
,
"Host2DeviceCopy"
)
inputs
=
cgtools
.
get_dep_vars
(
outputs
,
"Host2DeviceCopy"
)
inputs
=
{
i
.
name
:
i
for
i
in
inputs
}
inputs
=
{
i
.
name
:
i
for
i
in
inputs
}
...
...
src/gopt/include/megbrain/gopt/inference.h
浏览文件 @
54a4d70e
...
@@ -322,7 +322,31 @@ namespace gopt {
...
@@ -322,7 +322,31 @@ namespace gopt {
static
std
::
unique_ptr
<
EnableNchw44DotPass
>
make_nchw44_dot_converter
();
static
std
::
unique_ptr
<
EnableNchw44DotPass
>
make_nchw44_dot_converter
();
};
};
struct
OptimizeForInferenceOptions
:
cg
::
GraphCommonOptimizeOptions
{};
struct
OptimizeForInferenceOptions
:
cg
::
GraphCommonOptimizeOptions
{
uint64_t
serialize
()
{
uint64_t
ret
=
0
;
ret
|=
(
uint64_t
)
layout_transform
<<
32
;
if
(
f16_io_f32_comp
)
ret
|=
1u
;
if
(
f16_io_comp
)
ret
|=
1u
<<
1
;
if
(
fuse_conv_bias_nonlinearity
)
ret
|=
1u
<<
2
;
if
(
fuse_conv_bias_with_z
)
ret
|=
1u
<<
3
;
if
(
weight_preprocess
)
ret
|=
1u
<<
4
;
if
(
fuse_preprocess
)
ret
|=
1u
<<
5
;
return
ret
;
}
static
OptimizeForInferenceOptions
deserialize
(
uint64_t
buf
)
{
OptimizeForInferenceOptions
ret
;
ret
.
f16_io_f32_comp
=
buf
&
1u
;
ret
.
f16_io_comp
=
buf
&
1u
<<
1
;
ret
.
fuse_conv_bias_nonlinearity
=
buf
&
1u
<<
2
;
ret
.
fuse_conv_bias_with_z
=
buf
&
1u
<<
3
;
ret
.
weight_preprocess
=
buf
&
1u
<<
4
;
ret
.
fuse_preprocess
=
buf
&
1u
<<
5
;
ret
.
layout_transform
=
(
LayoutTransform
)(
buf
>>
32
);
return
ret
;
}
};
/*!
/*!
* \brief optimize a computing graph for inference
* \brief optimize a computing graph for inference
...
...
src/serialization/impl/schema.fbs
浏览文件 @
54a4d70e
...
@@ -128,6 +128,13 @@ table Operator {
...
@@ -128,6 +128,13 @@ table Operator {
name:string;
name:string;
}
}
table Metadata {
is_valid:bool;
graph_modified:bool;
user_info:string;
optimize_options:ulong;
}
struct OutputVar {
struct OutputVar {
compact_id:uint;
compact_id:uint;
original_id:uint;
original_id:uint;
...
@@ -141,6 +148,7 @@ table Graph {
...
@@ -141,6 +148,7 @@ table Graph {
nr_shared_tensor:uint;
nr_shared_tensor:uint;
oprs:[Operator];
oprs:[Operator];
output_vars_idx:[OutputVar];
output_vars_idx:[OutputVar];
metadata:Metadata;
}
}
root_type Graph;
root_type Graph;
src/serialization/impl/serializer_oss.cpp
浏览文件 @
54a4d70e
...
@@ -30,6 +30,7 @@
...
@@ -30,6 +30,7 @@
#include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/flatbuffers_helper.h"
#include "megbrain/serialization/internal/schema_generated.h"
#include "megbrain/serialization/internal/schema_generated.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/opr_load_dump.h"
#include "megbrain/serialization/metadata.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/version.h"
#include "megbrain/version.h"
...
@@ -115,6 +116,7 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
...
@@ -115,6 +116,7 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
std
::
vector
<
flatbuffers
::
Offset
<
void
>>
m_cur_opr_param
;
std
::
vector
<
flatbuffers
::
Offset
<
void
>>
m_cur_opr_param
;
void
init_oprs_to_dump
(
const
SymbolVarArray
&
endpoints
);
void
init_oprs_to_dump
(
const
SymbolVarArray
&
endpoints
);
flatbuffers
::
Offset
<
fbs
::
Metadata
>
build_metadata
(
const
Metadata
&
metadata
);
flatbuffers
::
Offset
<
fbs
::
Operator
>
build_single_opr
(
flatbuffers
::
Offset
<
fbs
::
Operator
>
build_single_opr
(
cg
::
OperatorNodeBase
*
opr
,
const
OprRegistry
*
registry
);
cg
::
OperatorNodeBase
*
opr
,
const
OprRegistry
*
registry
);
...
@@ -123,7 +125,8 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
...
@@ -123,7 +125,8 @@ class GraphDumperOSS final : public GraphDumper, OprDumpContextFlatBuffers {
public:
public:
GraphDumperOSS
(
std
::
unique_ptr
<
OutputFile
>
file
)
:
m_file
{
std
::
move
(
file
)}
{}
GraphDumperOSS
(
std
::
unique_ptr
<
OutputFile
>
file
)
:
m_file
{
std
::
move
(
file
)}
{}
DumpResult
dump
(
const
SymbolVarArray
&
output_vars
,
DumpResult
dump
(
const
SymbolVarArray
&
output_vars
,
const
DumpConfig
&
config
=
{})
override
;
const
DumpConfig
&
config
=
{},
const
Metadata
&
metadata
=
{})
override
;
const
GraphDumpConfig
&
config
()
const
override
{
return
m_config
;
}
const
GraphDumpConfig
&
config
()
const
override
{
return
m_config
;
}
void
dump_tensor
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
void
dump_tensor
(
const
std
::
string
&
name
,
const
HostTensorND
&
tensor
,
TensorWriteMethod
method
)
override
;
TensorWriteMethod
method
)
override
;
...
@@ -185,6 +188,17 @@ void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) {
...
@@ -185,6 +188,17 @@ void GraphDumperOSS::init_oprs_to_dump(const SymbolVarArray& endpoints) {
}
}
}
}
flatbuffers
::
Offset
<
fbs
::
Metadata
>
GraphDumperOSS
::
build_metadata
(
const
Metadata
&
metadata
)
{
auto
user_info
=
m_builder
.
CreateSharedString
(
metadata
.
user_info
);
fbs
::
MetadataBuilder
builder
(
m_builder
);
builder
.
add_is_valid
(
metadata
.
is_valid
);
builder
.
add_graph_modified
(
metadata
.
graph_modified
);
builder
.
add_user_info
(
user_info
);
builder
.
add_optimize_options
(
metadata
.
optimize_options
);
return
builder
.
Finish
();
}
flatbuffers
::
Offset
<
fbs
::
Operator
>
GraphDumperOSS
::
build_single_opr
(
flatbuffers
::
Offset
<
fbs
::
Operator
>
GraphDumperOSS
::
build_single_opr
(
cg
::
OperatorNodeBase
*
opr
,
const
OprRegistry
*
registry
)
{
cg
::
OperatorNodeBase
*
opr
,
const
OprRegistry
*
registry
)
{
m_cur_opr
=
opr
;
m_cur_opr
=
opr
;
...
@@ -282,7 +296,8 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
...
@@ -282,7 +296,8 @@ flatbuffers::Offset<fbs::Operator> GraphDumperOSS::build_single_opr(
}
}
GraphDumper
::
DumpResult
GraphDumperOSS
::
dump
(
GraphDumper
::
DumpResult
GraphDumperOSS
::
dump
(
const
SymbolVarArray
&
output_vars
,
const
DumpConfig
&
config
)
{
const
SymbolVarArray
&
output_vars
,
const
DumpConfig
&
config
,
const
Metadata
&
metadata
)
{
mgb_throw_if
(
output_vars
.
empty
(),
SerializationError
,
mgb_throw_if
(
output_vars
.
empty
(),
SerializationError
,
"Can't dump empty graph"
);
"Can't dump empty graph"
);
...
@@ -323,6 +338,9 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
...
@@ -323,6 +338,9 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
uint64_t
offset_to_fbs
=
0
;
uint64_t
offset_to_fbs
=
0
;
m_file
->
write
(
&
offset_to_fbs
,
sizeof
(
offset_to_fbs
));
m_file
->
write
(
&
offset_to_fbs
,
sizeof
(
offset_to_fbs
));
// Dump metadata
auto
fbmeta
=
build_metadata
(
metadata
);
// Dump operators
// Dump operators
init_oprs_to_dump
(
output_vars
);
init_oprs_to_dump
(
output_vars
);
std
::
vector
<
flatbuffers
::
Offset
<
fbs
::
Operator
>>
oprs
;
std
::
vector
<
flatbuffers
::
Offset
<
fbs
::
Operator
>>
oprs
;
...
@@ -350,6 +368,7 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
...
@@ -350,6 +368,7 @@ GraphDumper::DumpResult GraphDumperOSS::dump(
graph
.
add_oprs
(
fb_oprs
);
graph
.
add_oprs
(
fb_oprs
);
graph
.
add_output_vars_idx
(
fb_output_vars
);
graph
.
add_output_vars_idx
(
fb_output_vars
);
graph
.
add_nr_shared_tensor
(
m_nr_shared_tensor
);
graph
.
add_nr_shared_tensor
(
m_nr_shared_tensor
);
graph
.
add_metadata
(
fbmeta
);
m_builder
.
FinishSizePrefixed
(
graph
.
Finish
(),
fbs
::
GraphIdentifier
());
m_builder
.
FinishSizePrefixed
(
graph
.
Finish
(),
fbs
::
GraphIdentifier
());
// Write actual offset_to_fbs
// Write actual offset_to_fbs
...
@@ -531,6 +550,7 @@ public:
...
@@ -531,6 +550,7 @@ public:
mgb_assert
(
nr
==
1
);
mgb_assert
(
nr
==
1
);
}
}
Metadata
load_metadata
();
LoadResult
load_oprs
();
LoadResult
load_oprs
();
CompNode
load_comp_node
(
const
fbs
::
CompNode
*
comp_node
);
CompNode
load_comp_node
(
const
fbs
::
CompNode
*
comp_node
);
...
@@ -700,6 +720,22 @@ GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared() {
...
@@ -700,6 +720,22 @@ GraphLoaderOSS::OprLoadContextImpl::load_tensor_shared() {
return
sh_ptr_ref
;
return
sh_ptr_ref
;
}
}
Metadata
GraphLoaderOSS
::
OprLoadContextImpl
::
load_metadata
()
{
const
auto
*
fbmeta
=
m_loader
->
m_graph
->
metadata
();
Metadata
ret
;
ret
.
is_valid
=
fbmeta
->
is_valid
();
ret
.
graph_modified
=
fbmeta
->
graph_modified
();
if
(
fbmeta
->
user_info
())
{
ret
.
user_info
=
fbmeta
->
user_info
()
->
str
();
ret
.
has_user_info
=
true
;
}
if
(
fbmeta
->
optimize_options
())
{
ret
.
optimize_options
=
fbmeta
->
optimize_options
();
ret
.
optimized_for_inference
=
true
;
}
return
ret
;
}
void
GraphLoaderOSS
::
OprLoadContextImpl
::
load_single_opr
(
void
GraphLoaderOSS
::
OprLoadContextImpl
::
load_single_opr
(
const
fbs
::
Operator
*
fbopr
)
{
const
fbs
::
Operator
*
fbopr
)
{
m_cur_opr_tensor_cnt
=
0
;
m_cur_opr_tensor_cnt
=
0
;
...
@@ -872,7 +908,9 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config,
...
@@ -872,7 +908,9 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config,
}
}
OprLoadContextImpl
ctx
{
this
,
m_graph
->
mgb_version
()};
OprLoadContextImpl
ctx
{
this
,
m_graph
->
mgb_version
()};
auto
metadata
=
ctx
.
load_metadata
();
auto
result
=
ctx
.
load_oprs
();
auto
result
=
ctx
.
load_oprs
();
result
.
metadata
=
metadata
;
auto
fbs_end
=
tensor_begin
+
offset_to_fbs
+
sizeof
(
size
)
+
size
;
auto
fbs_end
=
tensor_begin
+
offset_to_fbs
+
sizeof
(
size
)
+
size
;
auto
cur
=
m_file
->
tell
();
auto
cur
=
m_file
->
tell
();
...
...
src/serialization/include/megbrain/serialization/metadata.h
0 → 100644
浏览文件 @
54a4d70e
/**
* \file src/serialization/include/megbrain/serialization/metadata.h
*
* This file is part of MegBrain, a deep learning framework developed by Megvii.
*
* \brief MegEngine model's metadata
*
* \copyright Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*/
#pragma once
#include <string>
namespace
mgb
{
namespace
serialization
{
struct
Metadata
{
bool
is_valid
=
false
;
bool
graph_modified
=
false
;
bool
has_user_info
=
false
;
std
::
string
user_info
;
bool
optimized_for_inference
=
false
;
uint64_t
optimize_options
;
#define ADD_PROPERTY(type, name) \
type get_##name() const { return name; } \
void set_##name(type x) { \
name = x; \
has_##name = true; \
}
ADD_PROPERTY
(
std
::
string
,
user_info
)
#undef ADD_PROPERTY
uint64_t
get_optimize_options
()
{
return
optimize_options
;
}
void
set_optimize_options
(
uint64_t
value
)
{
optimized_for_inference
=
true
;
optimize_options
=
value
;
}
};
}
// namespace serialization
}
// namespace mgb
\ No newline at end of file
src/serialization/include/megbrain/serialization/serializer.h
浏览文件 @
54a4d70e
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "megbrain/serialization/dump_format.h"
#include "megbrain/serialization/dump_format.h"
#include "megbrain/serialization/file.h"
#include "megbrain/serialization/file.h"
#include "megbrain/serialization/load_dump_config.h"
#include "megbrain/serialization/load_dump_config.h"
#include "megbrain/serialization/metadata.h"
namespace
mgb
{
namespace
mgb
{
namespace
serialization
{
namespace
serialization
{
...
@@ -32,6 +33,9 @@ namespace serialization {
...
@@ -32,6 +33,9 @@ namespace serialization {
//! expliit dtor decl to reduce binary size
//! expliit dtor decl to reduce binary size
~
LoadResult
()
noexcept
;
~
LoadResult
()
noexcept
;
//! metadata
Metadata
metadata
;
using
TensorMap
=
std
::
unordered_map
<
using
TensorMap
=
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
HostTensorND
>>
;
std
::
string
,
std
::
shared_ptr
<
HostTensorND
>>
;
...
@@ -178,7 +182,8 @@ namespace serialization {
...
@@ -178,7 +182,8 @@ namespace serialization {
virtual
DumpResult
dump
(
virtual
DumpResult
dump
(
const
SymbolVarArray
&
output_vars
,
const
SymbolVarArray
&
output_vars
,
const
DumpConfig
&
config
=
{})
=
0
;
const
DumpConfig
&
config
=
{},
const
Metadata
&
metadata
=
{})
=
0
;
virtual
GraphDumpFormat
format
()
const
=
0
;
virtual
GraphDumpFormat
format
()
const
=
0
;
};
};
...
...
src/serialization/test/serializer_oss.cpp
浏览文件 @
54a4d70e
...
@@ -92,6 +92,43 @@ TEST(TestSerializer2, MultiGraphDumpLoad) {
...
@@ -92,6 +92,43 @@ TEST(TestSerializer2, MultiGraphDumpLoad) {
load
();
load
();
}
}
TEST
(
TestSerializer2
,
Metadata
)
{
auto
fname
=
GET_OUTPUT_FILE
();
TensorShape
shape
{
2
,
3
};
auto
dump
=
[
&
]()
{
auto
cn
=
CompNode
::
load
(
"xpu0"
);
auto
host_x
=
std
::
make_shared
<
HostTensorND
>
(
cn
,
shape
),
host_y
=
std
::
make_shared
<
HostTensorND
>
(
cn
,
shape
);
auto
graph
=
ComputingGraph
::
make
();
auto
x
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
,
{
"x"
}),
y
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_y
,
{
"y"
});
using
Mode
=
opr
::
Elemwise
::
Mode
;
auto
z
=
opr
::
Elemwise
::
make
({
x
,
y
},
Mode
::
ADD
,
{
"add(x, y)"
});
Metadata
metadata
;
metadata
.
user_info
=
"TEST_METADATA"
;
metadata
.
has_user_info
=
true
;
auto
dumper
=
GraphDumper
::
make
(
OutputFile
::
make_fs
(
fname
.
c_str
()),
GraphDumpFormat
::
FLATBUFFERS
);
auto
rst
=
dumper
->
dump
({
z
.
rename
(
"z"
)},
{},
metadata
);
};
auto
load
=
[
&
]()
{
HostTensorGenerator
<>
gen
;
auto
loader
=
GraphLoader
::
make
(
InputFile
::
make_fs
(
fname
.
c_str
()),
GraphDumpFormat
::
FLATBUFFERS
);
auto
rst
=
loader
->
load
();
auto
metadata
=
rst
.
metadata
;
int
cmp
=
strcmp
(
metadata
.
user_info
.
c_str
(),
"TEST_METADATA"
);
EXPECT_EQ
(
cmp
,
0
);
};
dump
();
load
();
}
TEST
(
TestSerializer2
,
APlusB
)
{
TEST
(
TestSerializer2
,
APlusB
)
{
auto
fname
=
GET_OUTPUT_FILE
();
auto
fname
=
GET_OUTPUT_FILE
();
TensorShape
shape
{
2
,
3
};
TensorShape
shape
{
2
,
3
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录