Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f642b05e
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
f642b05e
编写于
9月 16, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test(mge): update traced_module unit test
GitOrigin-RevId: 3948d50d7901a85737a19795bc1866ceb08bd29d
上级
fb20cb36
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
106 addition
and
3 deletion
+106
-3
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+5
-1
imperative/python/megengine/traced_module/node.py
imperative/python/megengine/traced_module/node.py
+1
-1
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+100
-1
未找到文件。
imperative/python/megengine/traced_module/expr.py
浏览文件 @
f642b05e
...
...
@@ -229,6 +229,7 @@ class GetAttr(Expr):
name
=
None
r
"""name: the qualified name of the attribute to be retrieved."""
def
__init__
(
self
,
module
,
name
,
type
=
None
,
orig_name
=
None
):
super
().
__init__
()
assert
isinstance
(
module
,
ModuleNode
)
...
...
@@ -276,6 +277,7 @@ class CallMethod(Expr):
method: the method name.
Default: "__call__"
"""
def
__init__
(
self
,
node
,
method
=
"__call__"
):
super
().
__init__
()
if
isinstance
(
node
,
type
):
...
...
@@ -351,6 +353,7 @@ class Apply(Expr):
opdef: the applied :class:`OpDef`.
"""
opdef
=
None
def
__init__
(
self
,
opdef
):
super
().
__init__
()
assert
isinstance
(
opdef
,
OpDef
)
...
...
@@ -422,6 +425,7 @@ class CallFunction(Expr):
Args:
func: a built-in function.
"""
def
__init__
(
self
,
func
):
super
().
__init__
()
assert
isinstance
(
func
,
Callable
)
...
...
imperative/python/megengine/traced_module/node.py
浏览文件 @
f642b05e
imperative/python/test/unit/traced_module/test_modification.py
浏览文件 @
f642b05e
...
...
@@ -5,12 +5,21 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
pickle
import
numpy
as
np
import
megengine.functional
as
F
import
megengine.module
as
M
from
megengine.module.identity
import
Identity
from
megengine.traced_module
import
trace_module
from
megengine.traced_module.expr
import
CallFunction
,
GetAttr
from
megengine.traced_module.expr
import
CallFunction
,
Expr
,
GetAttr
from
megengine.traced_module.node
import
Node
class
IdentityMod
(
M
.
Module
):
def
forward
(
self
,
x
):
return
x
class
MyBlock
(
M
.
Module
):
...
...
@@ -18,11 +27,13 @@ class MyBlock(M.Module):
super
(
MyBlock
,
self
).
__init__
()
self
.
conv1
=
M
.
Conv2d
(
in_channels
,
channels
,
3
,
1
,
padding
=
1
,
bias
=
False
)
self
.
bn1
=
M
.
BatchNorm2d
(
channels
)
self
.
nothing
=
IdentityMod
()
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
F
.
relu
(
x
)
+
1
x
=
self
.
nothing
(
x
)
return
x
...
...
@@ -31,10 +42,24 @@ class MyModule(M.Module):
super
(
MyModule
,
self
).
__init__
()
self
.
block0
=
MyBlock
()
self
.
block1
=
MyBlock
()
self
.
nothing
=
IdentityMod
()
def
forward
(
self
,
x
):
x
=
self
.
block0
(
x
)
x
=
self
.
block1
(
x
)
x
=
self
.
nothing
(
x
)
return
x
class
NewModule
(
M
.
Module
):
def
__init__
(
self
,
traced_module
):
super
(
NewModule
,
self
).
__init__
()
self
.
module
=
traced_module
def
forward
(
self
,
x
):
x
=
x
-
1
x
=
self
.
module
(
x
)
x
=
x
+
1
return
x
...
...
@@ -82,6 +107,12 @@ def test_delete():
graph
.
compile
()
np
.
testing
.
assert_allclose
(
expect
-
1
,
F
.
relu
(
traced_module
(
x
)
-
1
),
atol
=
1e-6
)
# clear graph
graph
.
replace_node
({
graph
.
outputs
[
0
]:
graph
.
inputs
[
1
]})
graph
.
compile
()
np
.
testing
.
assert_equal
(
len
(
list
(
graph
.
_exprs
)),
0
)
np
.
testing
.
assert_equal
(
traced_module
(
x
).
numpy
(),
x
.
numpy
())
def
test_flatten
():
traced_module
,
x
,
expect
=
_init_module
()
...
...
@@ -89,6 +120,74 @@ def test_flatten():
traced_module
.
graph
.
compile
()
assert
all
(
not
isinstance
(
i
,
GetAttr
)
for
i
in
traced_module
.
graph
.
_exprs
)
assert
len
(
traced_module
.
graph
.
_exprs
)
==
12
np
.
testing
.
assert_equal
(
expect
.
numpy
(),
traced_module
(
x
).
numpy
())
def
test_id_and_name
():
def
_check_id
(
traced_module
):
_total_ids
=
traced_module
.
graph
.
_total_ids
node_ids
=
[
n
.
_id
for
n
in
traced_module
.
graph
.
nodes
().
as_list
()]
assert
len
(
set
(
node_ids
))
==
len
(
node_ids
)
assert
max
(
node_ids
)
+
1
==
len
(
node_ids
)
expr_ids
=
[
n
.
_id
for
n
in
traced_module
.
graph
.
exprs
().
as_list
()]
assert
len
(
set
(
expr_ids
))
==
len
(
expr_ids
)
assert
max
(
expr_ids
)
+
1
==
_total_ids
[
1
]
def
_check_name
(
flatened_module
):
node_names
=
[
n
.
_name
for
n
in
flatened_module
.
graph
.
nodes
().
as_list
()]
assert
len
(
set
(
node_names
))
==
len
(
node_names
)
traced_module
,
x
,
expect
=
_init_module
()
_check_id
(
traced_module
)
flattened_module
=
traced_module
.
flatten
()
_check_id
(
flattened_module
)
_check_name
(
flattened_module
)
# pickle check
obj
=
pickle
.
dumps
(
traced_module
)
traced_module
=
pickle
.
loads
(
obj
)
Node
.
_set_next_id
(
159
)
Expr
.
_set_next_id
(
1024
)
graph
=
traced_module
.
graph
for
expr
in
graph
.
get_function_by_type
(
F
.
relu
).
as_list
():
relu_out
=
expr
.
outputs
[
0
]
cur_graph
=
expr
.
top_graph
with
cur_graph
.
insert_exprs
():
neg_out
=
F
.
neg
(
relu_out
)
cur_graph
.
replace_node
({
relu_out
:
neg_out
})
cur_graph
.
compile
()
_check_id
(
traced_module
)
flattened_module
=
traced_module
.
flatten
()
_check_id
(
flattened_module
)
_check_name
(
flattened_module
)
# check trace TracedModule
obj
=
pickle
.
dumps
(
traced_module
)
traced_module
=
pickle
.
loads
(
obj
)
module
=
NewModule
(
traced_module
)
traced_module
=
trace_module
(
module
,
x
)
_check_id
(
traced_module
)
flattened_module
=
traced_module
.
flatten
()
_check_id
(
flattened_module
)
_check_name
(
flattened_module
)
def
test_set_name
():
traced_module
,
x
,
expect
=
_init_module
()
graph
=
traced_module
.
graph
output_node
=
graph
.
outputs
[
0
]
def
rename
(
name
):
output_node
.
name
=
name
np
.
testing
.
assert_raises
(
AssertionError
,
rename
,
"block1_out"
)
rename
(
"output"
)
np
.
testing
.
assert_equal
(
str
(
graph
.
outputs
[
0
]),
"output"
)
def
test_extra_block
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录