Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
0185c1a9
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
0185c1a9
编写于
11月 26, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/traced_module): add argspec for top TracedModule
GitOrigin-RevId: 8e31a00c7e69b7efa15cfef6b7eee6861535eaea
上级
1daeba76
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
63 addition
and
5 deletion
+63
-5
imperative/python/megengine/traced_module/pytree.py
imperative/python/megengine/traced_module/pytree.py
+2
-0
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+10
-2
imperative/python/megengine/traced_module/utils.py
imperative/python/megengine/traced_module/utils.py
+9
-3
imperative/python/test/unit/traced_module/test_trace_module.py
...ative/python/test/unit/traced_module/test_trace_module.py
+42
-0
未找到文件。
imperative/python/megengine/traced_module/pytree.py
浏览文件 @
0185c1a9
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
import
collections
import
collections
from
collections
import
OrderedDict
,
defaultdict
from
collections
import
OrderedDict
,
defaultdict
from
functools
import
partial
from
functools
import
partial
from
inspect
import
FullArgSpec
from
typing
import
Callable
,
NamedTuple
from
typing
import
Callable
,
NamedTuple
import
numpy
as
np
import
numpy
as
np
...
@@ -53,6 +54,7 @@ SUPPORTED_LEAF_TYPE = {
...
@@ -53,6 +54,7 @@ SUPPORTED_LEAF_TYPE = {
QuantMode
,
QuantMode
,
ArgsIndex
,
ArgsIndex
,
Group
,
Group
,
FullArgSpec
,
}
}
USER_REGISTERED_LEAF_TYPE
=
[]
USER_REGISTERED_LEAF_TYPE
=
[]
...
...
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
0185c1a9
...
@@ -1928,8 +1928,11 @@ class TracedModule(Module):
...
@@ -1928,8 +1928,11 @@ class TracedModule(Module):
self
.
watch_node_value
=
{}
self
.
watch_node_value
=
{}
self
.
end_points
=
[]
self
.
end_points
=
[]
self
.
is_qat
=
is_qat
self
.
is_qat
=
is_qat
self
.
argspec
=
None
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
if
hasattr
(
self
,
"argspec"
)
and
self
.
argspec
is
not
None
:
args
,
kwargs
=
_convert_kwargs_to_args
(
self
.
argspec
,
args
,
kwargs
,
True
)
inputs
,
treedef
=
tree_flatten
(((
self
,
*
args
),
kwargs
))
inputs
,
treedef
=
tree_flatten
(((
self
,
*
args
),
kwargs
))
assert
treedef
in
self
.
argdef_graph_map
assert
treedef
in
self
.
argdef_graph_map
inputs
=
filter
(
inputs
=
filter
(
...
@@ -2422,8 +2425,12 @@ def trace_module(
...
@@ -2422,8 +2425,12 @@ def trace_module(
NodeMixin
.
wrap_safe
(
NodeMixin
.
wrap_safe
(
builder
,
Input
.
make
(
name
=
"top"
,
type
=
ModuleNode
,
qualname
=
net_name
)
builder
,
Input
.
make
(
name
=
"top"
,
type
=
ModuleNode
,
qualname
=
net_name
)
)
)
args
,
kwargs
=
_convert_kwargs_to_args
(
mod
.
forward
,
args
,
kwargs
,
True
)
forward_argspec
=
(
mod
.
argspec
if
hasattr
(
mod
,
"argspec"
)
else
inspect
.
getfullargspec
(
mod
.
forward
)
)
args
,
kwargs
=
_convert_kwargs_to_args
(
forward_argspec
,
args
,
kwargs
,
True
)
inputs
,
_
=
tree_flatten
((
args
,
kwargs
))
inputs
,
_
=
tree_flatten
((
args
,
kwargs
))
for
_
,
i
in
enumerate
(
inputs
):
for
_
,
i
in
enumerate
(
inputs
):
# assert isinstance(i, Tensor), "not support "
# assert isinstance(i, Tensor), "not support "
...
@@ -2439,6 +2446,7 @@ def trace_module(
...
@@ -2439,6 +2446,7 @@ def trace_module(
builder
(
*
args
,
**
kwargs
)
builder
(
*
args
,
**
kwargs
)
active_module_tracer
().
pop_scope
()
active_module_tracer
().
pop_scope
()
traced_mod
=
builder
.
build
()
traced_mod
=
builder
.
build
()
traced_mod
.
argspec
=
forward_argspec
traced_mod
.
graph
.
_reset_ids
()
traced_mod
.
graph
.
_reset_ids
()
return
traced_mod
return
traced_mod
finally
:
finally
:
...
...
imperative/python/megengine/traced_module/utils.py
浏览文件 @
0185c1a9
...
@@ -9,7 +9,8 @@ import collections
...
@@ -9,7 +9,8 @@ import collections
import
copy
import
copy
import
inspect
import
inspect
from
collections.abc
import
MutableMapping
,
MutableSequence
from
collections.abc
import
MutableMapping
,
MutableSequence
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Type
from
inspect
import
FullArgSpec
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Type
,
Union
from
..
import
get_logger
from
..
import
get_logger
from
..module
import
Module
from
..module
import
Module
...
@@ -57,9 +58,14 @@ def replace_container_with_module_container(container):
...
@@ -57,9 +58,14 @@ def replace_container_with_module_container(container):
return
has_module
,
module_container
return
has_module
,
module_container
def
_convert_kwargs_to_args
(
func
,
args
,
kwargs
,
is_bounded
=
False
):
def
_convert_kwargs_to_args
(
argspecs
:
Union
[
Callable
,
FullArgSpec
],
args
,
kwargs
,
is_bounded
=
False
):
# is_bounded = True when func is a method and provided args don't include 'self'
# is_bounded = True when func is a method and provided args don't include 'self'
arg_specs
=
inspect
.
getfullargspec
(
func
)
arg_specs
=
(
inspect
.
getfullargspec
(
argspecs
)
if
isinstance
(
argspecs
,
Callable
)
else
argspecs
)
assert
isinstance
(
arg_specs
,
FullArgSpec
)
arg_specs_args
=
arg_specs
.
args
arg_specs_args
=
arg_specs
.
args
if
is_bounded
:
if
is_bounded
:
arg_specs_args
=
arg_specs
.
args
[
1
:]
arg_specs_args
=
arg_specs
.
args
[
1
:]
...
...
imperative/python/test/unit/traced_module/test_trace_module.py
浏览文件 @
0185c1a9
...
@@ -5,6 +5,7 @@ import numpy as np
...
@@ -5,6 +5,7 @@ import numpy as np
import
megengine.functional
as
F
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.module
as
M
from
megengine
import
Tensor
from
megengine
import
Tensor
from
megengine.module.module
import
Module
from
megengine.traced_module
import
TracedModule
,
trace_module
from
megengine.traced_module
import
TracedModule
,
trace_module
from
megengine.traced_module.expr
import
CallFunction
from
megengine.traced_module.expr
import
CallFunction
...
@@ -89,5 +90,46 @@ def test_trace_module():
...
@@ -89,5 +90,46 @@ def test_trace_module():
m4
=
MyModule4
()
m4
=
MyModule4
()
tm4
=
trace_module
(
m4
,
a
,
b
)
tm4
=
trace_module
(
m4
,
a
,
b
)
np
.
testing
.
assert_equal
(
tm4
(
a
,
b
).
numpy
(),
3
)
np
.
testing
.
assert_equal
(
tm4
(
a
,
y
=
b
).
numpy
(),
3
)
np
.
testing
.
assert_equal
(
tm4
(
x
=
a
,
y
=
b
).
numpy
(),
3
)
tm4
=
trace_module
(
m4
,
a
,
y
=
b
)
np
.
testing
.
assert_equal
(
tm4
(
a
,
b
).
numpy
(),
3
)
np
.
testing
.
assert_equal
(
tm4
(
a
,
y
=
b
).
numpy
(),
3
)
np
.
testing
.
assert_equal
(
tm4
(
x
=
a
,
y
=
b
).
numpy
(),
3
)
tm4
=
trace_module
(
m4
,
x
=
a
,
y
=
b
)
np
.
testing
.
assert_equal
(
tm4
(
a
,
b
).
numpy
(),
3
)
np
.
testing
.
assert_equal
(
tm4
(
a
,
y
=
b
).
numpy
(),
3
)
np
.
testing
.
assert_equal
(
tm4
(
x
=
a
,
y
=
b
).
numpy
(),
3
)
tm5
=
trace_module
(
tm4
,
a
,
b
)
np
.
testing
.
assert_equal
(
tm5
(
a
,
b
).
numpy
(),
3
)
np
.
testing
.
assert_equal
(
tm5
(
a
,
y
=
b
).
numpy
(),
3
)
np
.
testing
.
assert_equal
(
tm5
(
x
=
a
,
y
=
b
).
numpy
(),
3
)
tm5
=
trace_module
(
tm4
,
a
,
y
=
b
)
np
.
testing
.
assert_equal
(
tm5
(
a
,
b
).
numpy
(),
3
)
np
.
testing
.
assert_equal
(
tm5
(
a
,
y
=
b
).
numpy
(),
3
)
np
.
testing
.
assert_equal
(
tm5
(
x
=
a
,
y
=
b
).
numpy
(),
3
)
tm5
=
trace_module
(
tm4
,
x
=
a
,
y
=
b
)
np
.
testing
.
assert_equal
(
tm5
(
a
,
b
).
numpy
(),
3
)
np
.
testing
.
assert_equal
(
tm5
(
a
,
y
=
b
).
numpy
(),
3
)
np
.
testing
.
assert_equal
(
tm5
(
x
=
a
,
y
=
b
).
numpy
(),
3
)
assert
len
(
tm4
.
graph
.
_exprs
)
==
1
assert
len
(
tm4
.
graph
.
_exprs
)
==
1
assert
isinstance
(
tm4
.
graph
.
_exprs
[
0
],
CallFunction
)
assert
isinstance
(
tm4
.
graph
.
_exprs
[
0
],
CallFunction
)
class
MyModule5
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
m1
=
tm4
def
forward
(
self
,
x
,
y
):
return
self
.
m1
(
x
,
y
)
tm6
=
trace_module
(
MyModule5
(),
a
,
b
)
assert
tm6
.
m1
.
argspec
is
None
assert
tm6
.
m1
.
_is_top
is
False
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录