Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a404c508
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a404c508
编写于
10月 29, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge): support dump with specific format
GitOrigin-RevId: 57a7c0de02ec6ee30a67b5cc069dbdd7dc0f6437
上级
fba523a1
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
38 addition
and
9 deletion
+38
-9
imperative/python/megengine/core/tensor/megbrain_graph.py
imperative/python/megengine/core/tensor/megbrain_graph.py
+12
-4
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+3
-0
imperative/python/src/graph_rt.cpp
imperative/python/src/graph_rt.cpp
+15
-3
imperative/python/test/unit/jit/test_tracing.py
imperative/python/test/unit/jit/test_tracing.py
+8
-2
未找到文件。
imperative/python/megengine/core/tensor/megbrain_graph.py
浏览文件 @
a404c508
...
@@ -11,13 +11,12 @@ import json
...
@@ -11,13 +11,12 @@ import json
import
os
import
os
import
weakref
import
weakref
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
from
..
import
_imperative_rt
from
..
import
_imperative_rt
from
.._imperative_rt
import
GraphOptimizeOptions
from
.._imperative_rt
import
GraphOptimizeOptions
,
SerializationFormat
from
.._imperative_rt.core2
import
apply
,
set_cpp_apply_backward_varnode
from
.._wrap
import
as_device
from
.._wrap
import
as_device
from
..ops.builtin
import
OpDef
from
..ops.builtin
import
OpDef
...
@@ -377,7 +376,8 @@ def dump_graph(
...
@@ -377,7 +376,8 @@ def dump_graph(
keep_opr_priority
:
bool
=
False
,
keep_opr_priority
:
bool
=
False
,
strip_info_file
=
None
,
strip_info_file
=
None
,
append_json
=
False
,
append_json
=
False
,
metadata
=
None
metadata
=
None
,
dump_format
=
None
)
->
Tuple
[
bytes
,
CompGraphDumpResult
]:
)
->
Tuple
[
bytes
,
CompGraphDumpResult
]:
r
"""serialize the computing graph of `output_vars` and get byte result.
r
"""serialize the computing graph of `output_vars` and get byte result.
...
@@ -398,6 +398,7 @@ def dump_graph(
...
@@ -398,6 +398,7 @@ def dump_graph(
append_json: will be check when `strip_info_file` is not None. if set
append_json: will be check when `strip_info_file` is not None. if set
true, the information for code strip will be append to strip_info_file.
true, the information for code strip will be append to strip_info_file.
if set false, will rewrite strip_info_file
if set false, will rewrite strip_info_file
dump_format: using different dump formats.
Note:
Note:
The underlying C++ API only accepts a var list. If a dict is given,
The underlying C++ API only accepts a var list. If a dict is given,
...
@@ -434,6 +435,12 @@ def dump_graph(
...
@@ -434,6 +435,12 @@ def dump_graph(
outputs
=
[]
outputs
=
[]
params
=
[]
params
=
[]
dump_format_map
=
{
None
:
None
,
"FBS"
:
SerializationFormat
.
FBS
,
}
dump_format
=
dump_format_map
[
dump_format
]
dump_content
=
_imperative_rt
.
dump_graph
(
dump_content
=
_imperative_rt
.
dump_graph
(
ov
,
ov
,
keep_var_name
,
keep_var_name
,
...
@@ -441,6 +448,7 @@ def dump_graph(
...
@@ -441,6 +448,7 @@ def dump_graph(
keep_param_name
,
keep_param_name
,
keep_opr_priority
,
keep_opr_priority
,
metadata
,
metadata
,
dump_format
,
stat
,
stat
,
inputs
,
inputs
,
outputs
,
outputs
,
...
...
imperative/python/megengine/jit/tracing.py
浏览文件 @
a404c508
...
@@ -1008,6 +1008,7 @@ class trace:
...
@@ -1008,6 +1008,7 @@ class trace:
maxerr
=
1e-4
,
maxerr
=
1e-4
,
resize_input
=
False
,
resize_input
=
False
,
input_transform
=
None
,
input_transform
=
None
,
dump_format
:
str
=
None
,
**
kwargs
**
kwargs
):
):
r
"""Serializes trace to file system.
r
"""Serializes trace to file system.
...
@@ -1059,6 +1060,7 @@ class trace:
...
@@ -1059,6 +1060,7 @@ class trace:
resize_input: whether resize input image to fit input var shape.
resize_input: whether resize input image to fit input var shape.
input_transform: a python expression to transform the input data.
input_transform: a python expression to transform the input data.
Example: data / np.std(data)
Example: data / np.std(data)
dump_format: using different dump formats.
Keyword Arguments:
Keyword Arguments:
...
@@ -1265,6 +1267,7 @@ class trace:
...
@@ -1265,6 +1267,7 @@ class trace:
strip_info_file
=
strip_info_file
,
strip_info_file
=
strip_info_file
,
append_json
=
append_json
,
append_json
=
append_json
,
metadata
=
metadata
,
metadata
=
metadata
,
dump_format
=
dump_format
,
)
)
file
.
write
(
dump_content
)
file
.
write
(
dump_content
)
...
...
imperative/python/src/graph_rt.cpp
浏览文件 @
a404c508
...
@@ -35,6 +35,7 @@ using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
...
@@ -35,6 +35,7 @@ using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
using
_LayoutTransform
=
_OptimizeForInferenceOptions
::
LayoutTransform
;
using
_LayoutTransform
=
_OptimizeForInferenceOptions
::
LayoutTransform
;
using
_AlgoStrategy
=
opr
::
mixin
::
AlgoChooserHelper
::
ExecutionPolicy
::
Strategy
;
using
_AlgoStrategy
=
opr
::
mixin
::
AlgoChooserHelper
::
ExecutionPolicy
::
Strategy
;
using
_SerializationMetadata
=
mgb
::
serialization
::
Metadata
;
using
_SerializationMetadata
=
mgb
::
serialization
::
Metadata
;
using
_SerializationFormat
=
mgb
::
serialization
::
GraphDumpFormat
;
namespace
{
namespace
{
class
_CompGraphProfilerImpl
{
class
_CompGraphProfilerImpl
{
...
@@ -310,6 +311,10 @@ void init_graph_rt(py::module m) {
...
@@ -310,6 +311,10 @@ void init_graph_rt(py::module m) {
.
value
(
"NCHW64"
,
_LayoutTransform
::
NCHW64
)
.
value
(
"NCHW64"
,
_LayoutTransform
::
NCHW64
)
.
export_values
();
.
export_values
();
py
::
enum_
<
_SerializationFormat
>
(
m
,
"SerializationFormat"
)
.
value
(
"FBS"
,
_SerializationFormat
::
FLATBUFFERS
)
.
export_values
();
m
.
def
(
"optimize_for_inference"
,
m
.
def
(
"optimize_for_inference"
,
[](
const
VarNodeArray
&
dest_vars
,
const
_OptimizeForInferenceOptions
&
opt
)
{
[](
const
VarNodeArray
&
dest_vars
,
const
_OptimizeForInferenceOptions
&
opt
)
{
SymbolVarArray
symvars
(
dest_vars
.
begin
(),
dest_vars
.
end
());
SymbolVarArray
symvars
(
dest_vars
.
begin
(),
dest_vars
.
end
());
...
@@ -380,11 +385,18 @@ void init_graph_rt(py::module m) {
...
@@ -380,11 +385,18 @@ void init_graph_rt(py::module m) {
m
.
def
(
"dump_graph"
,
m
.
def
(
"dump_graph"
,
[](
const
std
::
vector
<
VarNode
*>&
dest_vars
,
int
keep_var_name
,
[](
const
std
::
vector
<
VarNode
*>&
dest_vars
,
int
keep_var_name
,
bool
keep_opr_name
,
bool
keep_param_name
,
bool
keep_opr_priority
,
bool
keep_opr_name
,
bool
keep_param_name
,
bool
keep_opr_priority
,
std
::
optional
<
_SerializationMetadata
>
metadata
,
py
::
list
&
stat
,
std
::
optional
<
_SerializationMetadata
>
metadata
,
std
::
optional
<
_SerializationFormat
>
dump_format
,
py
::
list
&
stat
,
py
::
list
&
inputs
,
py
::
list
&
outputs
,
py
::
list
&
params
)
{
py
::
list
&
inputs
,
py
::
list
&
outputs
,
py
::
list
&
params
)
{
std
::
vector
<
uint8_t
>
buf
;
std
::
vector
<
uint8_t
>
buf
;
auto
dumper
=
ser
::
GraphDumpFormat
format
;
ser
::
GraphDumper
::
make
(
ser
::
OutputFile
::
make_vector_proxy
(
&
buf
));
if
(
dump_format
.
has_value
())
{
format
=
dump_format
.
value
();
}
else
{
format
=
{};
}
auto
dumper
=
ser
::
GraphDumper
::
make
(
ser
::
OutputFile
::
make_vector_proxy
(
&
buf
),
format
);
SymbolVarArray
symvars
(
dest_vars
.
begin
(),
dest_vars
.
end
());
SymbolVarArray
symvars
(
dest_vars
.
begin
(),
dest_vars
.
end
());
ser
::
GraphDumper
::
DumpConfig
config
{
ser
::
GraphDumper
::
DumpConfig
config
{
...
...
imperative/python/test/unit/jit/test_tracing.py
浏览文件 @
a404c508
...
@@ -190,7 +190,13 @@ def test_print_in_trace():
...
@@ -190,7 +190,13 @@ def test_print_in_trace():
np
.
testing
.
assert_equal
(
z
,
buf
)
np
.
testing
.
assert_equal
(
z
,
buf
)
def
test_dump
():
@
pytest
.
mark
.
parametrize
(
"dump_format"
,
[
"FBS"
,
],
)
def
test_dump
(
dump_format
):
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
@
trace
(
symbolic
=
True
,
capture_as_const
=
True
)
def
f
(
a
,
b
):
def
f
(
a
,
b
):
return
a
+
b
return
a
+
b
...
@@ -205,7 +211,7 @@ def test_dump():
...
@@ -205,7 +211,7 @@ def test_dump():
np
.
testing
.
assert_equal
(
f
(
a
,
b
).
numpy
(),
y
)
np
.
testing
.
assert_equal
(
f
(
a
,
b
).
numpy
(),
y
)
file
=
io
.
BytesIO
()
file
=
io
.
BytesIO
()
dump_info
=
f
.
dump
(
file
)
dump_info
=
f
.
dump
(
file
,
dump_format
=
dump_format
)
assert
dump_info
.
nr_opr
==
3
assert
dump_info
.
nr_opr
==
3
np
.
testing
.
assert_equal
(
dump_info
.
inputs
,
[
"arg_0"
,
"arg_1"
])
np
.
testing
.
assert_equal
(
dump_info
.
inputs
,
[
"arg_0"
,
"arg_1"
])
np
.
testing
.
assert_equal
(
dump_info
.
outputs
,
[
"ADD"
])
np
.
testing
.
assert_equal
(
dump_info
.
outputs
,
[
"ADD"
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录