Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3ff5ca5f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3ff5ca5f
编写于
12月 27, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/traced_module): support to modify the name of Node during graph surgery
GitOrigin-RevId: 9ecf6f2c5b700d4c91947def2cdc00cce4e0efc7
上级
3a219209
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
55 addition
and
1 deletion
+55
-1
imperative/python/megengine/traced_module/module_tracer.py
imperative/python/megengine/traced_module/module_tracer.py
+5
-1
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+23
-0
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+27
-0
未找到文件。
imperative/python/megengine/traced_module/module_tracer.py
浏览文件 @
3ff5ca5f
...
...
@@ -92,7 +92,6 @@ BUILTIN_TENSOR_WRAP_METHOD = [
"dtype"
,
"grad"
,
"item"
,
"name"
,
"ndim"
,
"numpy"
,
"qparams"
,
...
...
@@ -152,6 +151,11 @@ class module_tracer:
return
self
.
_active_scopes
[
-
1
]
return
None
def
top_scope
(
self
):
if
self
.
_active_scopes
:
return
self
.
_active_scopes
[
0
]
return
None
class
NotExist
:
pass
...
...
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
3ff5ca5f
...
...
@@ -180,6 +180,25 @@ def _tensor_to_node(tensors):
return
nodes
def
_name_setter
(
node
:
Node
,
new_name
:
str
):
surgery_mode
=
_set_graph_surgery_mode
(
False
)
graph
=
active_module_tracer
().
current_scope
()
if
node
.
top_graph
is
not
None
:
top_graph
=
active_module_tracer
().
top_scope
()
if
node
is
top_graph
.
_namespace
.
used_names
.
get
(
node
.
_name
,
None
):
graph
=
top_graph
else
:
graph
=
node
.
top_graph
assert
(
graph
.
_namespace
.
used_names
.
get
(
new_name
,
None
)
is
None
),
"The name(%s) is already in use. Please try a different one again."
%
(
new_name
)
graph
.
_namespace
.
unassociate_name_with_obj
(
node
)
node
.
_name
=
graph
.
_namespace
.
create_unique_name
(
new_name
,
node
)
_set_graph_surgery_mode
(
surgery_mode
)
def
_wrap_method_to_tensor_node
():
def
_any_method
(
name
,
func
):
def
_any
(
*
args
,
**
kwargs
):
...
...
@@ -213,6 +232,10 @@ def _wrap_method_to_tensor_node():
else
:
patch
.
set_func
(
_any_method
(
method
,
patch
.
origin_fn
))
tensor_method_patch
.
append
(
patch
)
patch
=
PatchedFn
(
Node
,
"name"
)
patch
.
set_func
(
property
(
patch
.
origin_fn
.
fget
,
_name_setter
))
tensor_method_patch
.
append
(
patch
)
return
tensor_method_patch
...
...
imperative/python/test/unit/traced_module/test_modification.py
浏览文件 @
3ff5ca5f
...
...
@@ -377,6 +377,33 @@ def test_set_node_name():
rename
(
"output"
)
np
.
testing
.
assert_equal
(
str
(
graph
.
outputs
[
0
]),
"output"
)
def
add_1
(
x
):
x
=
x
+
1
x
.
name
=
"func_add_1"
return
x
class
ModuleAdd_3
(
M
.
Module
):
def
forward
(
self
,
x
):
x
=
x
+
1
x
.
name
=
"module_add_1"
x
=
x
+
2
return
x
setattr
(
traced_module
,
"add_3"
,
ModuleAdd_3
())
self
=
graph
.
inputs
[
0
]
with
graph
.
insert_exprs
():
x
=
output_node
+
1
x
.
name
=
"_add_1"
x
=
add_1
(
x
)
x
=
self
.
add_3
(
x
)
graph
.
replace_node
({
output_node
:
x
})
graph
.
compile
()
assert
"_add_1"
in
graph
.
_namespace
.
used_names
assert
"func_add_1"
in
graph
.
_namespace
.
used_names
assert
"module_add_1"
in
traced_module
.
add_3
.
graph
.
_namespace
.
used_names
def
test_set_graph_name
():
traced_module
,
x
,
expect
=
_init_module
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录