Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7b19bc76
MegEngine
项目概览
MegEngine 天元
/
MegEngine
12 个月 前同步成功
通知
393
Star
4703
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
7b19bc76
编写于
9月 08, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(traced_module): support traced module backward compatible serialization
GitOrigin-RevId: aaa9e51c74c11fa7955ae7bbfac476fa9bcf0d7d
上级
ffbfe59c
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
1314 addition
and
191 deletion
+1314
-191
imperative/python/megengine/traced_module/__init__.py
imperative/python/megengine/traced_module/__init__.py
+1
-0
imperative/python/megengine/traced_module/compat.py
imperative/python/megengine/traced_module/compat.py
+136
-0
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+276
-27
imperative/python/megengine/traced_module/module_tracer.py
imperative/python/megengine/traced_module/module_tracer.py
+0
-1
imperative/python/megengine/traced_module/node.py
imperative/python/megengine/traced_module/node.py
+54
-5
imperative/python/megengine/traced_module/pytree.py
imperative/python/megengine/traced_module/pytree.py
+51
-31
imperative/python/megengine/traced_module/serialization.py
imperative/python/megengine/traced_module/serialization.py
+146
-18
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+207
-79
imperative/python/megengine/traced_module/utils.py
imperative/python/megengine/traced_module/utils.py
+101
-1
imperative/python/test/unit/core/test_serialization.py
imperative/python/test/unit/core/test_serialization.py
+0
-23
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+3
-3
imperative/python/test/unit/traced_module/test_qat_module.py
imperative/python/test/unit/traced_module/test_qat_module.py
+24
-2
imperative/python/test/unit/traced_module/test_serialization.py
...tive/python/test/unit/traced_module/test_serialization.py
+315
-1
未找到文件。
imperative/python/megengine/traced_module/__init__.py
浏览文件 @
7b19bc76
...
...
@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
..core._imperative_rt.core2
import
set_cpp_apply_module_trace
from
.
import
compat
from
.traced_module
import
(
TracedModule
,
_register_all_builtin_module
,
...
...
imperative/python/megengine/traced_module/compat.py
0 → 100644
浏览文件 @
7b19bc76
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
from
..
import
tensor
from
..core.ops.builtin
import
BatchNorm
from
.expr
import
CallMethod
,
Constant
from
.node
import
TensorNode
from
.serialization
import
(
register_functional_loader
,
register_module_loader
,
register_opdef_loader
,
register_tensor_method_loader
,
)
"""
# Expr loaders examples
from ..core.ops.builtin import Elemwise
@register_opdef_loader(Elemwise)
def add_opdef_loader(expr):
if expr.opdef_state["mode"] == "ADD":
expr.opdef_state["mode"] == "MUL"
node = expr.inputs[1]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr,
shape=node.shape,
dtype=expr.inputs[0].dtype,
qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
expr.inputs[1] = oup
@register_functional_loader(("megengine.functional.nn", "conv2d"))
def conv2df_loader(expr):
# expr.func = ("megengine.functional.nn","conv2d")
kwargs = expr.kwargs
orig_weight = expr.named_args["weight"]
astype_expr = CallMethod(orig_weight, "astype")
oup = TensorNode(
astype_expr,
shape=orig_weight.shape,
dtype=orig_weight.dtype,
qparams=orig_weight.qparams,
)
astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
astype_expr.return_val = (oup,)
expr.set_arg("weight", oup)
@register_module_loader(("megengine.module.conv", "Conv2d"))
def conv2dm_loader(expr):
module = expr.inputs[0].owner
args = list(expr.args)
orig_inp = args[1]
astype_expr = CallMethod(orig_inp, "astype")
oup = TensorNode(
astype_expr,
shape=orig_inp.shape,
dtype=orig_inp.dtype,
qparams=orig_inp.qparams,
)
astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)
@register_tensor_method_loader("__add__")
def add_loader(expr):
args = list(expr.args)
if not isinstance(args[1], TensorNode):
args[1] = tensor(args[1])
node = Constant(args[1], "const").outputs[0]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)
"""
@
register_module_loader
(
(
"megengine.module.batchnorm"
,
"BatchNorm1d"
),
(
"megengine.module.batchnorm"
,
"BatchNorm2d"
),
(
"megengine.module.batchnorm"
,
"SyncBatchNorm"
),
)
def
bn2d_module_loader
(
expr
):
# mge 1.6
if
not
hasattr
(
expr
,
"version"
):
module
=
expr
.
inputs
[
0
].
owner
if
not
hasattr
(
module
,
"param_dim"
):
module
.
param_dim
=
"dim_1c11"
@
register_module_loader
(
(
"megengine.module.conv_bn"
,
"ConvBn2d"
),
(
"megengine.module.conv_bn"
,
"ConvBnRelu2d"
),
(
"megengine.module.qat.conv_bn"
,
"ConvBn2d"
),
(
"megengine.module.qat.conv_bn"
,
"ConvBnRelu2d"
),
)
def
convbn2d_module_loader
(
expr
):
# mge 1.6
if
not
hasattr
(
expr
,
"version"
):
module
=
expr
.
inputs
[
0
].
owner
if
not
hasattr
(
module
.
bn
,
"param_dim"
):
module
.
bn
.
param_dim
=
"dim_1c11"
@
register_opdef_loader
(
BatchNorm
)
def
bn_opdef_loader
(
expr
):
# mge 1.6
if
not
hasattr
(
expr
,
"version"
):
output
=
expr
.
outputs
[
-
1
]
oup
=
TensorNode
(
expr
,
shape
=
(
0
,),
dtype
=
None
,
qparams
=
output
.
_qparams
,)
expr
.
outputs
.
insert
(
4
,
oup
)
imperative/python/megengine/traced_module/expr.py
浏览文件 @
7b19bc76
...
...
@@ -11,19 +11,28 @@ import collections
import
copy
import
inspect
import
re
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Union
import
weakref
from
importlib
import
import_module
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
..core._imperative_rt
import
OpDef
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
apply
,
set_module_tracing
,
unset_module_tracing
from
..core._imperative_rt.core2
import
(
apply
,
is_tracing_module
,
set_module_tracing
,
unset_module_tracing
,
)
from
..core.ops.builtin
import
FakeQuant
from
..core.ops.special
import
Const
from
..module
import
Module
from
..tensor
import
Parameter
,
Tensor
from
..version
import
__version__
from
.module_tracer
import
active_module_tracer
,
module_tracer
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.pytree
import
ArgsIndex
,
TreeDef
,
_is_const_leaf
,
_is_leaf
,
tree_flatten
from
.serialization
import
get_opdef_state
,
load_opdef_from_state
from
.serialization
import
_ModuleState
from
.utils
import
_check_builtin_module_attr
,
_check_obj_attr
,
_convert_kwargs_to_args
def
rstrip
(
s
:
str
,
__chars
:
str
):
...
...
@@ -112,6 +121,7 @@ class Expr:
node
.
users
.
append
(
self
)
else
:
assert
node
is
None
assert
not
isinstance
(
val
,
(
Module
,
RawTensor
))
assert
_is_leaf
(
val
)
and
_is_const_leaf
(
val
)
idx
=
len
(
self
.
inputs
)
+
len
(
self
.
const_val
)
self
.
const_val
.
append
((
idx
,
val
))
...
...
@@ -132,14 +142,14 @@ class Expr:
current_graph
.
_namespace
.
auto_naming_for_outputs
(
self
)
def
unflatten_args
(
self
,
inputs
):
if
self
.
arg_def
is
not
None
:
inputs
=
list
(
inputs
)
for
idx
,
val
in
self
.
const_val
:
inputs
.
insert
(
idx
,
val
)
args
,
kwargs
=
self
.
arg_def
.
unflatten
(
inputs
)
return
args
,
kwargs
else
:
return
inputs
,
{}
assert
self
.
arg_def
is
not
None
,
"{} expr doesn't have args/kwargs"
.
format
(
type
(
self
).
__name__
)
inputs
=
list
(
inputs
)
for
idx
,
val
in
self
.
const_val
:
inputs
.
insert
(
idx
,
val
)
args
,
kwargs
=
self
.
arg_def
.
unflatten
(
inputs
)
return
args
,
kwargs
def
replace_inputs
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
r
"""Replace the input Nodes of this Expr.
...
...
@@ -165,6 +175,39 @@ class Expr:
node
.
users
.
remove
(
self
)
repl_node
.
users
.
append
(
self
)
@
property
def
_support_set_args_kwargs
(
self
):
return
False
def
set_args_kwargs
(
self
,
*
args
,
**
kwargs
):
r
""" Set args and kwargs for Expr.
"""
assert
(
self
.
_support_set_args_kwargs
),
"Doesn't support set args/kwargs for {} expr"
.
format
(
type
(
self
).
__name__
)
args
,
kwargs
=
_convert_kwargs_to_args
(
self
.
_get_func
(),
args
,
kwargs
)
inputs
,
arg_def
=
tree_flatten
((
args
,
kwargs
))
orig_inputs
=
self
.
inputs
self
.
inputs
=
[]
self
.
const_val
=
[]
for
val
in
inputs
:
if
isinstance
(
val
,
(
TensorNode
,
ModuleNode
)):
self
.
inputs
.
append
(
val
)
else
:
assert
_is_leaf
(
val
)
and
_is_const_leaf
(
val
)
idx
=
len
(
self
.
inputs
)
+
len
(
self
.
const_val
)
self
.
const_val
.
append
((
idx
,
val
))
for
n
in
orig_inputs
:
if
n
not
in
self
.
inputs
:
n
.
users
.
remove
(
self
)
for
n
in
self
.
inputs
:
if
n
not
in
orig_inputs
:
n
.
users
.
append
(
self
)
self
.
arg_def
=
arg_def
@
property
def
kwargs
(
self
):
r
"""Get the keyword arguments of the operation corresponding to this Expr."""
...
...
@@ -177,6 +220,61 @@ class Expr:
args
,
_
=
self
.
unflatten_args
(
self
.
inputs
)
return
args
def
_get_func
(
self
):
# get called function when the expr is interpreted
raise
NotImplementedError
@
property
def
named_args
(
self
):
func
=
self
.
_get_func
()
return
inspect
.
getcallargs
(
func
,
*
self
.
args
,
**
self
.
kwargs
)
def
set_arg
(
self
,
name
,
val
):
func
=
self
.
_get_func
()
if
name
in
self
.
kwargs
:
new_kwargs
=
self
.
kwargs
new_kwargs
[
name
]
=
val
self
.
set_args_kwargs
(
*
self
.
args
,
**
new_kwargs
)
else
:
arg_spec
=
inspect
.
getfullargspec
(
func
)
if
name
in
arg_spec
.
args
:
ind
=
arg_spec
.
args
.
index
(
name
)
new_args
=
list
(
self
.
args
)
new_args
[
ind
]
=
val
self
.
set_args_kwargs
(
*
new_args
)
elif
name
==
arg_spec
.
varargs
:
assert
arg_spec
.
varargs
is
not
None
assert
len
(
self
.
args
)
>=
len
(
arg_spec
.
args
)
val
=
(
val
,)
if
not
isinstance
(
val
,
Sequence
)
else
val
self
.
set_args_kwargs
(
*
self
.
args
[
0
:
len
(
arg_spec
.
args
)],
*
val
)
else
:
assert
(
arg_spec
.
varkw
is
not
None
),
"func {} does't have argument named {}"
.
format
(
func
,
name
)
new_kwargs
=
self
.
kwargs
new_kwargs
[
name
]
=
val
self
.
set_args_kwargs
(
*
self
.
args
,
**
new_kwargs
)
@
property
def
return_val
(
self
):
return
self
.
out_def
.
unflatten
(
self
.
outputs
)
@
return_val
.
setter
def
return_val
(
self
,
new_outputs
):
outputs
,
out_def
=
tree_flatten
(
new_outputs
,
is_leaf
=
lambda
x
:
isinstance
(
x
,
Node
)
)
assert
all
(
isinstance
(
o
,
Node
)
for
o
in
outputs
),
"Return values of expr must be ModuleNode or TensorNode or Container with them"
assert
all
(
o
.
expr
in
(
None
,
self
)
for
o
in
outputs
),
"Some nodes are produced by other expr, can not be output of expr {}"
.
format
(
self
)
self
.
outputs
=
outputs
self
.
out_def
=
out_def
@
property
def
top_graph
(
self
):
r
"""Get the parent graph of this Expr."""
...
...
@@ -184,12 +282,6 @@ class Expr:
return
self
.
_top_graph
()
return
None
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
if
"_top_graph"
in
state
:
state
.
pop
(
"_top_graph"
)
return
state
@
classmethod
def
_get_next_id
(
cls
):
return
cls
.
__total_id
...
...
@@ -199,6 +291,23 @@ class Expr:
assert
isinstance
(
id
,
int
)
cls
.
__total_id
=
id
def
__copy__
(
self
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
result
.
__dict__
.
update
(
self
.
__dict__
)
return
result
def
__deepcopy__
(
self
,
memo
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
state
=
{}
memo
[
id
(
self
)]
=
result
for
k
,
v
in
self
.
__dict__
.
items
():
if
not
isinstance
(
v
,
weakref
.
ReferenceType
):
state
[
k
]
=
copy
.
deepcopy
(
v
,
memo
)
result
.
__dict__
.
update
(
state
)
return
result
# expr: None (i.e. fake expression which is used to mark input)
class
Input
(
Expr
):
...
...
@@ -229,6 +338,17 @@ class Input(Expr):
def
__repr__
(
self
):
return
"%{}:
\t
{} = Input()"
.
format
(
self
.
_id
,
self
.
outputs
[
0
])
def
__getstate__
(
self
):
state
=
{
"_id"
:
self
.
_id
,
"_disable_remove"
:
self
.
_disable_remove
,
"inputs"
:
self
.
inputs
,
"outputs"
:
self
.
outputs
,
"name"
:
self
.
name
,
}
_check_obj_attr
(
state
)
return
state
# expr: outputs = getattr(inputs[0], self.name)
class
GetAttr
(
Expr
):
...
...
@@ -276,11 +396,23 @@ class GetAttr(Expr):
def
__repr__
(
self
):
out_type
=
"Tensor"
if
isinstance
(
self
.
outputs
[
0
],
ModuleNode
):
out_type
=
self
.
outputs
[
0
].
module_type
.
__name__
m_type
=
self
.
outputs
[
0
].
module_type
out_type
=
m_type
.
__name__
if
isinstance
(
m_type
,
type
)
else
m_type
[
1
]
return
'%{}:
\t
{} = getattr({}, "{}") -> ({})'
.
format
(
self
.
_id
,
self
.
outputs
[
0
],
self
.
inputs
[
0
],
self
.
name
,
out_type
)
def
__getstate__
(
self
):
state
=
{
"_id"
:
self
.
_id
,
"_disable_remove"
:
self
.
_disable_remove
,
"inputs"
:
self
.
inputs
,
"outputs"
:
self
.
outputs
,
"name"
:
self
.
name
,
}
_check_obj_attr
(
state
)
return
state
# expr: outputs = inputs[0].__call__(*inputs[1:])
class
CallMethod
(
Expr
):
...
...
@@ -307,6 +439,7 @@ class CallMethod(Expr):
node
,
]
self
.
const_val
=
[]
self
.
arg_def
=
tree_flatten
(((
node
,),
{}))[
1
]
self
.
method
=
method
@
classmethod
...
...
@@ -342,6 +475,27 @@ class CallMethod(Expr):
outputs
,
_
=
tree_flatten
(
outputs
,
is_leaf
=
lambda
x
:
isinstance
(
x
,
RawTensor
))
return
outputs
def
_get_func
(
self
):
if
isinstance
(
self
.
args
[
0
],
type
):
obj_type
=
self
.
args
[
0
]
elif
isinstance
(
self
.
args
[
0
],
ModuleNode
):
obj_type
=
self
.
args
[
0
].
module_type
else
:
assert
isinstance
(
self
.
args
[
0
],
TensorNode
)
obj_type
=
Tensor
meth
=
getattr
(
obj_type
,
"forward"
if
issubclass
(
obj_type
,
Module
)
else
self
.
method
)
return
meth
@
property
def
_support_set_args_kwargs
(
self
):
# only expr call tensor method or builtin module support modify args/kwargs
return
(
isinstance
(
self
.
args
[
0
],
(
TensorNode
,
type
))
or
self
.
args
[
0
].
module_type
is
not
Module
)
def
__repr__
(
self
):
args
=
", "
.
join
(
str
(
i
)
for
i
in
self
.
args
[
1
:])
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
...
...
@@ -359,6 +513,21 @@ class CallMethod(Expr):
", "
.
join
([
args
,
kwargs
]),
)
def
__getstate__
(
self
):
state
=
{
"_id"
:
self
.
_id
,
"_disable_remove"
:
self
.
_disable_remove
,
"inputs"
:
self
.
inputs
,
"const_val"
:
self
.
const_val
,
"method"
:
self
.
method
,
"arg_def"
:
self
.
arg_def
,
"out_def"
:
self
.
out_def
,
"outputs"
:
self
.
outputs
,
"version"
:
__version__
,
}
_check_obj_attr
(
state
)
return
state
# expr: outputs = apply(self.opdef, *inputs)
class
Apply
(
Expr
):
...
...
@@ -394,14 +563,32 @@ class Apply(Expr):
)
def
__getstate__
(
self
):
state
=
super
().
__getstate__
()
state
[
"opdef"
]
=
get_opdef_state
(
state
[
"opdef"
])
opdef_state
=
self
.
opdef
.
__getstate__
()
opdef_state
[
"opdef_type"
]
=
type
(
self
.
opdef
)
state
=
{
"_id"
:
self
.
_id
,
"_disable_remove"
:
self
.
_disable_remove
,
"opdef_state"
:
opdef_state
,
"inputs"
:
self
.
inputs
,
"outputs"
:
self
.
outputs
,
"version"
:
__version__
,
}
_check_obj_attr
(
state
)
return
state
def
__setstate__
(
self
,
state
):
state
[
"opdef"
]
=
load_opdef_from_state
(
state
[
"opdef"
])
for
k
,
v
in
state
.
items
():
setattr
(
self
,
k
,
v
)
# compat with mge 1.6
if
"opdef"
in
state
and
"opdef_state"
not
in
state
:
opdef_state
=
state
.
pop
(
"opdef"
)
opdef_state
[
"opdef_type"
]
=
opdef_state
.
pop
(
"type"
)
state
[
"opdef_state"
]
=
opdef_state
self
.
__dict__
.
update
(
state
)
assert
isinstance
(
state
[
"opdef_state"
],
dict
)
opdef_state
=
state
[
"opdef_state"
].
copy
()
opdef_type
=
opdef_state
.
pop
(
"opdef_type"
)
opdef_obj
=
opdef_type
()
opdef_obj
.
__setstate__
(
opdef_state
)
setattr
(
self
,
"opdef"
,
opdef_obj
)
@
classmethod
def
apply_module_trace_hook
(
cls
,
opdef
,
*
inputs
):
...
...
@@ -458,12 +645,24 @@ class CallFunction(Expr):
def
interpret
(
self
,
*
inputs
):
args
,
kwargs
=
self
.
unflatten_args
(
inputs
)
outputs
=
self
.
func
(
*
args
,
**
kwargs
)
func
=
(
self
.
func
if
not
is_tracing_module
()
else
active_module_tracer
().
patcher
.
wrap_fn
(
self
.
func
)
)
outputs
=
func
(
*
args
,
**
kwargs
)
if
outputs
is
None
:
return
outputs
outputs
,
_
=
tree_flatten
(
outputs
,
is_leaf
=
lambda
x
:
isinstance
(
x
,
RawTensor
))
return
outputs
def
_get_func
(
self
):
return
self
.
func
@
property
def
_support_set_args_kwargs
(
self
):
return
True
def
__repr__
(
self
):
args
=
", "
.
join
(
str
(
i
)
for
i
in
self
.
args
)
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
...
...
@@ -477,6 +676,33 @@ class CallFunction(Expr):
", "
.
join
([
args
,
kwargs
]),
)
def
__getstate__
(
self
):
state
=
{
"_id"
:
self
.
_id
,
"_disable_remove"
:
self
.
_disable_remove
,
"func"
:
(
self
.
func
.
__module__
,
self
.
func
.
__qualname__
),
"const_val"
:
self
.
const_val
,
"inputs"
:
self
.
inputs
,
"arg_def"
:
self
.
arg_def
,
"out_def"
:
self
.
out_def
,
"outputs"
:
self
.
outputs
,
"version"
:
__version__
,
}
_check_obj_attr
(
state
)
return
state
def
__setstate__
(
self
,
state
):
self
.
__dict__
.
update
(
state
)
try
:
if
isinstance
(
self
.
func
,
tuple
):
mname
,
fname
=
self
.
func
f
=
import_module
(
mname
)
for
i
in
fname
.
split
(
"."
):
f
=
getattr
(
f
,
i
)
self
.
func
=
f
except
Exception
:
pass
# expr outputs = self.value
class
Constant
(
Expr
):
...
...
@@ -496,6 +722,13 @@ class Constant(Expr):
assert
isinstance
(
c
,
(
RawTensor
,
Module
))
if
isinstance
(
c
,
Module
):
assert
module_tracer
.
is_builtin
(
c
)
or
c
.
is_qat
if
isinstance
(
c
,
RawTensor
):
if
is_tracing_module
():
unset_module_tracing
()
c
=
Tensor
(
c
)
set_module_tracing
()
else
:
c
=
Tensor
(
c
)
self
.
value
=
c
self
.
name
=
name
self
.
inputs
=
[]
...
...
@@ -530,9 +763,25 @@ class Constant(Expr):
)
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
if
"_top_graph"
in
state
:
state
.
pop
(
"_top_graph"
)
state
=
{
"_id"
:
self
.
_id
,
"_disable_remove"
:
self
.
_disable_remove
,
"value"
:
self
.
value
,
"name"
:
self
.
name
,
"inputs"
:
self
.
inputs
,
"outputs"
:
self
.
outputs
,
}
_check_obj_attr
(
state
)
if
isinstance
(
self
.
value
,
RawTensor
):
state
[
"value"
]
=
Tensor
(
self
.
value
)
if
isinstance
(
self
.
value
,
Module
)
and
module_tracer
.
is_builtin
(
self
.
value
):
_check_builtin_module_attr
(
self
.
value
)
state
[
"value"
]
=
_ModuleState
.
get_module_state
(
self
.
value
)
return
state
def
__setstate__
(
self
,
state
):
for
k
,
v
in
state
.
items
():
if
isinstance
(
v
,
_ModuleState
):
state
[
k
]
=
v
.
to_module
()
self
.
__dict__
.
update
(
state
)
imperative/python/megengine/traced_module/module_tracer.py
浏览文件 @
7b19bc76
...
...
@@ -72,7 +72,6 @@ BUILTIN_ARRAY_METHOD = [
"astype"
,
"reshape"
,
"_broadcast"
,
"transpose"
,
"flatten"
,
"sum"
,
"prod"
,
...
...
imperative/python/megengine/traced_module/node.py
浏览文件 @
7b19bc76
...
...
@@ -6,7 +6,9 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
abc
import
copy
import
weakref
from
importlib
import
import_module
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Type
import
numpy
...
...
@@ -14,7 +16,9 @@ import numpy
from
..
import
get_logger
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..module
import
Module
from
..quantization.utils
import
QParams
from
..tensor
import
Tensor
from
.utils
import
_check_obj_attr
logger
=
get_logger
(
__name__
)
...
...
@@ -145,6 +149,23 @@ class Node:
assert
isinstance
(
id
,
int
)
cls
.
__total_id
=
id
def
__copy__
(
self
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
result
.
__dict__
.
update
(
self
.
__dict__
)
return
result
def
__deepcopy__
(
self
,
memo
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
state
=
{}
memo
[
id
(
self
)]
=
result
for
k
,
v
in
self
.
__dict__
.
items
():
if
not
isinstance
(
v
,
weakref
.
ReferenceType
)
and
k
!=
"actual_node"
:
state
[
k
]
=
copy
.
deepcopy
(
v
,
memo
)
result
.
__dict__
.
update
(
state
)
return
result
class
ModuleNode
(
Node
):
r
"""``ModuleNode`` represents the Module objects."""
...
...
@@ -157,19 +178,28 @@ class ModuleNode(Node):
super
().
__init__
(
expr
,
name
,
qualname
)
def
__getstate__
(
self
):
return
{
state
=
{
"expr"
:
self
.
expr
,
"users"
:
self
.
users
,
"_id"
:
self
.
_id
,
"_name"
:
self
.
_name
,
"_qualname"
:
self
.
_qualname
,
"module_type"
:
self
.
module_type
,
"module_type"
:
(
self
.
module_type
.
__module__
,
self
.
module_type
.
__qualname__
)
,
}
_check_obj_attr
(
state
)
return
state
def
__setstate__
(
self
,
state
):
if
"_orig_name"
in
state
:
state
[
"_qualname"
]
=
state
.
pop
(
"_orig_name"
)
self
.
__dict__
.
update
(
state
)
try
:
if
isinstance
(
self
.
module_type
,
tuple
):
mname
,
classname
=
self
.
module_type
mtype
=
getattr
(
import_module
(
mname
),
classname
)
self
.
module_type
=
mtype
except
Exception
:
pass
@
property
def
owner
(
self
):
...
...
@@ -185,12 +215,26 @@ class TensorNode(Node):
_shape
=
None
# type: Tuple[int]
_dtype
=
None
# type: numpy.dtype
_qparams
=
None
_qparams
=
None
# type: QParams
_device
=
None
_value
=
None
# type: Tensor
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
,
qualname
:
str
=
None
,
shape
:
Tuple
[
int
]
=
None
,
dtype
:
numpy
.
dtype
=
None
,
qparams
:
QParams
=
None
,
):
super
().
__init__
(
expr
,
name
,
qualname
)
self
.
_shape
=
shape
self
.
_dtype
=
shape
self
.
_qparams
=
qparams
def
__getstate__
(
self
):
return
{
state
=
{
"expr"
:
self
.
expr
,
"users"
:
self
.
users
,
"_id"
:
self
.
_id
,
...
...
@@ -201,6 +245,8 @@ class TensorNode(Node):
"_name"
:
self
.
_name
,
"_qualname"
:
self
.
_qualname
,
}
_check_obj_attr
(
state
)
return
state
def
__setstate__
(
self
,
state
):
if
"_orig_name"
in
state
:
...
...
@@ -276,7 +322,10 @@ class NodeMixin(abc.ABC):
assert
isinstance
(
node
,
TensorNode
)
assert
isinstance
(
value
,
RawTensor
)
if
isinstance
(
value
,
RawTensor
):
node
.
_dtype
=
value
.
dtype
try
:
node
.
_dtype
=
value
.
dtype
except
RuntimeError
:
node
.
_dtype
=
None
node
.
_shape
=
(
value
.
_tuple_shape
if
isinstance
(
value
,
Tensor
)
else
value
.
shape
)
...
...
imperative/python/megengine/traced_module/pytree.py
浏览文件 @
7b19bc76
...
...
@@ -7,15 +7,18 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
from
collections
import
OrderedDict
from
collections
import
OrderedDict
,
defaultdict
from
functools
import
partial
from
typing
import
Callable
,
NamedTuple
import
numpy
as
np
from
..core._imperative_rt
import
OpDef
from
..core._imperative_rt.common
import
CompNode
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._wrap
import
Device
from
..core.tensor.dtype
import
QuantDtypeMeta
from
..distributed
import
Group
from
..module
import
Module
from
..quantization.utils
import
LSQParams
,
QParams
,
QuantMode
from
..tensor
import
Parameter
,
Tensor
...
...
@@ -49,45 +52,54 @@ SUPPORTED_LEAF_TYPE = {
type
(
Ellipsis
),
QuantMode
,
ArgsIndex
,
Group
,
}
USER_REGISTERED_LEAF_TYPE
=
[]
USER_REGISTERED_CONTAINER_TYPE
=
[]
# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree
SUPPORTED_LEAF_CLS
=
[
Module
,
Node
,
NodeMixin
,
np
.
dtype
,
np
.
ndarray
,
np
.
number
]
SUPPORTED_LEAF_CLS
=
[
Module
,
Node
,
NodeMixin
,
np
.
dtype
,
np
.
ndarray
,
np
.
number
,
np
.
bool_
,
OpDef
,
]
NodeType
=
NamedTuple
(
"NodeType"
,
[(
"flatten"
,
Callable
),
(
"unflatten"
,
Callable
)])
def
register_supported_type
(
type
,
flatten
=
None
,
unflatten
=
None
):
tp_info
=
(
type
.
__module__
,
type
.
__qualname__
)
if
flatten
and
unflatten
:
SUPPORTED_TYPE
[
type
]
=
NodeType
(
flatten
,
unflatten
)
USER_REGISTERED_CONTAINER_TYPE
.
append
(
tp_info
)
else
:
SUPPORTED_LEAF_CLS
.
append
(
type
)
def
_dict_flatten
(
inp
):
aux_data
=
[]
results
=
[]
for
key
,
value
in
sorted
(
inp
.
items
()):
results
.
append
(
value
)
aux_data
.
append
(
key
)
return
results
,
tuple
(
aux_data
)
USER_REGISTERED_LEAF_TYPE
.
append
(
tp_info
)
_register_supported_type
(
type
,
flatten
,
unflatten
)
def
_dict_unflatten
(
inps
,
aux_data
):
return
dict
(
zip
(
aux_data
,
inps
))
def
_register_supported_type
(
type
,
flatten
=
None
,
unflatten
=
None
):
if
flatten
and
unflatten
:
SUPPORTED_TYPE
[
type
]
=
NodeType
(
flatten
,
unflatten
)
else
:
SUPPORTED_LEAF_CLS
.
append
(
type
)
def
_
ordereddict_flatten
(
inp
):
def
_
dict_flatten
(
ordered
,
inp
):
aux_data
=
[]
results
=
[]
for
key
,
value
in
inp
.
items
():
dict_items
=
inp
.
items
()
if
ordered
else
sorted
(
inp
.
items
())
for
key
,
value
in
dict_items
:
results
.
append
(
value
)
aux_data
.
append
(
key
)
return
results
,
tuple
(
aux_data
)
def
_
ordereddict_unflatten
(
inps
,
aux_data
):
return
OrderedDict
(
zip
(
aux_data
,
inps
))
def
_
dict_unflatten
(
dict_type
,
inps
,
aux_data
):
return
dict_type
(
zip
(
aux_data
,
inps
))
def
qparams_flatten
(
inp
):
...
...
@@ -99,33 +111,41 @@ def qparams_flatten(inp):
return
results
,
tuple
(
aux_data
)
def
qparams_unflatten
(
inp
,
aux_data
):
obj
=
QParams
.
__new__
(
QParams
)
def
qparams_unflatten
(
qparam_type
,
inp
,
aux_data
):
obj
=
qparam_type
.
__new__
(
qparam_type
)
for
k
,
v
in
zip
(
aux_data
,
inp
):
setattr
(
obj
,
k
,
v
)
return
obj
register_supported_type
(
list
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
list
(
x
))
register_supported_type
(
tuple
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
tuple
(
x
))
register_supported_type
(
dict
,
_dict_flatten
,
_dict_unflatten
)
register_supported_type
(
collections
.
OrderedDict
,
_ordereddict_flatten
,
_ordereddict_unflatten
_register_supported_type
(
list
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
list
(
x
))
_register_supported_type
(
tuple
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
tuple
(
x
))
_register_supported_type
(
dict
,
partial
(
_dict_flatten
,
False
),
partial
(
_dict_unflatten
,
dict
)
)
_register_supported_type
(
defaultdict
,
partial
(
_dict_flatten
,
False
),
partial
(
_dict_unflatten
,
defaultdict
)
)
register_supported_type
(
_register_supported_type
(
OrderedDict
,
partial
(
_dict_flatten
,
True
),
partial
(
_dict_unflatten
,
OrderedDict
)
)
_register_supported_type
(
slice
,
lambda
x
:
([
x
.
start
,
x
.
stop
,
x
.
step
],
None
),
lambda
x
,
aux_data
:
slice
(
x
[
0
],
x
[
1
],
x
[
2
]),
)
register_supported_type
(
QParams
,
qparams_flatten
,
qparams_unflatten
)
_register_supported_type
(
QParams
,
qparams_flatten
,
partial
(
qparams_unflatten
,
QParams
))
_register_supported_type
(
LSQParams
,
qparams_flatten
,
partial
(
qparams_unflatten
,
LSQParams
)
)
def
_is_leaf
(
obj
):
if
isinstance
(
obj
,
type
):
return
issubclass
(
obj
,
tuple
(
SUPPORTED_LEAF_CLS
))
or
obj
in
SUPPORTED_LEAF_TYPE
obj_type
=
obj
if
isinstance
(
obj
,
type
)
else
type
(
obj
)
return
(
isinstance
(
obj
,
tuple
(
SUPPORTED_LEAF_CLS
))
or
type
(
obj
)
in
SUPPORTED_LEAF_TYPE
issubclass
(
obj_type
,
tuple
(
SUPPORTED_LEAF_CLS
))
or
obj_type
in
SUPPORTED_LEAF_TYPE
)
...
...
imperative/python/megengine/traced_module/serialization.py
浏览文件 @
7b19bc76
...
...
@@ -5,30 +5,158 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Dict
from
importlib
import
import_module
from
typing
import
Dict
,
Tuple
from
..core._imperative_rt
import
OpDef
from
..core.ops
import
builtin
from
..tensor
import
Tensor
from
..version
import
__version__
from
.utils
import
_convert_kwargs_to_args
OPDEF_PARAM_LOADER
=
{}
OPDEF_LOADER
=
{}
FUNCTIONAL_LOADER
=
{}
TENSORMETHOD_LOADER
=
{}
MODULE_LOADER
=
{}
def
get_opdef_state
(
obj
:
OpDef
)
->
Dict
:
state
=
obj
.
__getstate__
()
state
[
"type"
]
=
type
(
obj
)
state
[
"version"
]
=
__version__
return
state
class
_ModuleState
:
obj
=
None
def
__init__
(
self
,
module
:
Tuple
,
state
:
Dict
,
version
:
str
):
self
.
module
=
module
self
.
state
=
state
self
.
version
=
version
def
load_opdef_from_state
(
state
:
Dict
)
->
OpDef
:
assert
"type"
in
state
and
issubclass
(
state
[
"type"
],
OpDef
)
assert
"version"
in
state
opdef_type
=
state
.
pop
(
"type"
)
if
opdef_type
in
OPDEF_PARAM_LOADER
:
loader
=
OPDEF_PARAM_LOADER
[
opdef_type
]
state
=
loader
(
state
)
state
.
pop
(
"version"
)
opdef_obj
=
opdef_type
()
opdef_obj
.
__setstate__
(
state
)
return
opdef_obj
@
classmethod
def
get_module_state
(
cls
,
module
):
typem
=
(
type
(
module
).
__module__
,
type
(
module
).
__qualname__
)
state
=
module
.
__dict__
.
copy
()
state
.
pop
(
"_m_dump_modulestate"
,
None
)
if
hasattr
(
module
,
"_m_dump_modulestate"
):
assert
isinstance
(
module
.
_m_dump_modulestate
,
cls
)
module
.
_m_dump_modulestate
.
__init__
(
typem
,
state
,
__version__
)
else
:
module
.
__dict__
[
"_m_dump_modulestate"
]
=
_ModuleState
(
typem
,
state
,
__version__
)
return
module
.
_m_dump_modulestate
def
__getstate__
(
self
):
return
{
"module"
:
self
.
module
,
"state"
:
self
.
state
,
"version"
:
self
.
version
}
def
to_module
(
self
):
if
self
.
obj
is
None
:
typem
=
getattr
(
import_module
(
self
.
module
[
0
]),
self
.
module
[
1
])
m_obj
=
typem
.
__new__
(
typem
)
m_obj
.
__dict__
.
update
(
self
.
state
)
self
.
obj
=
m_obj
return
self
.
obj
def
register_opdef_loader
(
*
opdefs
):
def
callback
(
loader
):
for
opdef
in
opdefs
:
assert
opdef
not
in
OPDEF_LOADER
OPDEF_LOADER
[
opdef
]
=
loader
return
loader
return
callback
def
register_functional_loader
(
*
funcs
):
def
callback
(
loader
):
for
func
in
funcs
:
assert
func
not
in
FUNCTIONAL_LOADER
FUNCTIONAL_LOADER
[
func
]
=
loader
return
loader
return
callback
def
register_module_loader
(
*
module_types
):
def
callback
(
loader
):
for
module_type
in
module_types
:
assert
module_type
not
in
MODULE_LOADER
MODULE_LOADER
[
module_type
]
=
loader
return
loader
return
callback
def
register_tensor_method_loader
(
*
methods
):
def
callback
(
loader
):
for
method
in
methods
:
assert
method
not
in
TENSORMETHOD_LOADER
TENSORMETHOD_LOADER
[
method
]
=
loader
return
loader
return
callback
def
_replace_args_kwargs
(
expr
,
new_args
,
new_kwargs
):
if
len
(
new_args
)
!=
len
(
expr
.
args
)
or
set
(
new_kwargs
.
keys
())
!=
set
(
expr
.
kwargs
.
keys
()
):
expr
.
set_args_kwargs
(
*
new_args
,
**
new_kwargs
)
def
load_functional
(
expr
):
func
=
(
(
expr
.
func
.
__module__
,
expr
.
func
.
__qualname__
)
if
callable
(
expr
.
func
)
else
expr
.
func
)
assert
isinstance
(
func
,
tuple
)
if
func
in
FUNCTIONAL_LOADER
:
loader
=
FUNCTIONAL_LOADER
[
func
]
loader
(
expr
)
mname
,
fname
=
func
f
=
import_module
(
mname
)
for
i
in
fname
.
split
(
"."
):
f
=
getattr
(
f
,
i
)
expr
.
func
=
f
assert
callable
(
expr
.
func
)
if
not
hasattr
(
expr
,
"version"
)
or
expr
.
version
!=
__version__
:
args
,
kwargs
=
_convert_kwargs_to_args
(
expr
.
func
,
expr
.
args
,
expr
.
kwargs
)
_replace_args_kwargs
(
expr
,
args
,
kwargs
)
def
load_call_module_expr
(
expr
):
m_type
=
expr
.
inputs
[
0
].
module_type
if
isinstance
(
m_type
,
type
):
m_type
=
(
m_type
.
__module__
,
m_type
.
__qualname__
)
if
m_type
in
MODULE_LOADER
:
MODULE_LOADER
[
m_type
](
expr
)
if
isinstance
(
expr
.
inputs
[
0
].
module_type
,
tuple
):
mname
,
classname
=
expr
.
inputs
[
0
].
module_type
expr
.
inputs
[
0
].
module_type
=
getattr
(
import_module
(
mname
),
classname
)
if
not
hasattr
(
expr
,
"version"
)
or
expr
.
version
!=
__version__
:
fwd_func
=
getattr
(
expr
.
inputs
[
0
].
module_type
,
"forward"
)
args
,
kwargs
=
_convert_kwargs_to_args
(
fwd_func
,
expr
.
args
,
expr
.
kwargs
)
_replace_args_kwargs
(
expr
,
args
,
kwargs
)
def
load_call_tensor_method_expr
(
expr
):
if
expr
.
method
in
TENSORMETHOD_LOADER
:
loader
=
TENSORMETHOD_LOADER
[
expr
.
method
]
loader
(
expr
)
if
not
hasattr
(
expr
,
"version"
)
or
expr
.
version
!=
__version__
:
tmethod
=
(
getattr
(
expr
.
args
[
0
],
expr
.
method
)
if
isinstance
(
expr
.
args
[
0
],
type
)
else
getattr
(
Tensor
,
expr
.
method
)
)
args
,
kwargs
=
_convert_kwargs_to_args
(
tmethod
,
expr
.
args
,
expr
.
kwargs
)
_replace_args_kwargs
(
expr
,
args
,
kwargs
)
def
load_apply_expr
(
expr
):
opdef_type
=
type
(
expr
.
opdef
)
if
opdef_type
in
OPDEF_LOADER
:
OPDEF_LOADER
[
opdef_type
](
expr
)
opdef_state
=
expr
.
opdef_state
opdef_obj
=
opdef_state
.
pop
(
"opdef_type"
)()
opdef_obj
.
__setstate__
(
opdef_state
)
expr
.
opdef
=
opdef_obj
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
7b19bc76
...
...
@@ -14,6 +14,7 @@ import inspect
import
keyword
import
re
import
weakref
from
importlib
import
import_module
from
inspect
import
getcallargs
,
getmembers
,
isclass
,
ismethod
from
itertools
import
chain
from
types
import
FunctionType
...
...
@@ -53,6 +54,7 @@ from ..quantization.observer import (
SyncMinMaxObserver
,
)
from
..tensor
import
Tensor
from
..version
import
__version__
from
.expr
import
(
Apply
,
CallFunction
,
...
...
@@ -80,8 +82,27 @@ from .module_tracer import (
set_active_module_tracer
,
)
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.pytree
import
ArgsIndex
,
tree_flatten
from
.utils
import
replace_container_with_module_container
from
.pytree
import
(
USER_REGISTERED_CONTAINER_TYPE
,
USER_REGISTERED_LEAF_TYPE
,
ArgsIndex
,
TreeDef
,
_register_supported_type
,
tree_flatten
,
)
from
.serialization
import
(
_ModuleState
,
load_apply_expr
,
load_call_module_expr
,
load_call_tensor_method_expr
,
load_functional
,
)
from
.utils
import
(
_check_builtin_module_attr
,
_check_obj_attr
,
_convert_kwargs_to_args
,
replace_container_with_module_container
,
)
logger
=
get_logger
(
__name__
)
...
...
@@ -341,7 +362,7 @@ class NameSpace:
def
create_unique_name
(
self
,
name
:
str
,
node
:
Any
=
None
)
->
str
:
assert
isinstance
(
name
,
str
),
"The name must be a string"
if
name
in
self
.
_used_names
and
self
.
_used_names
[
name
]
is
node
:
if
name
in
self
.
_used_names
and
(
self
.
_used_names
[
name
]
is
node
)
:
return
name
name
=
re
.
sub
(
"[^0-9a-zA-Z_]+"
,
"_"
,
name
)
...
...
@@ -1067,6 +1088,7 @@ class InternalGraph:
if
node2value
[
n
][
1
]
==
0
:
node2value
.
pop
(
n
)
if
values
is
not
None
:
assert
len
(
values
)
==
len
(
expr
.
outputs
)
for
n
,
v
in
zip
(
expr
.
outputs
,
values
):
if
ref_count
(
n
)
>
0
:
node2value
[
n
]
=
[
v
,
ref_count
(
n
)]
...
...
@@ -1105,13 +1127,27 @@ class InternalGraph:
return
res
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
if
"_top_graph"
in
state
:
state
.
pop
(
"_top_graph"
)
state
=
{
"_exprs"
:
self
.
_exprs
,
"_inputs"
:
self
.
_inputs
,
"_outputs"
:
self
.
_outputs
,
"_watch_point"
:
[],
"_end_point"
:
[],
"_namespace"
:
self
.
_namespace
,
"_rst"
:
collections
.
defaultdict
(
list
),
"_name"
:
self
.
_name
,
"_qualname"
:
self
.
_qualname
,
}
if
self
.
_total_ids
:
state
[
"_total_ids"
]
=
self
.
_total_ids
_check_obj_attr
(
state
)
return
state
def
__setstate__
(
self
,
state
):
old_version
=
False
if
"_module_name"
in
state
:
old_version
=
True
state
[
"_qualname"
]
=
state
.
pop
(
"_module_name"
)
...
...
@@ -1144,6 +1180,25 @@ class InternalGraph:
self
.
_namespace
=
NameSpace
(
self
.
_name
,
self
.
_qualname
)
self
.
_re_associate_name
()
def
__copy__
(
self
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
result
.
__dict__
.
update
(
self
.
__dict__
)
return
result
def
__deepcopy__
(
self
,
memo
):
if
id
(
self
)
in
memo
:
return
memo
[
id
(
self
)]
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
state
=
{}
memo
[
id
(
self
)]
=
result
for
k
,
v
in
self
.
__dict__
.
items
():
if
not
isinstance
(
v
,
weakref
.
ReferenceType
):
state
[
k
]
=
copy
.
deepcopy
(
v
,
memo
)
result
.
__dict__
.
update
(
state
)
return
result
def
_get_meth_name
(
obj
,
func
):
tp
=
obj
if
isinstance
(
obj
,
type
)
else
type
(
obj
)
...
...
@@ -1157,9 +1212,7 @@ def _get_meth_name(obj, func):
def
_wrapped_function
(
orig_func
):
@
functools
.
wraps
(
orig_func
)
def
wrapped_fn
(
*
args
,
**
kwargs
):
method_func
=
wrapped_fn
if
"method_func"
in
kwargs
:
method_func
=
kwargs
.
pop
(
"method_func"
)
method_func
=
kwargs
.
pop
(
"method_func"
,
wrapped_fn
)
if
is_tracing_module
():
unset_module_tracing
()
inputs
,
tree_def
=
tree_flatten
((
args
,
kwargs
))
...
...
@@ -1167,11 +1220,11 @@ def _wrapped_function(orig_func):
if
not
NodeMixin
.
get
(
i
,
None
):
if
isinstance
(
i
,
(
RawTensor
,
NodeMixin
)):
NodeMixin
.
wrap_safe
(
i
,
Constant
.
make
(
i
))
meth_name
,
arg_type
=
None
,
None
if
args
:
meth_name
=
_get_meth_name
(
args
[
0
],
method_func
)
arg_type
=
args
[
0
]
if
isinstance
(
args
[
0
],
type
)
else
type
(
args
[
0
])
args
,
kwargs
=
_convert_kwargs_to_args
(
orig_func
,
args
,
kwargs
)
meth_name
=
_get_meth_name
(
args
[
0
],
method_func
)
arg_type
=
args
[
0
]
if
isinstance
(
args
[
0
],
type
)
else
type
(
args
[
0
])
if
meth_name
and
arg_type
and
issubclass
(
arg_type
,
RawTensor
):
inputs
,
tree_def
=
tree_flatten
((
args
,
kwargs
))
self
=
inputs
[
0
]
if
meth_name
==
"__new__"
:
if
all
([
not
isinstance
(
i
,
RawTensor
)
for
i
in
inputs
]):
...
...
@@ -1190,6 +1243,7 @@ def _wrapped_function(orig_func):
call_node
=
CallMethod
.
make
(
NodeMixin
.
get
(
self
),
meth_name
)
call_node
.
add_inputs
(
inputs
[
1
:])
else
:
inputs
,
tree_def
=
tree_flatten
((
args
,
kwargs
))
call_node
=
CallFunction
.
make
(
orig_func
)
call_node
.
add_inputs
(
inputs
)
...
...
@@ -1228,9 +1282,11 @@ class TracedModuleBuilder(NodeMixin):
"_record_wrapped_nodes"
,
"_argdef_graph_map"
,
"_argdef_outdef_map"
,
"_check_qat_module"
,
"nodes"
,
"__class__"
,
"__dict__"
,
"_is_top"
,
]
def
__init__
(
self
,
mod
,
is_top_module
=
False
):
...
...
@@ -1301,22 +1357,18 @@ class TracedModuleBuilder(NodeMixin):
qat_module
.
weight_fake_quant
.
set_qparams
(
qparams
)
def
build
(
self
):
if
self
.
_is_builtin
or
isinstance
(
self
.
_mod
,
TracedModule
):
if
module_tracer
.
is_builtin
(
self
.
_mod
)
or
isinstance
(
self
.
_mod
,
TracedModule
):
mod_type
=
type
(
self
.
_mod
)
else
:
assert
isinstance
(
self
.
_mod
,
(
Observer
,
_FakeQuantize
))
mod_type
=
(
Observer
if
isinstance
(
self
.
_mod
,
Observer
)
else
_FakeQuantize
)
if
self
.
_is_builtin
:
assert
module_tracer
.
is_builtin
(
self
.
_mod
)
mod_type
=
type
(
self
.
_mod
)
for
node
in
self
.
nodes
:
node
.
module_type
=
mod_type
return
self
.
_mod
else
:
is_qat
=
isinstance
(
self
.
_mod
,
QATModule
)
is_qat
=
isinstance
(
self
.
_mod
,
QATModule
)
or
(
isinstance
(
self
.
_mod
,
TracedModule
)
and
self
.
_mod
.
is_qat
)
traced_module
=
TracedModule
(
self
.
_is_top
,
self
.
_argdef_graph_map
,
self
.
_argdef_outdef_map
,
is_qat
)
...
...
@@ -1338,15 +1390,18 @@ class TracedModuleBuilder(NodeMixin):
traced_module
.
with_act
=
self
.
_mod
.
with_act
traced_module
.
with_weight
=
self
.
_mod
.
with_weight
if
not
hasattr
(
traced_module
,
"act_fake_quant"
):
traced_module
.
act_fakequant
=
None
traced_module
.
act_fake
_
quant
=
None
if
not
hasattr
(
traced_module
,
"act_observer"
):
traced_module
.
act_observer
=
None
if
not
hasattr
(
traced_module
,
"weight_fake_quant"
):
traced_module
.
weight_fakequant
=
None
traced_module
.
weight_fake
_
quant
=
None
if
not
hasattr
(
traced_module
,
"weight_observer"
):
traced_module
.
weight_observer
=
None
set_module_tracing
()
if
self
.
_is_top
:
traced_module
.
_update_ref
()
return
traced_module
def
_record_wrapped_nodes
(
self
,
node
):
...
...
@@ -1357,6 +1412,7 @@ class TracedModuleBuilder(NodeMixin):
# prepare args and kwargs for inner graph
if
"method_func"
in
kwargs
:
kwargs
.
pop
(
"method_func"
)
args
,
kwargs
=
_convert_kwargs_to_args
(
self
.
_mod
.
forward
,
args
,
kwargs
,
True
)
def
mark_constant
(
x
):
node
=
NodeMixin
.
get
(
x
,
None
)
...
...
@@ -1372,11 +1428,7 @@ class TracedModuleBuilder(NodeMixin):
callnode
.
arg_def
=
tree_def
if
(
self
.
_is_builtin
or
tree_def
in
self
.
_argdef_graph_map
or
isinstance
(
self
.
_mod
,
TracedModule
)
):
if
self
.
_is_builtin
or
tree_def
in
self
.
_argdef_graph_map
:
unset_module_tracing
()
rst
=
self
.
_mod
(
*
args
,
**
kwargs
)
outputs
,
out_def
=
tree_flatten
(
rst
,
is_leaf
=
_is_leaf
)
...
...
@@ -1385,33 +1437,7 @@ class TracedModuleBuilder(NodeMixin):
self
.
_body
=
None
elif
tree_def
in
self
.
_argdef_graph_map
:
self
.
_body
=
self
.
_argdef_graph_map
[
tree_def
]
else
:
self
.
_mod
.
_is_top
=
False
self
.
_body
=
self
.
_mod
.
argdef_graph_map
[
tree_def
]
module_qualname
=
NodeMixin
.
get
(
self
).
qualname
if
module_qualname
!=
self
.
_body
.
qualname
:
src_name
,
dst_name
=
self
.
_body
.
qualname
,
module_qualname
def
replace_qualname
(
g
):
attr_name
=
get_suffix_name
(
src_name
,
g
.
qualname
)
if
attr_name
is
not
None
:
g
.
_qualname
=
(
(
"%s.%s"
%
(
dst_name
,
attr_name
))
if
attr_name
else
dst_name
)
assert
get_suffix_name
(
dst_name
,
g
.
qualname
)
is
not
None
for
mod
in
self
.
_mod
.
modules
():
if
not
hasattr
(
mod
,
"argdef_graph_map"
):
continue
for
g
in
mod
.
argdef_graph_map
.
values
():
replace_qualname
(
g
)
g
.
_namespace
.
qualname
=
g
.
qualname
for
n
in
g
.
nodes
(
False
):
replace_qualname
(
n
)
else
:
self_node
=
None
orig_self
=
NodeMixin
.
get
(
self
)
parent_graph
=
active_module_tracer
().
current_scope
()
module_qualname
=
orig_self
.
_qualname
...
...
@@ -1423,20 +1449,14 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer
().
push_scope
(
self
.
_body
)
# rebind self to new input node
if
self_node
:
NodeMixin
.
wrap_safe
(
self
,
self_node
)
active_module_tracer
().
current_scope
().
_add_input
(
self_node
)
else
:
NodeMixin
.
wrap_safe
(
self
,
self_node
if
self_node
else
Input
.
make
(
name
=
"self"
,
qualname
=
module_qualname
,
type
=
NodeMixin
.
get_wrapped_type
(
self
),
),
)
NodeMixin
.
wrap_safe
(
self
,
Input
.
make
(
name
=
"self"
,
qualname
=
module_qualname
,
type
=
NodeMixin
.
get_wrapped_type
(
self
),
),
)
origin_inp_node
=
[
NodeMixin
.
get
(
i
,
None
)
for
i
in
inputs
[
1
:]]
# prepare args and kwargs for inner graph
...
...
@@ -1470,8 +1490,23 @@ class TracedModuleBuilder(NodeMixin):
return
x
args
=
[
self
]
for
i
,
v
in
enumerate
(
inputs
[
1
:]):
args
.
append
(
wrap
(
v
,
idx2key
[
i
+
1
]))
orig_traced_inputs
=
(
None
if
not
isinstance
(
self
.
_mod
,
TracedModule
)
else
self
.
_mod
.
argdef_graph_map
[
tree_def
].
inputs
)
ind
=
1
for
v
in
inputs
[
1
:]:
if
isinstance
(
v
,
(
RawTensor
,
NodeMixin
)):
args_name
=
(
orig_traced_inputs
[
ind
].
_name
if
orig_traced_inputs
else
idx2key
[
ind
]
)
ind
+=
1
args
.
append
(
wrap
(
v
,
args_name
))
else
:
args
.
append
(
v
)
args
,
kwargs
=
tree_def
.
unflatten
(
args
)
active_module_tracer
().
patcher
.
auto_patch
(
...
...
@@ -1514,7 +1549,6 @@ class TracedModuleBuilder(NodeMixin):
attr
=
getattr
(
type
(
self
.
_mod
),
name
).
__get__
(
self
,
type
(
self
))
else
:
attr
=
getattr
(
self
.
_mod
,
name
)
if
(
isinstance
(
attr
,
FunctionType
)
and
id
(
attr
)
in
active_module_tracer
().
patcher
.
patched_fn_ids
...
...
@@ -1568,7 +1602,7 @@ class TracedModuleBuilder(NodeMixin):
wrapped
=
self
.
__getattr__
(
name
)
if
isinstance
(
wrapped
,
TracedModuleBuilder
):
if
not
isinstance
(
mod_attr
,
(
List
,
Dict
)):
if
not
isinstance
(
mod_attr
,
(
List
,
Dict
,
QATModule
)):
assert
mod_attr
is
wrapped
.
_mod
else
:
assert
mod_attr
is
wrapped
...
...
@@ -1977,8 +2011,6 @@ class TracedModule(Module):
def
graph
(
self
)
->
InternalGraph
:
"""Return the ``InternalGraph`` of this ``TracedModule``.
"""
if
self
.
_is_top
:
self
.
_update_ref
()
assert
len
(
self
.
argdef_graph_map
)
==
1
return
list
(
self
.
argdef_graph_map
.
values
())[
0
]
...
...
@@ -2112,7 +2144,7 @@ class TracedModule(Module):
if
hasattr
(
obj
,
"argdef_graph_map"
)
else
None
)
if
expr_graph
is
not
None
:
if
expr_graph
is
not
None
and
not
obj
.
is_qat
:
exprs
=
_flatten_subgraph
(
graph
,
expr_graph
,
expr
,
obj
)
if
parent_graph
is
not
None
:
...
...
@@ -2137,26 +2169,119 @@ class TracedModule(Module):
)
new_module
.
graph
.
_re_associate_name
()
new_module
.
graph
.
compile
()
new_module
.
_update_ref
()
new_module
.
graph
.
_reset_ids
()
return
new_module
def
__getstate__
(
self
):
d
=
self
.
__dict__
d
=
self
.
__dict__
.
copy
()
for
k
in
Module
.
__dict__
:
d
.
pop
(
k
,
None
)
_check_obj_attr
(
d
)
for
k
in
d
:
if
module_tracer
.
is_builtin
(
d
[
k
]):
assert
_check_builtin_module_attr
(
d
[
k
]
),
"Module {} can not be serialized. "
.
format
(
type
(
d
[
k
]))
d
[
k
]
=
_ModuleState
.
get_module_state
(
d
[
k
])
dump_info
=
{
"version"
:
__version__
,
"register_type"
:
USER_REGISTERED_LEAF_TYPE
,
"register_container_type"
:
USER_REGISTERED_CONTAINER_TYPE
,
"register_mdule"
:
USER_REGISTERED_MODULE
,
"register_function"
:
USER_REGISTERED_FUNCTION
,
}
d
[
"dump_info"
]
=
dump_info
return
d
def
__setstate__
(
self
,
state
):
for
k
,
v
in
state
.
items
():
if
isinstance
(
v
,
_ModuleState
):
state
[
k
]
=
v
.
to_module
()
self
.
__dict__
.
update
(
state
)
self
.
_update_ref
()
for
_
,
graph
in
self
.
argdef_graph_map
.
items
():
for
expr
in
graph
.
_exprs
:
if
isinstance
(
expr
,
CallFunction
):
load_functional
(
expr
)
if
isinstance
(
expr
,
CallMethod
):
if
expr
.
method
==
"__call__"
:
load_call_module_expr
(
expr
)
else
:
load_call_tensor_method_expr
(
expr
)
if
isinstance
(
expr
,
Apply
):
load_apply_expr
(
expr
)
for
_
,
graph
in
self
.
argdef_graph_map
.
items
():
ind
=
0
while
ind
<
len
(
graph
.
_exprs
):
cur_expr
=
graph
.
_exprs
[
ind
]
has_new_expr
=
False
for
i
in
cur_expr
.
inputs
:
if
i
.
expr
not
in
graph
.
_exprs
and
not
isinstance
(
i
.
expr
,
Input
):
graph
.
_exprs
.
insert
(
ind
,
i
.
expr
)
has_new_expr
=
True
if
not
has_new_expr
:
ind
+=
1
for
expr
in
graph
.
_exprs
:
for
i
in
expr
.
inputs
:
if
expr
.
inputs
.
count
(
i
)
!=
i
.
users
.
count
(
expr
):
add_or_del_count
=
expr
.
inputs
.
count
(
i
)
-
i
.
users
.
count
(
expr
)
if
add_or_del_count
>
0
:
i
.
users
.
extend
([
expr
]
*
add_or_del_count
)
else
:
[
i
.
users
.
remove
(
expr
)
for
i
in
range
(
-
add_or_del_count
)]
for
o
in
expr
.
outputs
:
if
o
.
expr
is
not
expr
:
assert
o
not
in
o
.
expr
.
outputs
o
.
expr
=
expr
for
node
in
graph
.
nodes
(
False
):
# remove users of node which doesn't use node as input
node
.
users
=
[
e
for
e
in
node
.
users
if
node
in
e
.
inputs
]
for
expr
in
graph
.
_exprs
:
graph
.
_namespace
.
auto_naming_for_outputs
(
expr
)
self
.
_update_ref
()
for
_
,
graph
in
self
.
argdef_graph_map
.
items
():
graph
.
_reset_ids
()
def
__copy__
(
self
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
result
.
__dict__
.
update
(
self
.
__dict__
)
return
result
def
__deepcopy__
(
self
,
memo
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
state
=
{}
memo
[
id
(
self
)]
=
result
for
k
,
v
in
self
.
__dict__
.
items
():
if
not
isinstance
(
v
,
weakref
.
ReferenceType
):
state
[
k
]
=
copy
.
deepcopy
(
v
,
memo
)
result
.
__dict__
.
update
(
state
)
result
.
_update_ref
()
return
result
def
cpp_apply_module_trace
(
opdef
,
*
args
):
return
Apply
.
apply_module_trace_hook
(
opdef
,
*
args
)
USER_REGISTERED_MODULE
=
[]
USER_REGISTERED_FUNCTION
=
[]
def
register_as_builtin
(
mod_cls
:
Type
[
Module
])
->
None
:
r
"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module.
Args:
mod_cls: the module class which will be treated as builtin module in tracing.
"""
USER_REGISTERED_MODULE
.
append
((
mod_cls
.
__module__
,
mod_cls
.
__qualname__
))
module_tracer
.
register_as_builtin
(
mod_cls
)
...
...
@@ -2181,6 +2306,7 @@ def wrap(func: Callable):
Args:
func: the function of the global function to insert into the graph when it's called.
"""
USER_REGISTERED_FUNCTION
.
append
((
func
.
__module__
,
func
.
__qualname__
))
assert
callable
(
func
),
"func must be a callable"
assert
hasattr
(
func
,
"__code__"
)
fn_name
=
func
.
__code__
.
co_name
...
...
@@ -2247,6 +2373,8 @@ def trace_module(
NodeMixin
.
wrap_safe
(
builder
,
Input
.
make
(
name
=
"top"
,
type
=
ModuleNode
,
qualname
=
net_name
)
)
args
,
kwargs
=
_convert_kwargs_to_args
(
mod
.
forward
,
args
,
kwargs
,
True
)
inputs
,
_
=
tree_flatten
((
args
,
kwargs
))
for
_
,
i
in
enumerate
(
inputs
):
# assert isinstance(i, Tensor), "not support "
...
...
imperative/python/megengine/traced_module/utils.py
浏览文件 @
7b19bc76
...
...
@@ -5,12 +5,17 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
import
copy
import
inspect
from
collections.abc
import
MutableMapping
,
MutableSequence
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Type
from
..
import
get_logger
from
..module
import
Module
logger
=
get_logger
(
__name__
)
def
replace_container_with_module_container
(
container
):
has_module
=
False
...
...
@@ -52,6 +57,101 @@ def replace_container_with_module_container(container):
return
has_module
,
module_container
def
_convert_kwargs_to_args
(
func
,
args
,
kwargs
,
is_bounded
=
False
):
# is_bounded = True when func is a method and provided args don't include 'self'
arg_specs
=
inspect
.
getfullargspec
(
func
)
arg_specs_args
=
arg_specs
.
args
if
is_bounded
:
arg_specs_args
=
arg_specs
.
args
[
1
:]
new_args
=
[]
new_kwargs
=
{}
new_args
.
extend
(
args
)
if
set
(
arg_specs_args
[
0
:
len
(
new_args
)])
&
set
(
kwargs
.
keys
()):
repeated_arg_name
=
set
(
arg_specs_args
[
0
:
len
(
new_args
)])
&
set
(
kwargs
.
keys
())
raise
TypeError
(
"{} got multiple values for argument {}"
.
format
(
func
.
__qualname__
,
", "
.
join
(
repeated_arg_name
)
)
)
if
len
(
new_args
)
<
len
(
arg_specs
.
args
):
for
ind
in
range
(
len
(
new_args
),
len
(
arg_specs_args
)):
arg_name
=
arg_specs_args
[
ind
]
if
arg_name
in
kwargs
:
new_args
.
append
(
kwargs
[
arg_name
])
else
:
index
=
ind
-
len
(
arg_specs_args
)
+
len
(
arg_specs
.
defaults
)
assert
index
<
len
(
arg_specs
.
defaults
)
and
index
>=
0
new_args
.
append
(
arg_specs
.
defaults
[
index
])
for
kwarg_name
in
arg_specs
.
kwonlyargs
:
if
kwarg_name
in
kwargs
:
new_kwargs
[
kwarg_name
]
=
kwargs
[
kwarg_name
]
else
:
assert
kwarg_name
in
arg_specs
.
kwonlydefaults
new_kwargs
[
kwarg_name
]
=
arg_specs
.
kwonlydefaults
[
kwarg_name
]
for
k
,
v
in
kwargs
.
items
():
if
k
not
in
arg_specs
.
args
and
k
not
in
arg_specs
.
kwonlyargs
:
if
arg_specs
.
varkw
is
None
:
raise
TypeError
(
"{} got an unexpected keyword argument {}"
.
format
(
func
.
__qualname__
,
k
)
)
new_kwargs
[
k
]
=
v
return
tuple
(
new_args
),
new_kwargs
def
_check_obj_attr
(
obj
):
# check if all the attributes of a obj is serializable
from
.pytree
import
tree_flatten
from
.pytree
import
SUPPORTED_LEAF_CLS
,
SUPPORTED_LEAF_TYPE
,
TreeDef
from
.expr
import
Expr
from
.traced_module
import
TracedModule
,
InternalGraph
,
NameSpace
def
_check_leaf_type
(
leaf
):
leaf_type
=
leaf
if
isinstance
(
leaf
,
type
)
else
type
(
leaf
)
traced_module_types
=
[
Expr
,
TreeDef
,
TracedModule
,
InternalGraph
,
NameSpace
]
return
(
issubclass
(
leaf_type
,
tuple
(
SUPPORTED_LEAF_CLS
+
traced_module_types
))
or
leaf_type
in
SUPPORTED_LEAF_TYPE
)
for
_
,
v
in
obj
.
items
():
leafs
,
_
=
tree_flatten
(
v
,
is_leaf
=
lambda
_
:
True
)
for
leaf
in
leafs
:
assert
_check_leaf_type
(
leaf
),
"Type {} is not supported by traced module"
.
format
(
leaf
if
isinstance
(
leaf
,
type
)
else
type
(
leaf
)
)
def
_check_builtin_module_attr
(
mod
):
from
.pytree
import
_is_leaf
as
_check_leaf_type
from
.pytree
import
tree_flatten
# check if all the attributes of a builtin module is serializable
is_non_serializable_module
=
lambda
m
:
isinstance
(
m
,
Module
)
and
not
_check_builtin_module_attr
(
m
)
for
k
,
v
in
mod
.
__dict__
.
items
():
if
k
==
"_m_dump_modulestate"
:
continue
if
is_non_serializable_module
(
v
):
return
False
elif
not
isinstance
(
v
,
Module
):
leafs
,
_
=
tree_flatten
(
v
,
is_leaf
=
lambda
_
:
True
)
for
leaf
in
leafs
:
if
not
_check_leaf_type
(
leaf
)
or
is_non_serializable_module
(
leaf
):
logger
.
warn
(
"Type {} is not supported by traced module"
.
format
(
leaf
if
isinstance
(
leaf
,
type
)
else
type
(
leaf
)
)
)
return
False
return
True
class
_ModuleList
(
Module
,
MutableSequence
):
r
"""A List-like container.
...
...
imperative/python/test/unit/core/test_serialization.py
浏览文件 @
7b19bc76
...
...
@@ -15,7 +15,6 @@ import numpy as np
import
megengine
as
mge
from
megengine
import
Parameter
,
Tensor
from
megengine.core.ops
import
builtin
from
megengine.traced_module.serialization
import
get_opdef_state
,
load_opdef_from_state
def
test_tensor_serialization
():
...
...
@@ -88,25 +87,3 @@ def test_compatibility():
test_old_tensor
(
"tensor_v1_1.mge"
)
test_old_tensor
(
"tensor_v1_2.mge"
)
def
test_opdef_serialization
():
with
TemporaryFile
()
as
f
:
x
=
builtin
.
Elemwise
(
mode
=
"Add"
)
pickle
.
dump
(
get_opdef_state
(
x
),
f
)
f
.
seek
(
0
)
load_x
=
load_opdef_from_state
(
pickle
.
load
(
f
))
assert
x
==
load_x
with
TemporaryFile
()
as
f
:
x
=
builtin
.
Convolution
(
stride_h
=
9
,
compute_mode
=
"float32"
)
x
.
strategy
=
(
builtin
.
Convolution
.
Strategy
.
PROFILE
|
builtin
.
Convolution
.
Strategy
.
HEURISTIC
|
builtin
.
Convolution
.
Strategy
.
REPRODUCIBLE
)
pickle
.
dump
(
get_opdef_state
(
x
),
f
)
f
.
seek
(
0
)
load_x
=
load_opdef_from_state
(
pickle
.
load
(
f
))
assert
x
.
strategy
==
load_x
.
strategy
assert
x
==
load_x
imperative/python/test/unit/traced_module/test_modification.py
浏览文件 @
7b19bc76
...
...
@@ -85,12 +85,12 @@ class NewModule(M.Module):
return
x
def
_check_expr_users
(
trac
ed_module
):
def
_check_expr_users
(
flatten
ed_module
):
node_user
=
defaultdict
(
list
)
for
expr
in
trac
ed_module
.
graph
.
_exprs
:
for
expr
in
flatten
ed_module
.
graph
.
_exprs
:
for
node
in
expr
.
inputs
:
node_user
[
node
].
append
(
expr
)
for
node
in
trac
ed_module
.
graph
.
nodes
():
for
node
in
flatten
ed_module
.
graph
.
nodes
():
node
.
users
.
sort
(
key
=
lambda
m
:
m
.
_id
)
node_user
[
node
].
sort
(
key
=
lambda
m
:
m
.
_id
)
assert
node
.
users
==
node_user
[
node
]
...
...
imperative/python/test/unit/traced_module/test_qat_module.py
浏览文件 @
7b19bc76
...
...
@@ -8,6 +8,7 @@ import numpy as np
import
megengine
as
mge
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.module.qat
as
QM
import
megengine.quantization
as
Q
from
megengine
import
Tensor
from
megengine.module.qat.module
import
QATModule
...
...
@@ -28,10 +29,18 @@ def get_subattr(self: M.Module, name: str):
return
getattr
(
self
,
name
)
class
MyConvBnRelu2d
(
M
.
ConvBnRelu2d
):
pass
class
MyQATConvBnRelu2d
(
QM
.
ConvBnRelu2d
):
pass
class
Myblcok
(
M
.
Module
):
def
__init__
(
self
,):
super
().
__init__
()
self
.
conv0
=
M
.
ConvBnRelu2d
(
3
,
3
,
3
,
1
,
1
)
self
.
conv0
=
M
y
ConvBnRelu2d
(
3
,
3
,
3
,
1
,
1
)
self
.
conv1
=
M
.
ConvBn2d
(
3
,
3
,
1
,
1
,
0
)
self
.
conv2
=
M
.
ConvBn2d
(
3
,
3
,
1
,
1
,
0
)
self
.
add
=
M
.
Elemwise
(
"FUSE_ADD_RELU"
)
...
...
@@ -106,7 +115,11 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams):
def
build_observered_net
(
net
:
M
.
Module
,
observer_cls
):
qat_net
=
Q
.
quantize_qat
(
net
,
qconfig
=
get_observer_config
(
observer_cls
))
qat_net
=
Q
.
quantize_qat
(
net
,
qconfig
=
get_observer_config
(
observer_cls
),
mapping
=
{
MyConvBnRelu2d
:
MyQATConvBnRelu2d
},
)
Q
.
enable_observer
(
qat_net
)
inp
=
Tensor
(
np
.
random
.
random
(
size
=
(
5
,
3
,
32
,
32
)))
qat_net
(
inp
)
...
...
@@ -134,6 +147,15 @@ def test_trace_qat():
check_qparams
(
weight_qparams
,
traced_weight_qparams
)
if
act_qparams
:
check_qparams
(
act_qparams
,
traced_act_qparams
)
flatten_traced_net
=
traced_net
.
flatten
()
conv0_node
=
flatten_traced_net
.
graph
.
get_node_by_name
(
"MyModule_block0_conv0"
).
as_unique
()
conv0_out_node
=
flatten_traced_net
.
graph
.
get_node_by_name
(
"MyModule_block0_conv0_out"
).
as_unique
()
assert
isinstance
(
conv0_node
.
owner
,
TracedModule
)
assert
conv0_out_node
.
expr
.
inputs
[
0
]
is
conv0_node
_check_qat_module
(
build_observered_net
(
MyModule
(),
Q
.
MinMaxObserver
))
_check_qat_module
(
build_observered_net
(
MyModule
(),
MyMinMaxObserver
))
...
...
imperative/python/test/unit/traced_module/test_serialization.py
浏览文件 @
7b19bc76
...
...
@@ -6,14 +6,59 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
pickle
from
collections
import
defaultdict
from
tempfile
import
TemporaryFile
import
numpy
as
np
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.traced_module.serialization
as
S
from
megengine
import
Tensor
from
megengine.core._imperative_rt.core2
import
apply
from
megengine.core.ops
import
builtin
from
megengine.core.ops.builtin
import
Elemwise
from
megengine.module
import
Module
from
megengine.traced_module
import
trace_module
from
megengine.traced_module.expr
import
CallMethod
,
Constant
from
megengine.traced_module.node
import
TensorNode
from
megengine.traced_module.serialization
import
(
register_functional_loader
,
register_module_loader
,
register_opdef_loader
,
register_tensor_method_loader
,
)
from
megengine.traced_module.utils
import
_convert_kwargs_to_args
def
_check_id
(
traced_module
):
_total_ids
=
traced_module
.
graph
.
_total_ids
node_ids
=
[
n
.
_id
for
n
in
traced_module
.
graph
.
nodes
().
as_list
()]
assert
len
(
set
(
node_ids
))
==
len
(
node_ids
)
assert
max
(
node_ids
)
+
1
==
_total_ids
[
0
]
expr_ids
=
[
n
.
_id
for
n
in
traced_module
.
graph
.
exprs
().
as_list
()]
assert
len
(
set
(
expr_ids
))
==
len
(
expr_ids
)
assert
max
(
expr_ids
)
+
1
==
_total_ids
[
1
]
def
_check_name
(
flatened_module
):
node_names
=
[
n
.
_name
for
n
in
flatened_module
.
graph
.
nodes
().
as_list
()]
assert
len
(
set
(
node_names
))
==
len
(
node_names
)
def
_check_expr_users
(
traced_module
):
node_user
=
defaultdict
(
list
)
for
expr
in
traced_module
.
graph
.
_exprs
:
for
node
in
expr
.
inputs
:
node_user
[
node
].
append
(
expr
)
if
isinstance
(
expr
,
CallMethod
)
and
expr
.
graph
:
_check_expr_users
(
expr
.
inputs
[
0
].
owner
)
for
node
in
traced_module
.
graph
.
nodes
(
False
):
node
.
users
.
sort
(
key
=
lambda
m
:
m
.
_id
)
node_user
[
node
].
sort
(
key
=
lambda
m
:
m
.
_id
)
assert
node
.
users
==
node_user
[
node
]
class
MyBlock
(
Module
):
...
...
@@ -48,5 +93,274 @@ def test_dump_and_load():
traced_module
=
trace_module
(
module
,
x
)
np
.
testing
.
assert_array_equal
(
expect
,
traced_module
(
x
))
obj
=
pickle
.
dumps
(
traced_module
)
pickle
.
loads
(
obj
)
new_tm
=
pickle
.
loads
(
obj
)
_check_id
(
new_tm
)
_check_expr_users
(
new_tm
)
traced_module
.
graph
.
_reset_ids
()
old_nodes
=
traced_module
.
graph
.
nodes
().
as_list
()
new_nodes
=
new_tm
.
graph
.
nodes
().
as_list
()
old_exprs
=
traced_module
.
graph
.
exprs
().
as_list
()
new_exprs
=
new_tm
.
graph
.
exprs
().
as_list
()
assert
len
(
old_nodes
)
==
len
(
new_nodes
)
for
i
,
j
in
zip
(
old_nodes
,
new_nodes
):
assert
i
.
_name
==
j
.
_name
assert
i
.
_qualname
==
j
.
_qualname
assert
i
.
_id
==
j
.
_id
assert
len
(
old_exprs
)
==
len
(
new_exprs
)
for
i
,
j
in
zip
(
old_exprs
,
new_exprs
):
assert
i
.
_id
==
j
.
_id
np
.
testing
.
assert_array_equal
(
expect
,
traced_module
(
x
))
def
test_opdef_loader
():
class
MyModule1
(
Module
):
def
forward
(
self
,
x
,
y
):
op
=
Elemwise
(
"ADD"
)
return
apply
(
op
,
x
,
y
)[
0
]
m
=
MyModule1
()
x
=
Tensor
(
np
.
ones
((
20
)))
y
=
Tensor
(
np
.
ones
((
20
)))
traced_module
=
trace_module
(
m
,
x
,
y
)
orig_loader_dict
=
S
.
OPDEF_LOADER
S
.
OPDEF_LOADER
=
{}
@
register_opdef_loader
(
Elemwise
)
def
add_opdef_loader
(
expr
):
if
expr
.
opdef_state
[
"mode"
]
==
"ADD"
:
expr
.
opdef_state
[
"mode"
]
=
"MUL"
node
=
expr
.
inputs
[
1
]
astype_expr
=
CallMethod
(
node
,
"astype"
)
oup
=
TensorNode
(
astype_expr
,
shape
=
node
.
shape
,
dtype
=
expr
.
inputs
[
0
].
dtype
,
qparams
=
node
.
qparams
,
)
astype_expr
.
set_args_kwargs
(
node
,
expr
.
inputs
[
0
].
dtype
)
astype_expr
.
return_val
=
(
oup
,)
expr
.
inputs
[
1
]
=
oup
obj
=
pickle
.
dumps
(
traced_module
)
new_module
=
pickle
.
loads
(
obj
)
_check_id
(
new_module
)
_check_expr_users
(
new_module
)
_check_name
(
new_module
.
flatten
())
assert
(
isinstance
(
new_module
.
graph
.
_exprs
[
0
],
CallMethod
)
and
new_module
.
graph
.
_exprs
[
1
].
opdef
.
mode
==
"MUL"
and
len
(
new_module
.
graph
.
_exprs
)
==
2
)
result
=
new_module
(
x
,
y
)
np
.
testing
.
assert_equal
(
result
.
numpy
(),
x
.
numpy
())
S
.
OPDEF_LOADER
=
orig_loader_dict
def
test_functional_loader
():
class
MyModule2
(
Module
):
def
forward
(
self
,
x
,
y
):
return
F
.
conv2d
(
x
,
y
)
m
=
MyModule2
()
x
=
Tensor
(
np
.
random
.
random
((
1
,
3
,
32
,
32
)))
y
=
Tensor
(
np
.
random
.
random
((
3
,
3
,
3
,
3
)))
traced_module
=
trace_module
(
m
,
x
,
y
)
orig_loader_dict
=
S
.
FUNCTIONAL_LOADER
S
.
FUNCTIONAL_LOADER
=
{}
@
register_functional_loader
((
"megengine.functional.nn"
,
"conv2d"
))
def
conv2df_loader
(
expr
):
# expr.func = ("megengine.functional.nn","conv2d")
kwargs
=
expr
.
kwargs
orig_weight
=
expr
.
named_args
[
"weight"
]
astype_expr
=
CallMethod
(
orig_weight
,
"astype"
)
oup
=
TensorNode
(
astype_expr
,
shape
=
orig_weight
.
shape
,
dtype
=
orig_weight
.
dtype
,
qparams
=
orig_weight
.
qparams
,
)
astype_expr
.
set_args_kwargs
(
orig_weight
,
expr
.
named_args
[
"inp"
].
dtype
)
astype_expr
.
return_val
=
(
oup
,)
expr
.
set_arg
(
"weight"
,
oup
)
obj
=
pickle
.
dumps
(
traced_module
)
new_module
=
pickle
.
loads
(
obj
)
_check_expr_users
(
new_module
)
_check_id
(
new_module
)
result
=
new_module
(
x
,
y
)
gt
=
m
(
x
,
y
)
assert
(
isinstance
(
new_module
.
graph
.
_exprs
[
0
],
CallMethod
)
and
len
(
new_module
.
graph
.
_exprs
)
==
2
)
np
.
testing
.
assert_equal
(
result
.
numpy
(),
gt
.
numpy
())
S
.
FUNCTIONAL_LOADER
=
orig_loader_dict
def
test_tensor_method_loader
():
class
MyModule3
(
Module
):
def
forward
(
self
,
x
):
return
x
+
1
m
=
MyModule3
()
x
=
Tensor
(
np
.
ones
((
20
)))
traced_module
=
trace_module
(
m
,
x
)
orig_loader_dict
=
S
.
TENSORMETHOD_LOADER
S
.
TENSORMETHOD_LOADER
=
{}
@
register_tensor_method_loader
(
"__add__"
)
def
add_loader
(
expr
):
args
=
list
(
expr
.
args
)
if
not
isinstance
(
args
[
1
],
TensorNode
):
args
[
1
]
=
Tensor
(
args
[
1
])
node
=
Constant
(
args
[
1
],
"const"
).
outputs
[
0
]
astype_expr
=
CallMethod
(
node
,
"astype"
)
oup
=
TensorNode
(
astype_expr
,
shape
=
node
.
shape
,
dtype
=
node
.
dtype
,
qparams
=
node
.
qparams
,
)
astype_expr
.
set_args_kwargs
(
node
,
expr
.
inputs
[
0
].
dtype
)
astype_expr
.
return_val
=
(
oup
,)
add_expr
=
CallMethod
(
oup
,
"__add__"
)
add_expr
.
set_args_kwargs
(
oup
,
oup
)
oup1
=
TensorNode
(
add_expr
,
shape
=
oup
.
shape
,
dtype
=
oup
.
dtype
,
qparams
=
node
.
qparams
,
)
add_expr
.
return_val
=
oup1
args
[
1
]
=
oup1
expr
.
set_args_kwargs
(
*
args
)
obj
=
pickle
.
dumps
(
traced_module
)
new_module
=
pickle
.
loads
(
obj
)
_check_expr_users
(
new_module
)
_check_id
(
new_module
)
result
=
new_module
(
x
)
gt
=
m
(
x
)
assert
(
isinstance
(
new_module
.
graph
.
_exprs
[
0
],
Constant
)
and
len
(
new_module
.
graph
.
_exprs
)
==
4
)
np
.
testing
.
assert_equal
(
result
.
numpy
(),
(
x
+
2
).
numpy
())
S
.
TENSORMETHOD_LOADER
=
orig_loader_dict
def
test_module_loader
():
class
MyModule4
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
M
.
Conv2d
(
3
,
3
,
3
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
m
=
MyModule4
()
x
=
Tensor
(
np
.
random
.
random
((
1
,
3
,
32
,
32
)))
traced_module
=
trace_module
(
m
,
x
)
orig_loader_dict
=
S
.
MODULE_LOADER
S
.
MODULE_LOADER
=
{}
@
register_module_loader
((
"megengine.module.conv"
,
"Conv2d"
))
def
conv2dm_loader
(
expr
):
module
=
expr
.
inputs
[
0
].
owner
args
=
list
(
expr
.
args
)
orig_inp
=
args
[
1
]
astype_expr
=
CallMethod
(
orig_inp
,
"astype"
)
oup
=
TensorNode
(
astype_expr
,
shape
=
orig_inp
.
shape
,
dtype
=
orig_inp
.
dtype
,
qparams
=
orig_inp
.
qparams
,
)
astype_expr
.
set_args_kwargs
(
orig_inp
,
module
.
weight
.
dtype
)
astype_expr
.
return_val
=
(
oup
,)
args
[
1
]
=
oup
expr
.
set_args_kwargs
(
*
args
)
obj
=
pickle
.
dumps
(
traced_module
)
new_module
=
pickle
.
loads
(
obj
)
result
=
new_module
(
x
)
gt
=
m
(
x
)
assert
(
isinstance
(
new_module
.
graph
.
_exprs
[
1
],
CallMethod
)
and
len
(
new_module
.
graph
.
_exprs
)
==
3
)
np
.
testing
.
assert_equal
(
result
.
numpy
(),
gt
.
numpy
())
S
.
MODULE_LOADER
=
orig_loader_dict
def
test_shared_module
():
class
MyModule
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
a
=
M
.
Elemwise
(
"ADD"
)
self
.
b
=
self
.
a
def
forward
(
self
,
x
,
y
):
z
=
self
.
a
(
x
,
y
)
z
=
self
.
b
(
z
,
y
)
return
z
x
=
Tensor
(
1
)
y
=
Tensor
(
2
)
m
=
MyModule
()
tm
=
trace_module
(
m
,
x
,
y
)
obj
=
pickle
.
dumps
(
tm
)
load_tm
=
pickle
.
loads
(
obj
)
_check_expr_users
(
load_tm
)
_check_name
(
load_tm
.
flatten
())
_check_id
(
load_tm
)
assert
load_tm
.
a
is
load_tm
.
b
def
test_convert_kwargs_to_args
():
def
func
(
a
,
b
,
c
=
4
,
*
,
d
,
e
=
3
,
f
=
4
):
pass
args
=
(
1
,)
kwargs
=
{
"b"
:
1
,
"d"
:
6
}
new_args
,
new_kwargs
=
_convert_kwargs_to_args
(
func
,
args
,
kwargs
)
assert
new_args
==
(
1
,
1
,
4
)
assert
new_kwargs
==
{
"d"
:
6
,
"e"
:
3
,
"f"
:
4
}
args
=
(
1
,)
kwargs
=
{
"d"
:
6
}
new_args
,
new_kwargs
=
_convert_kwargs_to_args
(
func
,
args
,
kwargs
,
is_bounded
=
True
)
assert
new_args
==
(
1
,
4
)
assert
new_kwargs
==
{
"d"
:
6
,
"e"
:
3
,
"f"
:
4
}
def
func1
(
a
,
b
,
c
,
d
,
e
,
*
,
f
):
pass
args
=
()
kwargs
=
{
"a"
:
1
,
"b"
:
2
,
"c"
:
3
,
"d"
:
4
,
"e"
:
5
,
"f"
:
6
}
new_args
,
new_kwargs
=
_convert_kwargs_to_args
(
func1
,
args
,
kwargs
)
assert
new_args
==
(
1
,
2
,
3
,
4
,
5
)
assert
new_kwargs
==
{
"f"
:
6
}
def
test_opdef_serialization
():
with
TemporaryFile
()
as
f
:
x
=
builtin
.
Elemwise
(
mode
=
"Add"
)
pickle
.
dump
(
x
,
f
)
f
.
seek
(
0
)
load_x
=
pickle
.
load
(
f
)
assert
x
==
load_x
with
TemporaryFile
()
as
f
:
x
=
builtin
.
Convolution
(
stride_h
=
9
,
compute_mode
=
"float32"
)
x
.
strategy
=
(
builtin
.
Convolution
.
Strategy
.
PROFILE
|
builtin
.
Convolution
.
Strategy
.
HEURISTIC
|
builtin
.
Convolution
.
Strategy
.
REPRODUCIBLE
)
pickle
.
dump
(
x
,
f
)
f
.
seek
(
0
)
load_x
=
pickle
.
load
(
f
)
assert
x
.
strategy
==
load_x
.
strategy
assert
x
==
load_x
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录