Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
355782ae
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
355782ae
编写于
1月 11, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(traced_module): clear node after trace module
GitOrigin-RevId: f7f602403481fdeb6a77435bc98c5d9e7a5fa58e
上级
fba54488
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
27 addition
and
1 deletion
+27
-1
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+1
-0
imperative/python/megengine/traced_module/module_tracer.py
imperative/python/megengine/traced_module/module_tracer.py
+13
-0
imperative/python/megengine/traced_module/node.py
imperative/python/megengine/traced_module/node.py
+5
-0
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+8
-1
未找到文件。
imperative/python/megengine/traced_module/expr.py
浏览文件 @
355782ae
...
@@ -763,6 +763,7 @@ class Constant(Expr):
...
@@ -763,6 +763,7 @@ class Constant(Expr):
current_graph
=
active_module_tracer
().
current_scope
()
current_graph
=
active_module_tracer
().
current_scope
()
current_graph
.
_namespace
.
auto_naming_for_outputs
(
expr
)
current_graph
.
_namespace
.
auto_naming_for_outputs
(
expr
)
current_graph
.
_insert
(
expr
)
current_graph
.
_insert
(
expr
)
active_module_tracer
().
current_constant_cache
().
append
(
expr
.
value
)
return
expr
.
outputs
[
0
]
return
expr
.
outputs
[
0
]
def
interpret
(
self
,
*
inputs
):
def
interpret
(
self
,
*
inputs
):
...
...
imperative/python/megengine/traced_module/module_tracer.py
浏览文件 @
355782ae
...
@@ -131,6 +131,7 @@ class module_tracer:
...
@@ -131,6 +131,7 @@ class module_tracer:
self
.
_active_scopes
=
[]
self
.
_active_scopes
=
[]
self
.
checker
=
TracedModuleChecker
(
self
)
self
.
checker
=
TracedModuleChecker
(
self
)
self
.
patcher
=
Patcher
(
wrap_fn
)
self
.
patcher
=
Patcher
(
wrap_fn
)
self
.
_activate_constant_cache
=
[]
@
classmethod
@
classmethod
def
register_as_builtin
(
cls
,
mod
):
def
register_as_builtin
(
cls
,
mod
):
...
@@ -145,16 +146,28 @@ class module_tracer:
...
@@ -145,16 +146,28 @@ class module_tracer:
def
push_scope
(
self
,
scope
):
def
push_scope
(
self
,
scope
):
self
.
_active_scopes
.
append
(
scope
)
self
.
_active_scopes
.
append
(
scope
)
self
.
checker
.
push_scope
()
self
.
checker
.
push_scope
()
self
.
_activate_constant_cache
.
append
([])
def
pop_scope
(
self
):
def
pop_scope
(
self
):
self
.
_active_scopes
.
pop
()
self
.
_active_scopes
.
pop
()
self
.
checker
.
pop_scope
()
self
.
checker
.
pop_scope
()
cache
=
self
.
_activate_constant_cache
.
pop
()
for
obj
in
cache
:
if
hasattr
(
obj
,
"_NodeMixin__node"
):
delattr
(
obj
,
"_NodeMixin__node"
)
def
current_scope
(
self
):
def
current_scope
(
self
):
if
self
.
_active_scopes
:
if
self
.
_active_scopes
:
return
self
.
_active_scopes
[
-
1
]
return
self
.
_active_scopes
[
-
1
]
return
None
return
None
def
current_constant_cache
(
self
):
if
self
.
_activate_constant_cache
:
return
self
.
_activate_constant_cache
[
-
1
]
return
None
def
top_scope
(
self
):
def
top_scope
(
self
):
if
self
.
_active_scopes
:
if
self
.
_active_scopes
:
return
self
.
_active_scopes
[
0
]
return
self
.
_active_scopes
[
0
]
...
...
imperative/python/megengine/traced_module/node.py
浏览文件 @
355782ae
...
@@ -379,6 +379,11 @@ class NodeMixin(abc.ABC):
...
@@ -379,6 +379,11 @@ class NodeMixin(abc.ABC):
if
isinstance
(
value
,
NodeMixin
):
if
isinstance
(
value
,
NodeMixin
):
value
.
_record_wrapped_nodes
(
node
)
value
.
_record_wrapped_nodes
(
node
)
@
classmethod
def
clear_node
(
cls
,
value
):
if
hasattr
(
value
,
"_NodeMixin__node"
):
delattr
(
value
,
"_NodeMixin__node"
)
@
classmethod
@
classmethod
def
get
(
cls
,
value
,
*
default
):
def
get
(
cls
,
value
,
*
default
):
return
getattr
(
value
,
"_NodeMixin__node"
,
*
default
)
return
getattr
(
value
,
"_NodeMixin__node"
,
*
default
)
...
...
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
355782ae
...
@@ -1980,7 +1980,10 @@ class TracedModule(Module):
...
@@ -1980,7 +1980,10 @@ class TracedModule(Module):
assert
(
assert
(
treedef
in
self
.
argdef_graph_map
treedef
in
self
.
argdef_graph_map
),
"support input args kwargs format:
\n
{}, but get:
\n
{}"
.
format
(
),
"support input args kwargs format:
\n
{}, but get:
\n
{}"
.
format
(
"
\n
"
.
join
(
"forward({})"
.
format
(
i
.
_args_kwargs_repr
())
for
i
in
self
.
argdef_graph_map
.
keys
()),
"
\n
"
.
join
(
"forward({})"
.
format
(
i
.
_args_kwargs_repr
())
for
i
in
self
.
argdef_graph_map
.
keys
()
),
treedef
.
_args_kwargs_repr
(),
treedef
.
_args_kwargs_repr
(),
)
)
inputs
=
filter
(
inputs
=
filter
(
...
@@ -2514,3 +2517,7 @@ def trace_module(
...
@@ -2514,3 +2517,7 @@ def trace_module(
set_symbolic_shape
(
use_sym_shape
)
set_symbolic_shape
(
use_sym_shape
)
set_active_module_tracer
(
None
)
set_active_module_tracer
(
None
)
unset_module_tracing
()
unset_module_tracing
()
for
t
in
mod
.
tensors
(
recursive
=
True
):
NodeMixin
.
clear_node
(
t
)
for
t
in
inputs
:
NodeMixin
.
clear_node
(
t
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录