Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3a219209
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3a219209
编写于
12月 27, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mge/traced_module): fix some bugs
GitOrigin-RevId: 88f98829cec88df120a05318209f78f065699bf8
上级
88c192c8
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
69 addition
and
20 deletion
+69
-20
imperative/python/megengine/traced_module/__init__.py
imperative/python/megengine/traced_module/__init__.py
+2
-0
imperative/python/megengine/traced_module/_passes/const_pass.py
...tive/python/megengine/traced_module/_passes/const_pass.py
+11
-5
imperative/python/megengine/traced_module/pytree.py
imperative/python/megengine/traced_module/pytree.py
+43
-7
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+7
-4
imperative/python/megengine/traced_module/utils.py
imperative/python/megengine/traced_module/utils.py
+6
-4
未找到文件。
imperative/python/megengine/traced_module/__init__.py
浏览文件 @
3a219209
...
...
@@ -9,6 +9,7 @@
from
..core._imperative_rt.core2
import
set_cpp_apply_module_trace
from
.
import
compat
from
._passes
import
optimize
from
.pytree
import
register_supported_type
from
.traced_module
import
(
TracedModule
,
_register_all_builtin_module
,
...
...
@@ -23,6 +24,7 @@ set_cpp_apply_module_trace(cpp_apply_module_trace)
__all__
=
[
"register_as_builtin"
,
"register_supported_type"
,
"trace_module"
,
"wrap"
,
"TracedModule"
,
...
...
imperative/python/megengine/traced_module/_passes/const_pass.py
浏览文件 @
3a219209
...
...
@@ -12,7 +12,7 @@ from ...core.ops.builtin import GetVarShape
from
...logger
import
get_logger
from
...tensor
import
Tensor
from
..expr
import
Constant
,
Expr
,
is_apply_def
,
is_constant
,
is_getattr
from
..node
import
Node
,
TensorNode
from
..node
import
Node
,
NodeMixin
,
TensorNode
from
.matcher
import
PatternMatcher
from
.pass_base
import
BackwardPass
,
ForwardPass
,
register_pass
from
.pattern
import
is_op
...
...
@@ -21,6 +21,12 @@ from .utils import get_const_value
logger
=
get_logger
(
__name__
)
def
_as_const_node
(
x
):
node
=
Constant
.
make
(
x
)
NodeMixin
.
wrap
(
x
,
node
)
return
node
@
register_pass
(
"AttrToConstant"
)
class
AttrToConstant
(
BackwardPass
):
r
"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr."""
...
...
@@ -35,10 +41,10 @@ class AttrToConstant(BackwardPass):
orig_node
=
expr
.
outputs
[
0
]
name
=
orig_node
.
name
with
graph
.
insert_exprs
(
expr
):
const_node
=
Constant
.
make
(
value
,
name
=
nam
e
)
const_node
=
_as_const_node
(
valu
e
)
graph
.
replace_node
({
orig_node
:
const_node
})
graph
.
compile
()
name
=
orig_node
.
name
const_node
.
name
=
name
return
const_node
.
expr
...
...
@@ -53,7 +59,7 @@ class FixInputShape(BackwardPass):
shape
=
Tensor
(
expr
.
inputs
[
0
].
shape
,
dtype
=
"int32"
)
graph
=
expr
.
top_graph
with
graph
.
insert_exprs
(
expr
):
const_shape
=
Constant
.
mak
e
(
shape
)
const_shape
=
_as_const_nod
e
(
shape
)
graph
.
replace_node
({
expr
.
outputs
[
0
]:
const_shape
})
graph
.
compile
()
const_shape
.
name
=
expr
.
outputs
[
0
].
name
...
...
@@ -73,7 +79,7 @@ class FlodConstant(ForwardPass):
const_var
=
expr
.
interpret
(
*
[
get_const_value
(
n
.
expr
)
for
n
in
expr
.
inputs
])[
0
]
graph
=
expr
.
top_graph
with
graph
.
insert_exprs
(
expr
):
const_node
=
Constant
.
mak
e
(
const_var
)
const_node
=
_as_const_nod
e
(
const_var
)
graph
.
replace_node
({
expr
.
outputs
[
0
]:
const_node
})
graph
.
compile
()
const_node
.
name
=
expr
.
outputs
[
0
].
name
...
...
imperative/python/megengine/traced_module/pytree.py
浏览文件 @
3a219209
...
...
@@ -10,7 +10,7 @@ import collections
from
collections
import
OrderedDict
,
defaultdict
from
functools
import
partial
from
inspect
import
FullArgSpec
from
typing
import
Callable
,
Named
Tuple
from
typing
import
Any
,
Callable
,
List
,
NamedTuple
,
Tuple
import
numpy
as
np
...
...
@@ -46,6 +46,8 @@ SUPPORTED_LEAF_TYPE = {
int
,
float
,
bool
,
bytes
,
bytearray
,
QuantDtypeMeta
,
CompNode
,
Device
,
...
...
@@ -74,18 +76,51 @@ SUPPORTED_LEAF_CLS = [
NodeType
=
NamedTuple
(
"NodeType"
,
[(
"flatten"
,
Callable
),
(
"unflatten"
,
Callable
)])
def
register_supported_type
(
type
,
flatten
=
None
,
unflatten
=
None
):
def
register_supported_type
(
type
,
flatten_fn
:
Callable
[[
Any
],
Tuple
[
List
,
Any
]]
=
None
,
unflatten_fn
:
Callable
[[
List
,
Any
],
Any
]
=
None
,
):
r
"""Call this function to register the ``type`` as a built-in type. The registered ``type``
can be used and serialized correctly in :py:class:`TracedModule`.
Examples:
.. code-block::
def dict_flatten(obj: Dict):
context, values = [], []
# obj.keys() needs to be sortable
keys = sorted(obj.keys())
for key in keys:
values.append(obj[key])
context.append(key)
return values, tuple(context)
def dict_unflatten(values: List, context: Any):
return dict(zip(context, values))
register_supported_type(dict, dict_flatten, dict_unflatten)
Args:
type: the type that needs to be registered.
flatten_fn: a function that should take an object created from ``type`` and return a
flat list of values. It can also return some context that is used in reconstructing
the object. Default: None
unflatten_fn: a function that should take a flat list of values and some context
(returned by flatten_fn). It returns the object by reconstructing
it from the list and the context. Default: None
"""
tp_info
=
(
type
.
__module__
,
type
.
__qualname__
)
if
flatten
and
unflatte
n
:
if
flatten
_fn
and
unflatten_f
n
:
USER_REGISTERED_CONTAINER_TYPE
.
append
(
tp_info
)
else
:
USER_REGISTERED_LEAF_TYPE
.
append
(
tp_info
)
_register_supported_type
(
type
,
flatten
,
unflatte
n
)
_register_supported_type
(
type
,
flatten
_fn
,
unflatten_f
n
)
def
_register_supported_type
(
type
,
flatten
=
None
,
unflatte
n
=
None
):
if
flatten
and
unflatte
n
:
SUPPORTED_TYPE
[
type
]
=
NodeType
(
flatten
,
unflatte
n
)
def
_register_supported_type
(
type
,
flatten
_fn
=
None
,
unflatten_f
n
=
None
):
if
flatten
_fn
and
unflatten_f
n
:
SUPPORTED_TYPE
[
type
]
=
NodeType
(
flatten
_fn
,
unflatten_f
n
)
else
:
SUPPORTED_LEAF_CLS
.
append
(
type
)
...
...
@@ -131,6 +166,7 @@ _register_supported_type(
_register_supported_type
(
OrderedDict
,
partial
(
_dict_flatten
,
True
),
partial
(
_dict_unflatten
,
OrderedDict
)
)
_register_supported_type
(
slice
,
lambda
x
:
([
x
.
start
,
x
.
stop
,
x
.
step
],
None
),
...
...
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
3a219209
...
...
@@ -42,6 +42,7 @@ from ..core._imperative_rt.core2 import (
)
from
..core._trace_option
import
set_symbolic_shape
from
..module
import
Module
from
..module
import
external
as
MExternal
from
..module.qat
import
QATModule
from
..quantization.fake_quant
import
LSQ
,
TQT
,
FakeQuantize
,
_FakeQuantize
from
..quantization.observer
import
(
...
...
@@ -207,6 +208,7 @@ def _wrap_method_to_tensor_node():
for
method
in
get_tensor_wrapable_method
():
patch
=
PatchedFn
(
TensorNode
,
method
)
if
type
(
getattr
(
Tensor
,
method
))
==
property
:
# Only support property.getter
patch
.
set_func
(
property
(
_any_method
(
method
,
patch
.
origin_fn
)))
else
:
patch
.
set_func
(
_any_method
(
method
,
patch
.
origin_fn
))
...
...
@@ -351,14 +353,14 @@ class _InsertExprs:
assert
(
node
.
top_graph
==
self
.
graph
),
"The input node ({}) is not in the graph ({})"
.
format
(
node
,
self
.
graph
)
if
isinstance
(
node
,
TensorNode
)
and
node
.
expr
in
self
.
graph
.
_exprs
:
if
node
.
expr
in
self
.
graph
.
_exprs
:
max_inp_expr_idx
=
max
(
max_inp_expr_idx
,
self
.
graph
.
_exprs
.
index
(
node
.
expr
)
)
max_inp_expr_idx
+=
1
insert_index
=
-
1
if
self
.
expr
i
s
not
None
:
if
self
.
expr
i
n
self
.
graph
.
_exprs
:
insert_index
=
self
.
graph
.
_exprs
.
index
(
self
.
expr
)
insert_index
+=
1
...
...
@@ -2070,7 +2072,8 @@ class TracedModule(Module):
for
inp_def
,
graph
in
self
.
argdef_graph_map
.
items
():
if
top_graph
is
not
None
:
graph
.
_top_graph
=
weakref
.
ref
(
top_graph
)
for
n
in
graph
.
_inputs
+
graph
.
outputs
:
for
n
in
graph
.
_inputs
+
graph
.
_outputs
:
n
.
expr
.
_top_graph
=
weakref
.
ref
(
graph
)
n
.
_top_graph
=
weakref
.
ref
(
graph
)
graph
.
_inputs
[
0
].
_owner
=
weakref
.
ref
(
self
)
for
i
,
n
in
enumerate
(
graph
.
_inputs
):
...
...
@@ -2375,7 +2378,7 @@ def wrap(func: Callable):
def
_register_all_builtin_module
():
for
sub_mod
in
[
M
,
M
.
qat
,
M
.
quantized
]:
for
sub_mod
in
[
M
,
M
.
qat
,
M
.
quantized
,
MExternal
]:
for
m
in
getmembers
(
sub_mod
):
if
(
isclass
(
m
[
1
])
...
...
imperative/python/megengine/traced_module/utils.py
浏览文件 @
3a219209
...
...
@@ -126,10 +126,12 @@ def _check_obj_attr(obj):
for
_
,
v
in
obj
.
items
():
leafs
,
_
=
tree_flatten
(
v
,
is_leaf
=
lambda
_
:
True
)
for
leaf
in
leafs
:
assert
_check_leaf_type
(
leaf
),
"Type {} is not supported by traced module"
.
format
(
leaf
if
isinstance
(
leaf
,
type
)
else
type
(
leaf
)
assert
_check_leaf_type
(
leaf
),
(
"Type {} is not supported in TracedModule serialization by default. "
"If you want to save this object to file, please call tm.register_supported_type({}) "
"before saving."
.
format
(
leaf
if
isinstance
(
leaf
,
type
)
else
type
(
leaf
),
type
(
leaf
).
__name__
)
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录