Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
ba8bd010
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
ba8bd010
编写于
11月 01, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/traced_module): fix insert module
GitOrigin-RevId: 755e1c68f60b0fc994eec56697d0515a8343e9f5
上级
b8316de5
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
36 addition
and
17 deletion
+36
-17
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+9
-13
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+27
-4
未找到文件。
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
ba8bd010
...
...
@@ -293,19 +293,10 @@ class _InsertExprs:
module
=
self
.
graph
.
inputs
[
0
].
owner
for
mod
,
parent
in
module
.
modules
(
with_parent
=
True
):
name
=
mod
.
_name
if
isinstance
(
mod
,
TracedModuleBuilder
):
mod
=
mod
.
build
()
if
hasattr
(
mod
,
"argdef_graph_map"
):
for
g
in
mod
.
argdef_graph_map
.
values
():
for
n
in
g
.
nodes
(
False
):
if
isinstance
(
n
,
TensorNode
):
n
.
value
=
None
setattr
(
parent
,
name
,
mod
)
for
node
in
self
.
global_scope
.
nodes
(
False
):
node
.
value
=
None
for
k
,
v
in
module
.
__dict__
.
items
():
if
isinstance
(
v
,
TracedModuleBuilder
):
v
=
v
.
build
()
setattr
(
module
,
k
,
v
)
extra_inp_nodes
=
set
(
self
.
global_scope
.
inputs
)
max_inp_expr_idx
=
-
1
...
...
@@ -334,6 +325,9 @@ class _InsertExprs:
self
.
graph
.
_namespace
.
merge
(
self
.
global_scope
.
_namespace
)
self
.
root_graph
.
_total_ids
=
(
Node
.
_get_next_id
(),
Expr
.
_get_next_id
())
self
.
root_graph
.
inputs
[
0
].
owner
.
_update_ref
()
for
node
in
self
.
root_graph
.
nodes
():
if
isinstance
(
node
,
TensorNode
):
node
.
value
=
None
return
True
...
...
@@ -1519,6 +1513,7 @@ class TracedModuleBuilder(NodeMixin):
return
active_module_tracer
().
patcher
.
wrap_fn
(
attr
)
if
isinstance
(
attr
,
(
List
,
Dict
)):
flag
=
_set_convert_node_flag
(
False
)
unset_module_tracing
()
has_module
,
m_container
=
replace_container_with_module_container
(
attr
)
if
m_container
:
...
...
@@ -1529,6 +1524,7 @@ class TracedModuleBuilder(NodeMixin):
" Module and Non-Module objects."
)
set_module_tracing
()
_set_convert_node_flag
(
flag
)
if
isinstance
(
attr
,
Module
):
attr
=
TracedModuleBuilder
(
attr
)
...
...
imperative/python/test/unit/traced_module/test_modification.py
浏览文件 @
ba8bd010
...
...
@@ -16,7 +16,7 @@ import megengine.module as M
from
megengine.module.identity
import
Identity
from
megengine.traced_module
import
trace_module
from
megengine.traced_module.expr
import
CallFunction
,
CallMethod
,
Expr
,
GetAttr
,
Input
from
megengine.traced_module.node
import
ModuleNode
,
Node
from
megengine.traced_module.node
import
ModuleNode
,
Node
,
TensorNode
class
IdentityMod
(
M
.
Module
):
...
...
@@ -159,21 +159,44 @@ def test_insert():
def
test_insert_module
():
class
Neg
(
M
.
Module
):
def
__init__
(
self
,
name
):
super
().
__init__
(
name
)
self
.
identity
=
M
.
Identity
()
self
.
identity_list
=
[
M
.
Identity
(),
M
.
Identity
()]
self
.
identity_dict
=
{
"0"
:
M
.
Identity
(),
"1"
:
M
.
Identity
()}
self
.
param
=
F
.
zeros
((
1
,))
def
forward
(
self
,
x
):
return
F
.
neg
(
x
)
x
=
self
.
identity
(
x
)
for
m
in
self
.
identity_dict
:
x
=
self
.
identity_dict
[
m
](
x
)
for
m
in
self
.
identity_list
:
x
=
m
(
x
)
return
F
.
neg
(
x
)
+
self
.
param
traced_module
,
x
,
expect
=
_init_block
()
graph
=
traced_module
.
graph
relu_out
=
graph
.
get_function_by_type
(
F
.
relu
).
as_unique
().
outputs
[
0
]
self
=
graph
.
inputs
[
0
]
setattr
(
traced_module
,
"neg"
,
Neg
())
setattr
(
traced_module
,
"neg"
,
Neg
(
name
=
"neg"
))
setattr
(
traced_module
,
"neg2"
,
Neg
(
name
=
"neg"
))
setattr
(
traced_module
,
"param"
,
F
.
zeros
((
1
,)))
with
graph
.
insert_exprs
():
neg_out
=
self
.
neg
(
relu_out
)
neg_out
=
self
.
neg2
(
relu_out
)
neg_out
=
neg_out
+
self
.
param
graph
.
replace_node
({
relu_out
:
neg_out
})
graph
.
compile
()
np
.
testing
.
assert_allclose
(
expect
-
1
,
1
-
traced_module
(
x
),
atol
=
1e-6
)
assert
traced_module
.
neg
.
graph
is
not
None
assert
len
(
traced_module
.
neg
.
graph
.
_exprs
)
==
1
assert
traced_module
.
neg2
.
graph
is
not
None
assert
traced_module
.
neg2
.
param
is
not
None
assert
len
(
traced_module
.
neg
.
graph
.
_exprs
)
==
13
for
n
in
traced_module
.
graph
.
nodes
():
if
isinstance
(
n
,
TensorNode
):
assert
n
.
value
is
None
def
test_add_input_and_output
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录