Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
01d2473c
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
01d2473c
编写于
10月 31, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/traced_module): fix TracedModule flatten
GitOrigin-RevId: 7b15fe492b4486d009d603227bec05485457a7da
上级
23c1fda7
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
36 addition
and
3 deletion
+36
-3
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+1
-3
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+35
-0
未找到文件。
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
01d2473c
...
...
@@ -2078,9 +2078,7 @@ class TracedModule(Module):
for
node
,
repl_node
in
repl_dict
.
items
():
assert
node
in
graph
.
_inputs
or
node
in
graph
.
_outputs
for
i
in
node
.
users
:
if
i
not
in
repl_node
.
users
:
repl_node
.
users
.
append
(
i
)
repl_node
.
users
.
extend
(
node
.
users
)
rename_blacklist
=
list
(
chain
(
call
.
inputs
,
call
.
outputs
))
...
...
imperative/python/test/unit/traced_module/test_modification.py
浏览文件 @
01d2473c
...
...
@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
pickle
from
collections
import
defaultdict
from
itertools
import
chain
import
numpy
as
np
...
...
@@ -52,6 +53,25 @@ class MyModule(M.Module):
return
x
class
MyBlock1
(
M
.
Module
):
def
forward
(
self
,
a
):
y
=
F
.
concat
([
a
,
a
])
return
a
,
y
class
MyModule1
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
block0
=
MyBlock1
()
self
.
block1
=
MyBlock1
()
def
forward
(
self
,
a
):
a
,
y1
=
self
.
block0
(
a
)
a
=
a
+
1
a
,
y2
=
self
.
block1
(
a
)
return
a
,
y1
+
y2
class
NewModule
(
M
.
Module
):
def
__init__
(
self
,
traced_module
):
super
(
NewModule
,
self
).
__init__
()
...
...
@@ -64,6 +84,17 @@ class NewModule(M.Module):
return
x
def
_check_expr_users
(
traced_module
):
node_user
=
defaultdict
(
list
)
for
expr
in
traced_module
.
graph
.
_exprs
:
for
node
in
expr
.
inputs
:
node_user
[
node
].
append
(
expr
)
for
node
in
traced_module
.
graph
.
nodes
():
node
.
users
.
sort
(
key
=
lambda
m
:
m
.
_id
)
node_user
[
node
].
sort
(
key
=
lambda
m
:
m
.
_id
)
assert
node
.
users
==
node_user
[
node
]
def
_init_cls
(
cls
):
module
=
cls
()
x
=
F
.
ones
((
1
,
3
,
3
,
3
))
...
...
@@ -201,6 +232,10 @@ def test_flatten():
assert
len
(
traced_module
.
graph
.
_exprs
)
==
12
np
.
testing
.
assert_equal
(
expect
.
numpy
(),
traced_module
(
x
).
numpy
())
traced_module
,
x
,
expect
=
_init_cls
(
MyModule1
)
traced_module
=
traced_module
.
flatten
()
_check_expr_users
(
traced_module
)
def
test_id_and_name
():
def
_check_id
(
traced_module
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录