Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
a6fe7f7f
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a6fe7f7f
编写于
10月 08, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mge/traced_module): refactor Node naming rule and merge GetAttr
GitOrigin-RevId: a43ad1273c84ddc03e93152d8500b68c5aad259a
上级
c48d58da
变更
6
展开全部
隐藏空白更改
内联
并排
Showing
6 changed file
with
802 addition
and
483 deletion
+802
-483
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+87
-91
imperative/python/megengine/traced_module/module_tracer.py
imperative/python/megengine/traced_module/module_tracer.py
+1
-2
imperative/python/megengine/traced_module/node.py
imperative/python/megengine/traced_module/node.py
+76
-25
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+346
-359
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+97
-6
imperative/python/test/unit/traced_module/test_qat_module.py
imperative/python/test/unit/traced_module/test_qat_module.py
+195
-0
未找到文件。
imperative/python/megengine/traced_module/expr.py
浏览文件 @
a6fe7f7f
...
...
@@ -11,7 +11,7 @@ import collections
import
copy
import
inspect
import
re
from
typing
import
Callable
,
Dict
,
List
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Union
from
..core._imperative_rt
import
OpDef
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
...
...
@@ -32,6 +32,43 @@ def rstrip(s: str, __chars: str):
return
s
def
get_suffix_name
(
prefix
:
str
,
name
:
str
):
if
prefix
==
name
:
return
""
matchd
=
re
.
compile
(
"^%s\.(.*)"
%
prefix
).
match
(
name
)
if
matchd
is
None
:
return
None
return
matchd
.
group
(
1
)
def
is_call_module
(
expr
):
return
(
isinstance
(
expr
,
CallMethod
)
and
isinstance
(
expr
.
inputs
[
0
],
ModuleNode
)
and
expr
.
method
==
"__call__"
)
def
is_call_tensor_method
(
expr
):
return
isinstance
(
expr
,
CallMethod
)
and
not
is_call_module
(
expr
)
def
is_call_function
(
expr
):
return
isinstance
(
expr
,
CallFunction
)
def
is_constant
(
expr
):
return
isinstance
(
expr
,
Constant
)
def
is_getattr
(
expr
):
return
isinstance
(
expr
,
GetAttr
)
def
is_apply_def
(
expr
):
return
isinstance
(
expr
,
Apply
)
class
Expr
:
r
"""``Expr`` represents the operations (i.e. ``CallMethod``, ``CallFunction``, ``Apply``,
``GetAttr``, ``Input``, ``Constant``) on ``Node``.
...
...
@@ -76,50 +113,19 @@ class Expr:
self
.
const_val
.
append
((
idx
,
val
))
def
add_outputs
(
self
,
outputs
):
assert
active_module_tracer
()
is
not
None
self
.
outputs
=
[]
if
outputs
is
not
None
:
if
not
isinstance
(
outputs
,
collections
.
Sequence
):
outputs
=
(
outputs
,)
name
=
None
orig_name
=
None
if
isinstance
(
self
,
CallMethod
):
name
=
self
.
inputs
[
0
].
_name
orig_name
=
self
.
inputs
[
0
].
_orig_name
assert
isinstance
(
name
,
str
),
"The name of ({}) must be a str"
.
format
(
self
.
inputs
[
0
]
)
assert
isinstance
(
orig_name
,
str
),
"The orig_name of ({}) must be a str"
.
format
(
self
.
inputs
[
0
])
name
=
rstrip
(
name
,
"_out"
)
if
self
.
method
==
"__call__"
:
name
+=
"_out"
orig_name
+=
"_out"
else
:
strip_method
=
self
.
method
.
strip
(
"_"
)
name
=
"%s_out"
%
strip_method
orig_name
=
name
elif
isinstance
(
self
,
CallFunction
):
name
=
self
.
func
.
__name__
+
"_out"
elif
isinstance
(
self
,
Apply
):
name
=
str
(
self
.
opdef
).
lower
()
+
"_out"
for
i
in
outputs
:
assert
isinstance
(
i
,
RawTensor
),
"The output must be a Tensor"
o_name
=
(
active_module_tracer
().
current_scope
().
_create_unique_name
(
name
)
)
self
.
outputs
.
append
(
NodeMixin
.
get_wrapped_type
(
i
)(
expr
=
self
,
name
=
o_name
,
orig_name
=
orig_name
if
orig_name
else
o_name
,
)
)
for
i
,
node
in
zip
(
outputs
,
self
.
outputs
,):
NodeMixin
.
wrap_safe
(
i
,
node
)
if
outputs
is
None
:
return
current_graph
=
active_module_tracer
().
current_scope
()
if
not
isinstance
(
outputs
,
collections
.
Sequence
):
outputs
=
(
outputs
,)
for
i
in
outputs
:
assert
isinstance
(
i
,
RawTensor
),
"The output must be a Tensor"
node
=
NodeMixin
.
get_wrapped_type
(
i
)(
expr
=
self
,
name
=
""
,
qualname
=
""
,)
NodeMixin
.
wrap_safe
(
i
,
node
)
self
.
outputs
.
append
(
node
)
current_graph
.
_namespace
.
auto_naming_for_outputs
(
self
)
def
unflatten_args
(
self
,
inputs
):
if
self
.
arg_def
is
not
None
:
...
...
@@ -152,9 +158,7 @@ class Expr:
),
"({}) must be generated before ({})"
.
format
(
repl_node
,
self
)
idx
=
self
.
inputs
.
index
(
node
)
self
.
inputs
[
idx
]
=
repl_node
user_idx
=
node
.
users
.
index
(
self
)
assert
user_idx
>=
0
node
.
users
.
pop
(
user_idx
)
node
.
users
.
remove
(
self
)
repl_node
.
users
.
append
(
self
)
@
property
...
...
@@ -197,26 +201,23 @@ class Input(Expr):
r
"""A fake Expr which is used to mark the input of graph."""
name
=
None
def
__init__
(
self
,
name
=
None
,
type
=
None
,
orig_name
=
None
):
def
__init__
(
self
,
type
:
List
[
Node
],
name
:
str
=
"args"
,
qualname
:
str
=
""
):
super
().
__init__
()
assert
type
in
[
ModuleNode
,
TensorNode
]
assert
name
and
qualname
self
.
inputs
=
[]
node_cls
=
type
if
type
else
Node
if
orig_name
is
None
:
orig_name
=
name
self
.
outputs
=
[
node_cls
(
self
,
name
=
name
,
orig_name
=
orig_
name
),
node_cls
(
self
,
name
=
name
,
qualname
=
qual
name
),
]
self
.
name
=
name
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
assert
active_module_tracer
()
is
not
None
expr
=
cls
(
*
args
,
**
kwargs
)
oup_node
=
expr
.
outputs
[
0
]
name
=
(
active_module_tracer
().
current_scope
().
_create_unique_name
(
oup_node
.
_name
)
)
oup_node
.
_name
=
name
active_module_tracer
().
current_scope
().
_add_input
(
oup_node
)
out_node
=
expr
.
outputs
[
0
]
active_module_tracer
().
current_scope
().
_add_input
(
out_node
)
return
expr
.
outputs
[
0
]
def
__repr__
(
self
):
...
...
@@ -230,34 +231,41 @@ class GetAttr(Expr):
name
=
None
r
"""name: the qualified name of the attribute to be retrieved."""
def
__init__
(
self
,
module
,
name
,
type
=
None
,
orig_name
=
None
):
def
__init__
(
self
,
module
:
ModuleNode
,
type
:
Union
[
Node
],
attr_name
:
str
,
name
:
str
=
""
,
):
super
().
__init__
()
assert
isinstance
(
module
,
ModuleNode
)
assert
type
in
[
TensorNode
,
ModuleNode
]
self
.
inputs
=
[
module
,
]
module
.
users
.
append
(
self
)
self
.
name
=
name
node_cls
=
type
if
type
else
Node
self
.
name
=
attr_name
self
.
outputs
=
[
node_cls
(
self
,
name
=
name
,
orig_name
=
orig_name
),
type
(
self
,
name
=
name
,
qualname
=
"{}.{}"
.
format
(
module
.
qualname
,
attr_name
)
),
]
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
assert
active_module_tracer
()
is
not
None
current_graph
=
active_module_tracer
().
current_scope
()
expr
=
cls
(
*
args
,
**
kwargs
)
module
=
expr
.
inputs
[
0
]
oup_name
=
expr
.
name
while
module
.
_name
!=
"self"
:
oup_name
=
module
.
_name
+
"_"
+
oup_name
module
=
module
.
expr
.
inputs
[
0
]
oup_name
=
active_module_tracer
().
current_scope
().
_create_unique_name
(
oup_name
)
expr
.
outputs
[
0
].
_name
=
oup_name
active_module_tracer
().
current_scope
().
_insert
(
expr
)
current_graph
.
_namespace
.
auto_naming_for_outputs
(
expr
)
current_graph
.
_insert
(
expr
)
return
expr
.
outputs
[
0
]
def
interpret
(
self
,
*
inputs
):
return
(
getattr
(
inputs
[
0
],
self
.
name
),)
mod
=
inputs
[
0
]
module_path
,
_
,
name
=
self
.
name
.
rpartition
(
"."
)
if
module_path
==
""
:
return
(
getattr
(
mod
,
name
),)
module_names
=
module_path
.
split
(
"."
)
for
item
in
module_names
:
mod
=
getattr
(
mod
,
item
)
if
not
isinstance
(
mod
,
Module
):
raise
AttributeError
(
"`{}` is not an Module"
.
format
(
item
))
return
(
getattr
(
mod
,
name
),)
def
__repr__
(
self
):
out_type
=
"Tensor"
...
...
@@ -297,6 +305,7 @@ class CallMethod(Expr):
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
assert
active_module_tracer
()
is
not
None
expr
=
cls
(
*
args
,
**
kwargs
)
active_module_tracer
().
current_scope
().
_insert
(
expr
)
return
expr
...
...
@@ -362,6 +371,7 @@ class Apply(Expr):
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
assert
active_module_tracer
()
is
not
None
expr
=
cls
(
*
args
,
**
kwargs
)
active_module_tracer
().
current_scope
().
_insert
(
expr
)
return
expr
...
...
@@ -435,6 +445,7 @@ class CallFunction(Expr):
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
assert
active_module_tracer
()
is
not
None
expr
=
cls
(
*
args
,
**
kwargs
)
active_module_tracer
().
current_scope
().
_insert
(
expr
)
return
expr
...
...
@@ -474,7 +485,7 @@ class Constant(Expr):
# TODO: constant cache to reduce the size of dumped model
_constant_cache
=
{}
def
__init__
(
self
,
c
,
name
=
None
):
def
__init__
(
self
,
c
,
name
:
str
=
""
,
qualname
:
str
=
""
):
super
().
__init__
()
assert
isinstance
(
c
,
(
RawTensor
,
Module
))
if
isinstance
(
c
,
Module
):
...
...
@@ -484,31 +495,16 @@ class Constant(Expr):
self
.
inputs
=
[]
node_cls
=
NodeMixin
.
get_wrapped_type
(
c
)
self
.
outputs
=
[
node_cls
(
self
,
name
=
name
,
orig_name
=
name
),
node_cls
(
self
,
name
=
name
,
qualname
=
qual
name
),
]
self
.
outputs
[
0
].
_name
=
name
if
name
else
"const_"
+
str
(
self
.
_id
)
@
classmethod
def
make
(
cls
,
*
args
,
**
kwargs
):
assert
active_module_tracer
()
is
not
None
expr
=
cls
(
*
args
,
**
kwargs
)
name
=
"const_module"
if
isinstance
(
expr
.
value
,
Module
)
else
"const_tensor"
full_name
=
name
if
(
isinstance
(
expr
.
value
,
RawTensor
)
and
id
(
expr
.
value
)
in
active_module_tracer
().
id2name
):
full_name
=
active_module_tracer
().
id2name
[
id
(
expr
.
value
)]
scope_name
=
active_module_tracer
().
current_scope
().
_module_name
if
full_name
and
scope_name
:
full_name
=
(
"self."
+
full_name
)[
len
(
scope_name
)
+
1
:]
else
:
full_name
=
name
else
:
full_name
=
name
name
=
active_module_tracer
().
current_scope
().
_create_unique_name
(
full_name
)
expr
.
outputs
[
0
].
_name
=
name
expr
.
outputs
[
0
].
_orig_name
=
full_name
active_module_tracer
().
current_scope
().
_insert
(
expr
)
current_graph
=
active_module_tracer
().
current_scope
()
current_graph
.
_namespace
.
auto_naming_for_outputs
(
expr
)
current_graph
.
_insert
(
expr
)
return
expr
.
outputs
[
0
]
def
interpret
(
self
,
*
inputs
):
...
...
imperative/python/megengine/traced_module/module_tracer.py
浏览文件 @
a6fe7f7f
...
...
@@ -128,10 +128,9 @@ class module_tracer:
_active_scopes
=
None
def
__init__
(
self
,
wrap_fn
,
id2name
):
def
__init__
(
self
,
wrap_fn
):
self
.
_active_scopes
=
[]
self
.
patcher
=
Patcher
(
wrap_fn
)
self
.
id2name
=
id2name
@
classmethod
def
register_as_builtin
(
cls
,
mod
):
...
...
imperative/python/megengine/traced_module/node.py
浏览文件 @
a6fe7f7f
...
...
@@ -29,17 +29,15 @@ class Node:
__total_id
=
0
# type: int
_id
=
None
# type: int
_top_graph
=
None
# type: weakref.ReferenceType
_name
=
None
# type: str
_orig_name
=
None
# type: str
_format_spec
=
""
# type: str
def
__init__
(
self
,
expr
,
name
:
str
,
orig_
name
:
str
):
def
__init__
(
self
,
expr
,
name
:
str
,
qual
name
:
str
):
self
.
expr
=
expr
self
.
users
=
[]
# List[Expr]
self
.
_id
=
Node
.
__total_id
Node
.
__total_id
+=
1
self
.
_name
=
name
self
.
_
orig_name
=
orig_
name
self
.
_
qualname
=
qual
name
self
.
actual_node
=
[]
# type: List[Node]
def
__repr__
(
self
):
...
...
@@ -54,21 +52,10 @@ class Node:
name
=
""
if
format_spec
in
[
"i"
,
"p"
,
"ip"
,
"pi"
]:
if
"p"
in
format_spec
:
graph
=
self
.
top_graph
prefix_name
=
""
if
graph
is
not
None
:
prefix_name
=
graph
.
_name
if
graph
.
_prefix_name
:
prefix_name
=
"{}_{}"
.
format
(
graph
.
_prefix_name
,
prefix_name
.
lstrip
(
"_"
)
)
if
name
:
name
=
"_"
+
name
.
lstrip
(
"_"
)
name
=
"{}{}"
.
format
(
prefix_name
,
name
)
prefix_name
=
self
.
top_graph
.
_name
name
=
"{}_{}"
.
format
(
prefix_name
,
name
)
if
"i"
in
format_spec
:
if
name
:
name
=
"_"
+
name
.
lstrip
(
"_"
)
name
=
"%{}{}"
.
format
(
self
.
_id
,
name
)
name
=
"%{}_{}"
.
format
(
self
.
_id
,
name
)
return
name
else
:
return
name
if
name
else
(
"%d"
%
self
.
_id
)
...
...
@@ -80,15 +67,62 @@ class Node:
@
name
.
setter
def
name
(
self
,
new_name
:
str
):
r
"""Set a new name to this Node."""
graph
=
self
.
top_graph
assert
graph
is
not
None
,
"The parent graph of this Node cannot be None."
assert
new_name
not
in
graph
.
_used_names
,
(
assert
new_name
not
in
graph
.
_
namespace
.
used_names
,
(
"The name(%s) is already in use. Please try a different one again."
%
(
new_name
)
)
new_name
=
graph
.
_create_unique_name
(
new_name
)
new_name
=
graph
.
_
namespace
.
create_unique_name
(
new_name
)
self
.
_name
=
new_name
self
.
_orig_name
=
new_name
@
property
def
qualname
(
self
):
r
"""Get the `qualname` of this Node. The `qualname` can be used to get the
submodule from the traced Module or Module.
Example:
.. code-block::
import megengine.module as M
import megengine.functional as F
import megengine.traced_module as tm
import megengine as mge
class block(M.Module):
def __init__(self):
super().__init__()
self.param = mge.Tensor([1.])
self.relu = M.ReLU()
def forward(self, x):
x = x + self.param
return self.relu(F.relu(x))
class module(M.Module):
def __init__(self):
super().__init__()
self.block = block()
def forward(self, x):
x = self.block(x)
return x
net = module()
traced_net = tm.trace_module(net, mge.Tensor([0.]))
traced_net = traced_net.flatten()
out_node = traced_net.graph.outputs[0]
# qualname : "module.block.relu.[out]"
qualname = out_node.qualname
# qualname : "block.relu"
qualname = qualname.split(".", 1)[-1].rsplit(".", 1)[0]
assert qualname in list(map(lambda x: x[0], net.named_modules()))
assert qualname in list(map(lambda x: x[0], traced_net.named_modules()))
"""
return
self
.
_qualname
@
property
def
top_graph
(
self
):
...
...
@@ -120,8 +154,8 @@ class ModuleNode(Node):
r
"""The type of the Module correspending to the ModuleNode."""
_owner
=
None
# type: weakref.ReferenceType
def
__init__
(
self
,
expr
,
name
:
str
=
None
,
orig_
name
:
str
=
None
):
super
().
__init__
(
expr
,
name
,
orig_
name
)
def
__init__
(
self
,
expr
,
name
:
str
=
None
,
qual
name
:
str
=
None
):
super
().
__init__
(
expr
,
name
,
qual
name
)
def
__getstate__
(
self
):
return
{
...
...
@@ -129,10 +163,15 @@ class ModuleNode(Node):
"users"
:
self
.
users
,
"_id"
:
self
.
_id
,
"_name"
:
self
.
_name
,
"_
orig_name"
:
self
.
_orig_
name
,
"_
qualname"
:
self
.
_qual
name
,
"module_type"
:
self
.
module_type
,
}
def
__setstate__
(
self
,
state
):
if
"_orig_name"
in
state
:
state
[
"_qualname"
]
=
state
.
pop
(
"_orig_name"
)
self
.
__dict__
.
update
(
state
)
@
property
def
owner
(
self
):
r
"""Get the ``Module`` corresponding to this ``ModuleNode``.
...
...
@@ -161,9 +200,21 @@ class TensorNode(Node):
"_dtype"
:
self
.
_dtype
,
"_device"
:
self
.
_device
,
"_name"
:
self
.
_name
,
"_
orig_name"
:
self
.
_orig_
name
,
"_
qualname"
:
self
.
_qual
name
,
}
def
__setstate__
(
self
,
state
):
if
"_orig_name"
in
state
:
qualname
=
state
.
pop
(
"_orig_name"
)
modulepath
,
comma
,
qualname
=
qualname
.
rpartition
(
"."
)
expr_name
=
state
[
"expr"
].
__class__
.
__name__
if
expr_name
not
in
[
"GetAttr"
]:
qualname
=
"[{}]"
.
format
(
qualname
)
if
comma
:
qualname
=
"{}.{}"
.
format
(
modulepath
,
qualname
)
state
[
"_qualname"
]
=
qualname
self
.
__dict__
.
update
(
state
)
@
property
def
shape
(
self
):
r
"""Get the shape of this Node."""
...
...
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
a6fe7f7f
此差异已折叠。
点击以展开。
imperative/python/test/unit/traced_module/test_modification.py
浏览文件 @
a6fe7f7f
...
...
@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
pickle
from
itertools
import
chain
import
numpy
as
np
...
...
@@ -13,8 +14,8 @@ import megengine.functional as F
import
megengine.module
as
M
from
megengine.module.identity
import
Identity
from
megengine.traced_module
import
trace_module
from
megengine.traced_module.expr
import
CallFunction
,
Expr
,
GetAttr
from
megengine.traced_module.node
import
Node
from
megengine.traced_module.expr
import
CallFunction
,
CallMethod
,
Expr
,
GetAttr
,
Input
from
megengine.traced_module.node
import
ModuleNode
,
Node
class
IdentityMod
(
M
.
Module
):
...
...
@@ -85,6 +86,34 @@ def test_search():
relu_expr
=
graph
.
get_function_by_type
(
F
.
relu
).
as_unique
()
assert
isinstance
(
relu_expr
,
CallFunction
)
and
relu_expr
.
func
==
F
.
relu
conv_node
=
graph
.
get_module_by_type
(
M
.
Conv2d
).
as_unique
()
assert
isinstance
(
conv_node
,
ModuleNode
)
and
conv_node
.
module_type
==
M
.
Conv2d
add_expr
=
graph
.
get_method_by_type
(
"__add__"
).
as_unique
()
assert
isinstance
(
add_expr
,
CallMethod
)
and
add_expr
.
method
==
"__add__"
conv_node
=
graph
.
get_node_by_name
(
"MyBlock_conv1"
).
as_unique
()
assert
isinstance
(
conv_node
,
ModuleNode
)
and
conv_node
.
module_type
==
M
.
Conv2d
def
test_producer_and_users
():
traced_module
,
*
_
=
_init_module
()
def
_check
(
exprs
):
for
expr
in
exprs
:
for
n
in
chain
(
expr
.
inputs
,
expr
.
outputs
):
if
not
isinstance
(
n
.
expr
,
Input
):
assert
n
.
expr
in
exprs
for
e
in
n
.
users
:
assert
e
in
exprs
assert
n
in
e
.
inputs
for
mod
in
traced_module
.
modules
():
if
not
hasattr
(
mod
,
"argdef_graph_map"
):
continue
for
g
in
mod
.
argdef_graph_map
.
values
():
_check
(
g
.
_exprs
)
def
test_insert
():
traced_module
,
x
,
expect
=
_init_block
()
...
...
@@ -97,6 +126,54 @@ def test_insert():
np
.
testing
.
assert_allclose
(
expect
-
1
,
1
-
traced_module
(
x
),
atol
=
1e-6
)
def
test_insert_module
():
class
Neg
(
M
.
Module
):
def
forward
(
self
,
x
):
return
F
.
neg
(
x
)
traced_module
,
x
,
expect
=
_init_block
()
graph
=
traced_module
.
graph
relu_out
=
graph
.
get_function_by_type
(
F
.
relu
).
as_unique
().
outputs
[
0
]
self
=
graph
.
inputs
[
0
]
setattr
(
traced_module
,
"neg"
,
Neg
())
with
graph
.
insert_exprs
():
neg_out
=
self
.
neg
(
relu_out
)
graph
.
replace_node
({
relu_out
:
neg_out
})
graph
.
compile
()
np
.
testing
.
assert_allclose
(
expect
-
1
,
1
-
traced_module
(
x
),
atol
=
1e-6
)
assert
traced_module
.
neg
.
graph
is
not
None
assert
len
(
traced_module
.
neg
.
graph
.
_exprs
)
==
1
def
test_add_input_and_output
():
traced_module
,
x
,
y
=
_init_module
()
data_node
=
traced_module
.
graph
.
add_input_node
(
shape
=
(
1
,
3
,
224
,
224
),
name
=
"data"
)
traced_module
.
graph
.
add_output_node
(
data_node
)
assert
data_node
.
name
==
"data"
assert
traced_module
.
graph
.
inputs
[
-
1
]
==
data_node
assert
len
(
traced_module
.
graph
.
inputs
)
==
3
assert
len
(
traced_module
.
graph
.
outputs
)
==
2
y1
,
y2
=
traced_module
(
x
,
x
)
np
.
testing
.
assert_equal
(
y1
.
numpy
(),
y
.
numpy
())
np
.
testing
.
assert_equal
(
y2
.
numpy
(),
x
.
numpy
())
y1
,
y2
=
traced_module
(
x
,
y
)
np
.
testing
.
assert_equal
(
y2
.
numpy
(),
y
.
numpy
())
traced_module
.
graph
.
reset_outputs
(
({
"orig_out"
:
traced_module
.
graph
.
outputs
[
0
]},
traced_module
.
graph
.
outputs
[
1
])
)
out
=
traced_module
(
x
,
x
)
assert
isinstance
(
out
,
tuple
)
assert
isinstance
(
out
[
0
],
dict
)
np
.
testing
.
assert_equal
(
out
[
0
][
"orig_out"
].
numpy
(),
y
.
numpy
())
np
.
testing
.
assert_equal
(
out
[
1
].
numpy
(),
x
.
numpy
())
def
test_delete
():
traced_module
,
x
,
expect
=
_init_block
()
graph
=
traced_module
.
graph
...
...
@@ -117,8 +194,10 @@ def test_delete():
def
test_flatten
():
traced_module
,
x
,
expect
=
_init_module
()
traced_module
=
traced_module
.
flatten
()
traced_module
.
graph
.
compile
()
assert
all
(
not
isinstance
(
i
,
GetAttr
)
for
i
in
traced_module
.
graph
.
_exprs
)
assert
len
(
traced_module
.
graph
.
_exprs
)
==
12
np
.
testing
.
assert_equal
(
expect
.
numpy
(),
traced_module
(
x
).
numpy
())
traced_module
=
traced_module
.
flatten
()
assert
len
(
traced_module
.
graph
.
_exprs
)
==
12
np
.
testing
.
assert_equal
(
expect
.
numpy
(),
traced_module
(
x
).
numpy
())
...
...
@@ -128,7 +207,7 @@ def test_id_and_name():
_total_ids
=
traced_module
.
graph
.
_total_ids
node_ids
=
[
n
.
_id
for
n
in
traced_module
.
graph
.
nodes
().
as_list
()]
assert
len
(
set
(
node_ids
))
==
len
(
node_ids
)
assert
max
(
node_ids
)
+
1
==
len
(
node_ids
)
assert
max
(
node_ids
)
+
1
==
_total_ids
[
0
]
expr_ids
=
[
n
.
_id
for
n
in
traced_module
.
graph
.
exprs
().
as_list
()]
assert
len
(
set
(
expr_ids
))
==
len
(
expr_ids
)
...
...
@@ -177,7 +256,7 @@ def test_id_and_name():
_check_name
(
flattened_module
)
def
test_set_name
():
def
test_set_n
ode_n
ame
():
traced_module
,
x
,
expect
=
_init_module
()
graph
=
traced_module
.
graph
output_node
=
graph
.
outputs
[
0
]
...
...
@@ -190,6 +269,18 @@ def test_set_name():
np
.
testing
.
assert_equal
(
str
(
graph
.
outputs
[
0
]),
"output"
)
def
test_set_graph_name
():
traced_module
,
x
,
expect
=
_init_module
()
graph
=
traced_module
.
graph
output_node
=
graph
.
outputs
[
0
]
node_name
=
output_node
.
name
graph
.
name
=
"Top"
node
=
graph
.
get_node_by_name
(
"{}_{}"
.
format
(
"Top"
,
node_name
)).
as_unique
()
assert
node
is
output_node
def
test_extra_block
():
class
PostProcess
(
M
.
Module
):
def
forward
(
self
,
x
):
...
...
imperative/python/test/unit/traced_module/test_qat_module.py
0 → 100644
浏览文件 @
a6fe7f7f
import
io
from
functools
import
partial
from
itertools
import
chain
from
typing
import
Callable
import
numpy
as
np
import
megengine
as
mge
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.quantization
as
Q
from
megengine
import
Tensor
from
megengine.module.qat.module
import
QATModule
from
megengine.traced_module
import
TracedModule
,
trace_module
def
get_subattr
(
self
:
M
.
Module
,
name
:
str
):
if
name
==
""
:
return
self
module_path
,
_
,
name
=
name
.
rpartition
(
"."
)
if
module_path
==
""
:
return
getattr
(
self
,
name
)
module_names
=
module_path
.
split
(
"."
)
for
item
in
module_names
:
self
=
getattr
(
self
,
item
)
if
not
isinstance
(
self
,
M
.
Module
):
raise
AttributeError
(
"`{}` is not an Module"
.
format
(
item
))
return
getattr
(
self
,
name
)
class
Myblcok
(
M
.
Module
):
def
__init__
(
self
,):
super
().
__init__
()
self
.
conv0
=
M
.
ConvBnRelu2d
(
3
,
3
,
3
,
1
,
1
)
self
.
conv1
=
M
.
ConvBn2d
(
3
,
3
,
1
,
1
,
0
)
self
.
conv2
=
M
.
ConvBn2d
(
3
,
3
,
1
,
1
,
0
)
self
.
add
=
M
.
Elemwise
(
"FUSE_ADD_RELU"
)
def
forward
(
self
,
x
):
x
=
self
.
conv0
(
x
)
x0
=
self
.
conv1
(
x
)
x1
=
self
.
conv2
(
x
)
o
=
self
.
add
(
x0
,
x1
)
return
o
class
MyModule
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
block0
=
Myblcok
()
self
.
block1
=
Myblcok
()
def
forward
(
self
,
x
):
x
=
self
.
block0
(
x
)
x
=
self
.
block1
(
x
)
return
x
class
MyMinMaxObserver
(
Q
.
MinMaxObserver
):
pass
class
MyTQT
(
Q
.
TQT
):
pass
def
get_lsq_config
(
lsq_cls
):
return
Q
.
QConfig
(
weight_observer
=
None
,
act_observer
=
None
,
weight_fake_quant
=
partial
(
lsq_cls
,
dtype
=
"qint8_narrow"
),
act_fake_quant
=
partial
(
lsq_cls
,
dtype
=
"qint8"
),
)
def
get_observer_config
(
observer_cls
):
return
Q
.
QConfig
(
weight_observer
=
partial
(
observer_cls
,
dtype
=
"qint8_narrow"
),
act_observer
=
partial
(
observer_cls
,
dtype
=
"qint8"
),
weight_fake_quant
=
None
,
act_fake_quant
=
None
,
)
def
get_qparams
(
mod
:
QATModule
):
weight_qparams
,
act_qparams
=
None
,
None
if
mod
.
act_observer
is
not
None
:
act_qparams
=
mod
.
act_observer
.
get_qparams
()
if
mod
.
act_fake_quant
:
act_qparams
=
mod
.
act_fake_quant
.
get_qparams
()
if
mod
.
weight_observer
is
not
None
:
weight_qparams
=
mod
.
weight_observer
.
get_qparams
()
if
mod
.
weight_fake_quant
:
weight_qparams
=
mod
.
weight_fake_quant
.
get_qparams
()
return
weight_qparams
,
act_qparams
def
check_qparams
(
qparmsa
:
Q
.
QParams
,
qparmsb
:
Q
.
QParams
):
assert
qparmsa
.
dtype_meta
==
qparmsb
.
dtype_meta
assert
qparmsa
.
mode
==
qparmsb
.
mode
np
.
testing
.
assert_equal
(
qparmsa
.
scale
.
numpy
(),
qparmsb
.
scale
.
numpy
())
if
qparmsa
.
zero_point
is
not
None
:
np
.
testing
.
assert_equal
(
qparmsa
.
zero_point
.
numpy
(),
qparmsb
.
zero_point
.
numpy
())
def
build_observered_net
(
net
:
M
.
Module
,
observer_cls
):
qat_net
=
Q
.
quantize_qat
(
net
,
qconfig
=
get_observer_config
(
observer_cls
))
Q
.
enable_observer
(
qat_net
)
for
_
in
range
(
5
):
inp
=
Tensor
(
np
.
random
.
random
(
size
=
(
5
,
3
,
32
,
32
)))
qat_net
(
inp
)
Q
.
disable_observer
(
qat_net
)
return
qat_net
def
build_fakequanted_net
(
net
:
QATModule
,
fakequant_cls
):
qat_net
=
Q
.
reset_qconfig
(
net
,
get_lsq_config
(
fakequant_cls
))
return
qat_net
def
test_trace_qat
():
def
_check_qat_module
(
qat_net
:
QATModule
):
inp
=
Tensor
(
np
.
random
.
random
(
size
=
(
5
,
3
,
32
,
32
)))
traced_net
=
trace_module
(
qat_net
,
inp
)
for
name
,
qat_module
in
qat_net
.
named_modules
():
if
not
isinstance
(
qat_module
,
QATModule
):
continue
traced_qat_module
=
get_subattr
(
traced_net
,
name
)
weight_qparams
,
act_qparams
=
get_qparams
(
qat_module
)
traced_weight_qparams
,
traced_act_qparams
=
get_qparams
(
traced_qat_module
)
if
weight_qparams
:
check_qparams
(
weight_qparams
,
traced_weight_qparams
)
if
act_qparams
:
check_qparams
(
act_qparams
,
traced_act_qparams
)
_check_qat_module
(
build_observered_net
(
MyModule
(),
Q
.
MinMaxObserver
))
_check_qat_module
(
build_observered_net
(
MyModule
(),
MyMinMaxObserver
))
_check_qat_module
(
build_fakequanted_net
(
build_observered_net
(
MyModule
(),
Q
.
MinMaxObserver
),
Q
.
TQT
)
)
_check_qat_module
(
build_fakequanted_net
(
build_observered_net
(
MyModule
(),
Q
.
MinMaxObserver
),
MyTQT
)
)
def
test_load_param
():
def
_check_param
(
moda
:
M
.
Module
,
modb
:
M
.
Module
):
for
name
,
attr
in
chain
(
moda
.
named_parameters
(),
moda
.
named_buffers
()):
traced_attr
=
get_subattr
(
modb
,
name
)
np
.
testing
.
assert_equal
(
attr
.
numpy
(),
traced_attr
.
numpy
())
def
_check_module
(
build_func
:
Callable
):
net
=
build_func
()
buffer
=
io
.
BytesIO
()
mge
.
save
(
net
.
state_dict
(),
buffer
)
buffer
.
seek
(
0
)
inp
=
Tensor
(
np
.
random
.
random
(
size
=
(
5
,
3
,
32
,
32
)))
traced_net
=
trace_module
(
build_func
(),
inp
)
traced_net
.
load_state_dict
(
mge
.
load
(
buffer
))
_check_param
(
net
,
traced_net
)
buffer
.
seek
(
0
)
traced_net
=
trace_module
(
build_func
(),
inp
).
flatten
()
traced_net
.
load_state_dict
(
mge
.
load
(
buffer
))
_check_param
(
net
,
traced_net
)
_check_module
(
lambda
:
MyModule
())
_check_module
(
lambda
:
build_observered_net
(
MyModule
(),
Q
.
MinMaxObserver
))
def
test_qualname
():
def
_check_qualname
(
net
):
inp
=
Tensor
(
np
.
random
.
random
(
size
=
(
5
,
3
,
32
,
32
)))
traced_net
=
trace_module
(
net
,
inp
)
base_qualname
=
traced_net
.
graph
.
qualname
for
node
in
traced_net
.
graph
.
nodes
():
qualname
=
node
.
qualname
qualname
=
qualname
[
len
(
base_qualname
)
+
1
:]
if
qualname
.
endswith
(
"]"
):
qualname
=
qualname
.
rsplit
(
"."
,
1
)[
0
]
if
qualname
.
startswith
(
"["
):
qualname
=
""
traced_attr
=
get_subattr
(
traced_net
,
qualname
)
orig_attr
=
get_subattr
(
net
,
qualname
)
assert
traced_attr
is
not
None
assert
orig_attr
is
not
None
_check_qualname
(
MyModule
())
_check_qualname
(
build_observered_net
(
MyModule
(),
Q
.
MinMaxObserver
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录