Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e474994f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
9 个月 前同步成功
通知
392
Star
4702
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
e474994f
编写于
4月 02, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/jit): catch input tensors name when tracing
GitOrigin-RevId: 9c692548663654265f9f9e2753f8637d444cb78d
上级
aed681d3
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
23 addition
and
4 deletion
+23
-4
imperative/python/megengine/jit/tracing.py
imperative/python/megengine/jit/tracing.py
+4
-2
imperative/python/test/unit/utils/test_dump_naming.py
imperative/python/test/unit/utils/test_dump_naming.py
+19
-2
未找到文件。
imperative/python/megengine/jit/tracing.py
浏览文件 @
e474994f
...
...
@@ -772,7 +772,8 @@ class trace:
len
(
self
.
_output_bindings
)
)
)
if
arg_names
is
None
:
without_arg_names
=
arg_names
is
None
if
without_arg_names
:
arg_names
=
[
"arg_%d"
%
i
for
i
in
range
(
len
(
self
.
_arg_bindings
))]
if
arg_names
and
not
isinstance
(
arg_names
,
collections
.
abc
.
Sequence
):
arg_names
=
(
arg_names
,)
...
...
@@ -802,7 +803,7 @@ class trace:
dtype
=
info
.
dtype
,
device
=
dumped_device
(
info
),
shape
=
info
.
shape
or
(
1
,),
name
=
arg_names
[
i
]
if
arg_names
else
None
,
name
=
info
.
name
if
without_arg_names
and
info
.
name
else
arg_names
[
i
]
,
)
for
k
,
h
in
self
.
_kwarg_bindings
.
items
():
info
=
self
.
_tinfo
[
h
]
...
...
@@ -889,6 +890,7 @@ class trace:
return
h
,
info
=
self
.
_new_handle
()
info
.
external
=
False
info
.
name
=
x
.
c_name
info
.
device
=
x
.
device
info
.
dtype
=
x
.
dtype
info
.
shape
=
x
.
numpy
().
shape
...
...
imperative/python/test/unit/utils/test_dump_naming.py
浏览文件 @
e474994f
...
...
@@ -203,14 +203,31 @@ def test_with_same_operators(symbolic):
assert
ops
[
-
2
].
name
==
"simple.RELU[0]"
def
test_not_keep_opr_name
():
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
def
test_not_keep_opr_name
(
symbolic
):
def
f
(
x
):
return
2
*
x
op
=
_dump_and_load
(
f
,
True
,
False
)[
-
1
]
op
=
_dump_and_load
(
f
,
symbolic
,
False
)[
-
1
]
assert
op
.
name
==
"MUL(x,const<2>[2])[4]"
@
pytest
.
mark
.
parametrize
(
"tensor_name, var_name"
,
[(
"data"
,
"data"
),
(
None
,
"arg_0"
)])
def
test_catch_input_name
(
tensor_name
,
var_name
):
def
f
(
x
):
return
2
*
x
func
=
trace
(
f
,
symbolic
=
True
,
capture_as_const
=
True
)
x
=
Tensor
(
np
.
ones
(
shape
=
(
2
,
3
)),
name
=
tensor_name
)
func
(
x
).
numpy
()
file
=
io
.
BytesIO
()
func
.
dump
(
file
,
optimize_for_inference
=
False
,
keep_opr_name
=
True
,
keep_var_name
=
2
)
file
.
seek
(
0
)
*
_
,
outputs
=
G
.
load_graph
(
file
)
op
=
cgtools
.
get_oprs_seq
(
outputs
)[
-
1
]
assert
op
.
inputs
[
0
].
name
==
var_name
@
pytest
.
mark
.
parametrize
(
"symbolic"
,
[
False
,
True
])
def
test_quantized_module_auto_naming
(
symbolic
):
class
Simple
(
M
.
Module
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录