Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
7b19bc76
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
7b19bc76
编写于
9月 08, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(traced_module): support traced module backward compatible serialization
GitOrigin-RevId: aaa9e51c74c11fa7955ae7bbfac476fa9bcf0d7d
上级
ffbfe59c
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
1314 addition
and
191 deletion
+1314
-191
imperative/python/megengine/traced_module/__init__.py
imperative/python/megengine/traced_module/__init__.py
+1
-0
imperative/python/megengine/traced_module/compat.py
imperative/python/megengine/traced_module/compat.py
+136
-0
imperative/python/megengine/traced_module/expr.py
imperative/python/megengine/traced_module/expr.py
+276
-27
imperative/python/megengine/traced_module/module_tracer.py
imperative/python/megengine/traced_module/module_tracer.py
+0
-1
imperative/python/megengine/traced_module/node.py
imperative/python/megengine/traced_module/node.py
+54
-5
imperative/python/megengine/traced_module/pytree.py
imperative/python/megengine/traced_module/pytree.py
+51
-31
imperative/python/megengine/traced_module/serialization.py
imperative/python/megengine/traced_module/serialization.py
+146
-18
imperative/python/megengine/traced_module/traced_module.py
imperative/python/megengine/traced_module/traced_module.py
+207
-79
imperative/python/megengine/traced_module/utils.py
imperative/python/megengine/traced_module/utils.py
+101
-1
imperative/python/test/unit/core/test_serialization.py
imperative/python/test/unit/core/test_serialization.py
+0
-23
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+3
-3
imperative/python/test/unit/traced_module/test_qat_module.py
imperative/python/test/unit/traced_module/test_qat_module.py
+24
-2
imperative/python/test/unit/traced_module/test_serialization.py
...tive/python/test/unit/traced_module/test_serialization.py
+315
-1
未找到文件。
imperative/python/megengine/traced_module/__init__.py
浏览文件 @
7b19bc76
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
..core._imperative_rt.core2
import
set_cpp_apply_module_trace
from
..core._imperative_rt.core2
import
set_cpp_apply_module_trace
from
.
import
compat
from
.traced_module
import
(
from
.traced_module
import
(
TracedModule
,
TracedModule
,
_register_all_builtin_module
,
_register_all_builtin_module
,
...
...
imperative/python/megengine/traced_module/compat.py
0 → 100644
浏览文件 @
7b19bc76
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
from
..
import
tensor
from
..core.ops.builtin
import
BatchNorm
from
.expr
import
CallMethod
,
Constant
from
.node
import
TensorNode
from
.serialization
import
(
register_functional_loader
,
register_module_loader
,
register_opdef_loader
,
register_tensor_method_loader
,
)
"""
# Expr loaders examples
from ..core.ops.builtin import Elemwise
@register_opdef_loader(Elemwise)
def add_opdef_loader(expr):
if expr.opdef_state["mode"] == "ADD":
expr.opdef_state["mode"] == "MUL"
node = expr.inputs[1]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr,
shape=node.shape,
dtype=expr.inputs[0].dtype,
qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
expr.inputs[1] = oup
@register_functional_loader(("megengine.functional.nn", "conv2d"))
def conv2df_loader(expr):
# expr.func = ("megengine.functional.nn","conv2d")
kwargs = expr.kwargs
orig_weight = expr.named_args["weight"]
astype_expr = CallMethod(orig_weight, "astype")
oup = TensorNode(
astype_expr,
shape=orig_weight.shape,
dtype=orig_weight.dtype,
qparams=orig_weight.qparams,
)
astype_expr.set_args_kwargs(orig_weight, expr.named_args["inp"].dtype)
astype_expr.return_val = (oup,)
expr.set_arg("weight", oup)
@register_module_loader(("megengine.module.conv", "Conv2d"))
def conv2dm_loader(expr):
module = expr.inputs[0].owner
args = list(expr.args)
orig_inp = args[1]
astype_expr = CallMethod(orig_inp, "astype")
oup = TensorNode(
astype_expr,
shape=orig_inp.shape,
dtype=orig_inp.dtype,
qparams=orig_inp.qparams,
)
astype_expr.set_args_kwargs(orig_inp, module.weight.dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)
@register_tensor_method_loader("__add__")
def add_loader(expr):
args = list(expr.args)
if not isinstance(args[1], TensorNode):
args[1] = tensor(args[1])
node = Constant(args[1], "const").outputs[0]
astype_expr = CallMethod(node, "astype")
oup = TensorNode(
astype_expr, shape=node.shape, dtype=node.dtype, qparams=node.qparams,
)
astype_expr.set_args_kwargs(node, expr.inputs[0].dtype)
astype_expr.return_val = (oup,)
args[1] = oup
expr.set_args_kwargs(*args)
"""
@
register_module_loader
(
(
"megengine.module.batchnorm"
,
"BatchNorm1d"
),
(
"megengine.module.batchnorm"
,
"BatchNorm2d"
),
(
"megengine.module.batchnorm"
,
"SyncBatchNorm"
),
)
def
bn2d_module_loader
(
expr
):
# mge 1.6
if
not
hasattr
(
expr
,
"version"
):
module
=
expr
.
inputs
[
0
].
owner
if
not
hasattr
(
module
,
"param_dim"
):
module
.
param_dim
=
"dim_1c11"
@
register_module_loader
(
(
"megengine.module.conv_bn"
,
"ConvBn2d"
),
(
"megengine.module.conv_bn"
,
"ConvBnRelu2d"
),
(
"megengine.module.qat.conv_bn"
,
"ConvBn2d"
),
(
"megengine.module.qat.conv_bn"
,
"ConvBnRelu2d"
),
)
def
convbn2d_module_loader
(
expr
):
# mge 1.6
if
not
hasattr
(
expr
,
"version"
):
module
=
expr
.
inputs
[
0
].
owner
if
not
hasattr
(
module
.
bn
,
"param_dim"
):
module
.
bn
.
param_dim
=
"dim_1c11"
@
register_opdef_loader
(
BatchNorm
)
def
bn_opdef_loader
(
expr
):
# mge 1.6
if
not
hasattr
(
expr
,
"version"
):
output
=
expr
.
outputs
[
-
1
]
oup
=
TensorNode
(
expr
,
shape
=
(
0
,),
dtype
=
None
,
qparams
=
output
.
_qparams
,)
expr
.
outputs
.
insert
(
4
,
oup
)
imperative/python/megengine/traced_module/expr.py
浏览文件 @
7b19bc76
...
@@ -11,19 +11,28 @@ import collections
...
@@ -11,19 +11,28 @@ import collections
import
copy
import
copy
import
inspect
import
inspect
import
re
import
re
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Union
import
weakref
from
importlib
import
import_module
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Union
from
..core._imperative_rt
import
OpDef
from
..core._imperative_rt
import
OpDef
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
apply
,
set_module_tracing
,
unset_module_tracing
from
..core._imperative_rt.core2
import
(
apply
,
is_tracing_module
,
set_module_tracing
,
unset_module_tracing
,
)
from
..core.ops.builtin
import
FakeQuant
from
..core.ops.builtin
import
FakeQuant
from
..core.ops.special
import
Const
from
..core.ops.special
import
Const
from
..module
import
Module
from
..module
import
Module
from
..tensor
import
Parameter
,
Tensor
from
..tensor
import
Parameter
,
Tensor
from
..version
import
__version__
from
.module_tracer
import
active_module_tracer
,
module_tracer
from
.module_tracer
import
active_module_tracer
,
module_tracer
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.pytree
import
ArgsIndex
,
TreeDef
,
_is_const_leaf
,
_is_leaf
,
tree_flatten
from
.pytree
import
ArgsIndex
,
TreeDef
,
_is_const_leaf
,
_is_leaf
,
tree_flatten
from
.serialization
import
get_opdef_state
,
load_opdef_from_state
from
.serialization
import
_ModuleState
from
.utils
import
_check_builtin_module_attr
,
_check_obj_attr
,
_convert_kwargs_to_args
def
rstrip
(
s
:
str
,
__chars
:
str
):
def
rstrip
(
s
:
str
,
__chars
:
str
):
...
@@ -112,6 +121,7 @@ class Expr:
...
@@ -112,6 +121,7 @@ class Expr:
node
.
users
.
append
(
self
)
node
.
users
.
append
(
self
)
else
:
else
:
assert
node
is
None
assert
node
is
None
assert
not
isinstance
(
val
,
(
Module
,
RawTensor
))
assert
_is_leaf
(
val
)
and
_is_const_leaf
(
val
)
assert
_is_leaf
(
val
)
and
_is_const_leaf
(
val
)
idx
=
len
(
self
.
inputs
)
+
len
(
self
.
const_val
)
idx
=
len
(
self
.
inputs
)
+
len
(
self
.
const_val
)
self
.
const_val
.
append
((
idx
,
val
))
self
.
const_val
.
append
((
idx
,
val
))
...
@@ -132,14 +142,14 @@ class Expr:
...
@@ -132,14 +142,14 @@ class Expr:
current_graph
.
_namespace
.
auto_naming_for_outputs
(
self
)
current_graph
.
_namespace
.
auto_naming_for_outputs
(
self
)
def
unflatten_args
(
self
,
inputs
):
def
unflatten_args
(
self
,
inputs
):
if
self
.
arg_def
is
not
None
:
assert
self
.
arg_def
is
not
None
,
"{} expr doesn't have args/kwargs"
.
format
(
inputs
=
list
(
inputs
)
type
(
self
).
__name__
for
idx
,
val
in
self
.
const_val
:
)
inputs
.
insert
(
idx
,
val
)
inputs
=
list
(
inputs
)
args
,
kwargs
=
self
.
arg_def
.
unflatten
(
inputs
)
for
idx
,
val
in
self
.
const_val
:
return
args
,
kwargs
inputs
.
insert
(
idx
,
val
)
else
:
args
,
kwargs
=
self
.
arg_def
.
unflatten
(
inputs
)
return
inputs
,
{}
return
args
,
kwargs
def
replace_inputs
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
def
replace_inputs
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
r
"""Replace the input Nodes of this Expr.
r
"""Replace the input Nodes of this Expr.
...
@@ -165,6 +175,39 @@ class Expr:
...
@@ -165,6 +175,39 @@ class Expr:
node
.
users
.
remove
(
self
)
node
.
users
.
remove
(
self
)
repl_node
.
users
.
append
(
self
)
repl_node
.
users
.
append
(
self
)
@
property
def
_support_set_args_kwargs
(
self
):
return
False
def
set_args_kwargs
(
self
,
*
args
,
**
kwargs
):
r
""" Set args and kwargs for Expr.
"""
assert
(
self
.
_support_set_args_kwargs
),
"Doesn't support set args/kwargs for {} expr"
.
format
(
type
(
self
).
__name__
)
args
,
kwargs
=
_convert_kwargs_to_args
(
self
.
_get_func
(),
args
,
kwargs
)
inputs
,
arg_def
=
tree_flatten
((
args
,
kwargs
))
orig_inputs
=
self
.
inputs
self
.
inputs
=
[]
self
.
const_val
=
[]
for
val
in
inputs
:
if
isinstance
(
val
,
(
TensorNode
,
ModuleNode
)):
self
.
inputs
.
append
(
val
)
else
:
assert
_is_leaf
(
val
)
and
_is_const_leaf
(
val
)
idx
=
len
(
self
.
inputs
)
+
len
(
self
.
const_val
)
self
.
const_val
.
append
((
idx
,
val
))
for
n
in
orig_inputs
:
if
n
not
in
self
.
inputs
:
n
.
users
.
remove
(
self
)
for
n
in
self
.
inputs
:
if
n
not
in
orig_inputs
:
n
.
users
.
append
(
self
)
self
.
arg_def
=
arg_def
@
property
@
property
def
kwargs
(
self
):
def
kwargs
(
self
):
r
"""Get the keyword arguments of the operation corresponding to this Expr."""
r
"""Get the keyword arguments of the operation corresponding to this Expr."""
...
@@ -177,6 +220,61 @@ class Expr:
...
@@ -177,6 +220,61 @@ class Expr:
args
,
_
=
self
.
unflatten_args
(
self
.
inputs
)
args
,
_
=
self
.
unflatten_args
(
self
.
inputs
)
return
args
return
args
def
_get_func
(
self
):
# get called function when the expr is interpreted
raise
NotImplementedError
@
property
def
named_args
(
self
):
func
=
self
.
_get_func
()
return
inspect
.
getcallargs
(
func
,
*
self
.
args
,
**
self
.
kwargs
)
def
set_arg
(
self
,
name
,
val
):
func
=
self
.
_get_func
()
if
name
in
self
.
kwargs
:
new_kwargs
=
self
.
kwargs
new_kwargs
[
name
]
=
val
self
.
set_args_kwargs
(
*
self
.
args
,
**
new_kwargs
)
else
:
arg_spec
=
inspect
.
getfullargspec
(
func
)
if
name
in
arg_spec
.
args
:
ind
=
arg_spec
.
args
.
index
(
name
)
new_args
=
list
(
self
.
args
)
new_args
[
ind
]
=
val
self
.
set_args_kwargs
(
*
new_args
)
elif
name
==
arg_spec
.
varargs
:
assert
arg_spec
.
varargs
is
not
None
assert
len
(
self
.
args
)
>=
len
(
arg_spec
.
args
)
val
=
(
val
,)
if
not
isinstance
(
val
,
Sequence
)
else
val
self
.
set_args_kwargs
(
*
self
.
args
[
0
:
len
(
arg_spec
.
args
)],
*
val
)
else
:
assert
(
arg_spec
.
varkw
is
not
None
),
"func {} does't have argument named {}"
.
format
(
func
,
name
)
new_kwargs
=
self
.
kwargs
new_kwargs
[
name
]
=
val
self
.
set_args_kwargs
(
*
self
.
args
,
**
new_kwargs
)
@
property
def
return_val
(
self
):
return
self
.
out_def
.
unflatten
(
self
.
outputs
)
@
return_val
.
setter
def
return_val
(
self
,
new_outputs
):
outputs
,
out_def
=
tree_flatten
(
new_outputs
,
is_leaf
=
lambda
x
:
isinstance
(
x
,
Node
)
)
assert
all
(
isinstance
(
o
,
Node
)
for
o
in
outputs
),
"Return values of expr must be ModuleNode or TensorNode or Container with them"
assert
all
(
o
.
expr
in
(
None
,
self
)
for
o
in
outputs
),
"Some nodes are produced by other expr, can not be output of expr {}"
.
format
(
self
)
self
.
outputs
=
outputs
self
.
out_def
=
out_def
@
property
@
property
def
top_graph
(
self
):
def
top_graph
(
self
):
r
"""Get the parent graph of this Expr."""
r
"""Get the parent graph of this Expr."""
...
@@ -184,12 +282,6 @@ class Expr:
...
@@ -184,12 +282,6 @@ class Expr:
return
self
.
_top_graph
()
return
self
.
_top_graph
()
return
None
return
None
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
if
"_top_graph"
in
state
:
state
.
pop
(
"_top_graph"
)
return
state
@
classmethod
@
classmethod
def
_get_next_id
(
cls
):
def
_get_next_id
(
cls
):
return
cls
.
__total_id
return
cls
.
__total_id
...
@@ -199,6 +291,23 @@ class Expr:
...
@@ -199,6 +291,23 @@ class Expr:
assert
isinstance
(
id
,
int
)
assert
isinstance
(
id
,
int
)
cls
.
__total_id
=
id
cls
.
__total_id
=
id
def
__copy__
(
self
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
result
.
__dict__
.
update
(
self
.
__dict__
)
return
result
def
__deepcopy__
(
self
,
memo
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
state
=
{}
memo
[
id
(
self
)]
=
result
for
k
,
v
in
self
.
__dict__
.
items
():
if
not
isinstance
(
v
,
weakref
.
ReferenceType
):
state
[
k
]
=
copy
.
deepcopy
(
v
,
memo
)
result
.
__dict__
.
update
(
state
)
return
result
# expr: None (i.e. fake expression which is used to mark input)
# expr: None (i.e. fake expression which is used to mark input)
class
Input
(
Expr
):
class
Input
(
Expr
):
...
@@ -229,6 +338,17 @@ class Input(Expr):
...
@@ -229,6 +338,17 @@ class Input(Expr):
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"%{}:
\t
{} = Input()"
.
format
(
self
.
_id
,
self
.
outputs
[
0
])
return
"%{}:
\t
{} = Input()"
.
format
(
self
.
_id
,
self
.
outputs
[
0
])
def
__getstate__
(
self
):
state
=
{
"_id"
:
self
.
_id
,
"_disable_remove"
:
self
.
_disable_remove
,
"inputs"
:
self
.
inputs
,
"outputs"
:
self
.
outputs
,
"name"
:
self
.
name
,
}
_check_obj_attr
(
state
)
return
state
# expr: outputs = getattr(inputs[0], self.name)
# expr: outputs = getattr(inputs[0], self.name)
class
GetAttr
(
Expr
):
class
GetAttr
(
Expr
):
...
@@ -276,11 +396,23 @@ class GetAttr(Expr):
...
@@ -276,11 +396,23 @@ class GetAttr(Expr):
def
__repr__
(
self
):
def
__repr__
(
self
):
out_type
=
"Tensor"
out_type
=
"Tensor"
if
isinstance
(
self
.
outputs
[
0
],
ModuleNode
):
if
isinstance
(
self
.
outputs
[
0
],
ModuleNode
):
out_type
=
self
.
outputs
[
0
].
module_type
.
__name__
m_type
=
self
.
outputs
[
0
].
module_type
out_type
=
m_type
.
__name__
if
isinstance
(
m_type
,
type
)
else
m_type
[
1
]
return
'%{}:
\t
{} = getattr({}, "{}") -> ({})'
.
format
(
return
'%{}:
\t
{} = getattr({}, "{}") -> ({})'
.
format
(
self
.
_id
,
self
.
outputs
[
0
],
self
.
inputs
[
0
],
self
.
name
,
out_type
self
.
_id
,
self
.
outputs
[
0
],
self
.
inputs
[
0
],
self
.
name
,
out_type
)
)
def
__getstate__
(
self
):
state
=
{
"_id"
:
self
.
_id
,
"_disable_remove"
:
self
.
_disable_remove
,
"inputs"
:
self
.
inputs
,
"outputs"
:
self
.
outputs
,
"name"
:
self
.
name
,
}
_check_obj_attr
(
state
)
return
state
# expr: outputs = inputs[0].__call__(*inputs[1:])
# expr: outputs = inputs[0].__call__(*inputs[1:])
class
CallMethod
(
Expr
):
class
CallMethod
(
Expr
):
...
@@ -307,6 +439,7 @@ class CallMethod(Expr):
...
@@ -307,6 +439,7 @@ class CallMethod(Expr):
node
,
node
,
]
]
self
.
const_val
=
[]
self
.
const_val
=
[]
self
.
arg_def
=
tree_flatten
(((
node
,),
{}))[
1
]
self
.
method
=
method
self
.
method
=
method
@
classmethod
@
classmethod
...
@@ -342,6 +475,27 @@ class CallMethod(Expr):
...
@@ -342,6 +475,27 @@ class CallMethod(Expr):
outputs
,
_
=
tree_flatten
(
outputs
,
is_leaf
=
lambda
x
:
isinstance
(
x
,
RawTensor
))
outputs
,
_
=
tree_flatten
(
outputs
,
is_leaf
=
lambda
x
:
isinstance
(
x
,
RawTensor
))
return
outputs
return
outputs
def
_get_func
(
self
):
if
isinstance
(
self
.
args
[
0
],
type
):
obj_type
=
self
.
args
[
0
]
elif
isinstance
(
self
.
args
[
0
],
ModuleNode
):
obj_type
=
self
.
args
[
0
].
module_type
else
:
assert
isinstance
(
self
.
args
[
0
],
TensorNode
)
obj_type
=
Tensor
meth
=
getattr
(
obj_type
,
"forward"
if
issubclass
(
obj_type
,
Module
)
else
self
.
method
)
return
meth
@
property
def
_support_set_args_kwargs
(
self
):
# only expr call tensor method or builtin module support modify args/kwargs
return
(
isinstance
(
self
.
args
[
0
],
(
TensorNode
,
type
))
or
self
.
args
[
0
].
module_type
is
not
Module
)
def
__repr__
(
self
):
def
__repr__
(
self
):
args
=
", "
.
join
(
str
(
i
)
for
i
in
self
.
args
[
1
:])
args
=
", "
.
join
(
str
(
i
)
for
i
in
self
.
args
[
1
:])
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
...
@@ -359,6 +513,21 @@ class CallMethod(Expr):
...
@@ -359,6 +513,21 @@ class CallMethod(Expr):
", "
.
join
([
args
,
kwargs
]),
", "
.
join
([
args
,
kwargs
]),
)
)
def
__getstate__
(
self
):
state
=
{
"_id"
:
self
.
_id
,
"_disable_remove"
:
self
.
_disable_remove
,
"inputs"
:
self
.
inputs
,
"const_val"
:
self
.
const_val
,
"method"
:
self
.
method
,
"arg_def"
:
self
.
arg_def
,
"out_def"
:
self
.
out_def
,
"outputs"
:
self
.
outputs
,
"version"
:
__version__
,
}
_check_obj_attr
(
state
)
return
state
# expr: outputs = apply(self.opdef, *inputs)
# expr: outputs = apply(self.opdef, *inputs)
class
Apply
(
Expr
):
class
Apply
(
Expr
):
...
@@ -394,14 +563,32 @@ class Apply(Expr):
...
@@ -394,14 +563,32 @@ class Apply(Expr):
)
)
def
__getstate__
(
self
):
def
__getstate__
(
self
):
state
=
super
().
__getstate__
()
opdef_state
=
self
.
opdef
.
__getstate__
()
state
[
"opdef"
]
=
get_opdef_state
(
state
[
"opdef"
])
opdef_state
[
"opdef_type"
]
=
type
(
self
.
opdef
)
state
=
{
"_id"
:
self
.
_id
,
"_disable_remove"
:
self
.
_disable_remove
,
"opdef_state"
:
opdef_state
,
"inputs"
:
self
.
inputs
,
"outputs"
:
self
.
outputs
,
"version"
:
__version__
,
}
_check_obj_attr
(
state
)
return
state
return
state
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
state
[
"opdef"
]
=
load_opdef_from_state
(
state
[
"opdef"
])
# compat with mge 1.6
for
k
,
v
in
state
.
items
():
if
"opdef"
in
state
and
"opdef_state"
not
in
state
:
setattr
(
self
,
k
,
v
)
opdef_state
=
state
.
pop
(
"opdef"
)
opdef_state
[
"opdef_type"
]
=
opdef_state
.
pop
(
"type"
)
state
[
"opdef_state"
]
=
opdef_state
self
.
__dict__
.
update
(
state
)
assert
isinstance
(
state
[
"opdef_state"
],
dict
)
opdef_state
=
state
[
"opdef_state"
].
copy
()
opdef_type
=
opdef_state
.
pop
(
"opdef_type"
)
opdef_obj
=
opdef_type
()
opdef_obj
.
__setstate__
(
opdef_state
)
setattr
(
self
,
"opdef"
,
opdef_obj
)
@
classmethod
@
classmethod
def
apply_module_trace_hook
(
cls
,
opdef
,
*
inputs
):
def
apply_module_trace_hook
(
cls
,
opdef
,
*
inputs
):
...
@@ -458,12 +645,24 @@ class CallFunction(Expr):
...
@@ -458,12 +645,24 @@ class CallFunction(Expr):
def
interpret
(
self
,
*
inputs
):
def
interpret
(
self
,
*
inputs
):
args
,
kwargs
=
self
.
unflatten_args
(
inputs
)
args
,
kwargs
=
self
.
unflatten_args
(
inputs
)
outputs
=
self
.
func
(
*
args
,
**
kwargs
)
func
=
(
self
.
func
if
not
is_tracing_module
()
else
active_module_tracer
().
patcher
.
wrap_fn
(
self
.
func
)
)
outputs
=
func
(
*
args
,
**
kwargs
)
if
outputs
is
None
:
if
outputs
is
None
:
return
outputs
return
outputs
outputs
,
_
=
tree_flatten
(
outputs
,
is_leaf
=
lambda
x
:
isinstance
(
x
,
RawTensor
))
outputs
,
_
=
tree_flatten
(
outputs
,
is_leaf
=
lambda
x
:
isinstance
(
x
,
RawTensor
))
return
outputs
return
outputs
def
_get_func
(
self
):
return
self
.
func
@
property
def
_support_set_args_kwargs
(
self
):
return
True
def
__repr__
(
self
):
def
__repr__
(
self
):
args
=
", "
.
join
(
str
(
i
)
for
i
in
self
.
args
)
args
=
", "
.
join
(
str
(
i
)
for
i
in
self
.
args
)
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
...
@@ -477,6 +676,33 @@ class CallFunction(Expr):
...
@@ -477,6 +676,33 @@ class CallFunction(Expr):
", "
.
join
([
args
,
kwargs
]),
", "
.
join
([
args
,
kwargs
]),
)
)
def
__getstate__
(
self
):
state
=
{
"_id"
:
self
.
_id
,
"_disable_remove"
:
self
.
_disable_remove
,
"func"
:
(
self
.
func
.
__module__
,
self
.
func
.
__qualname__
),
"const_val"
:
self
.
const_val
,
"inputs"
:
self
.
inputs
,
"arg_def"
:
self
.
arg_def
,
"out_def"
:
self
.
out_def
,
"outputs"
:
self
.
outputs
,
"version"
:
__version__
,
}
_check_obj_attr
(
state
)
return
state
def
__setstate__
(
self
,
state
):
self
.
__dict__
.
update
(
state
)
try
:
if
isinstance
(
self
.
func
,
tuple
):
mname
,
fname
=
self
.
func
f
=
import_module
(
mname
)
for
i
in
fname
.
split
(
"."
):
f
=
getattr
(
f
,
i
)
self
.
func
=
f
except
Exception
:
pass
# expr outputs = self.value
# expr outputs = self.value
class
Constant
(
Expr
):
class
Constant
(
Expr
):
...
@@ -496,6 +722,13 @@ class Constant(Expr):
...
@@ -496,6 +722,13 @@ class Constant(Expr):
assert
isinstance
(
c
,
(
RawTensor
,
Module
))
assert
isinstance
(
c
,
(
RawTensor
,
Module
))
if
isinstance
(
c
,
Module
):
if
isinstance
(
c
,
Module
):
assert
module_tracer
.
is_builtin
(
c
)
or
c
.
is_qat
assert
module_tracer
.
is_builtin
(
c
)
or
c
.
is_qat
if
isinstance
(
c
,
RawTensor
):
if
is_tracing_module
():
unset_module_tracing
()
c
=
Tensor
(
c
)
set_module_tracing
()
else
:
c
=
Tensor
(
c
)
self
.
value
=
c
self
.
value
=
c
self
.
name
=
name
self
.
name
=
name
self
.
inputs
=
[]
self
.
inputs
=
[]
...
@@ -530,9 +763,25 @@ class Constant(Expr):
...
@@ -530,9 +763,25 @@ class Constant(Expr):
)
)
def
__getstate__
(
self
):
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
=
{
if
"_top_graph"
in
state
:
"_id"
:
self
.
_id
,
state
.
pop
(
"_top_graph"
)
"_disable_remove"
:
self
.
_disable_remove
,
"value"
:
self
.
value
,
"name"
:
self
.
name
,
"inputs"
:
self
.
inputs
,
"outputs"
:
self
.
outputs
,
}
_check_obj_attr
(
state
)
if
isinstance
(
self
.
value
,
RawTensor
):
if
isinstance
(
self
.
value
,
RawTensor
):
state
[
"value"
]
=
Tensor
(
self
.
value
)
state
[
"value"
]
=
Tensor
(
self
.
value
)
if
isinstance
(
self
.
value
,
Module
)
and
module_tracer
.
is_builtin
(
self
.
value
):
_check_builtin_module_attr
(
self
.
value
)
state
[
"value"
]
=
_ModuleState
.
get_module_state
(
self
.
value
)
return
state
return
state
def
__setstate__
(
self
,
state
):
for
k
,
v
in
state
.
items
():
if
isinstance
(
v
,
_ModuleState
):
state
[
k
]
=
v
.
to_module
()
self
.
__dict__
.
update
(
state
)
imperative/python/megengine/traced_module/module_tracer.py
浏览文件 @
7b19bc76
...
@@ -72,7 +72,6 @@ BUILTIN_ARRAY_METHOD = [
...
@@ -72,7 +72,6 @@ BUILTIN_ARRAY_METHOD = [
"astype"
,
"astype"
,
"reshape"
,
"reshape"
,
"_broadcast"
,
"_broadcast"
,
"transpose"
,
"flatten"
,
"flatten"
,
"sum"
,
"sum"
,
"prod"
,
"prod"
,
...
...
imperative/python/megengine/traced_module/node.py
浏览文件 @
7b19bc76
...
@@ -6,7 +6,9 @@
...
@@ -6,7 +6,9 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
abc
import
abc
import
copy
import
weakref
import
weakref
from
importlib
import
import_module
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Type
import
numpy
import
numpy
...
@@ -14,7 +16,9 @@ import numpy
...
@@ -14,7 +16,9 @@ import numpy
from
..
import
get_logger
from
..
import
get_logger
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..module
import
Module
from
..module
import
Module
from
..quantization.utils
import
QParams
from
..tensor
import
Tensor
from
..tensor
import
Tensor
from
.utils
import
_check_obj_attr
logger
=
get_logger
(
__name__
)
logger
=
get_logger
(
__name__
)
...
@@ -145,6 +149,23 @@ class Node:
...
@@ -145,6 +149,23 @@ class Node:
assert
isinstance
(
id
,
int
)
assert
isinstance
(
id
,
int
)
cls
.
__total_id
=
id
cls
.
__total_id
=
id
def
__copy__
(
self
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
result
.
__dict__
.
update
(
self
.
__dict__
)
return
result
def
__deepcopy__
(
self
,
memo
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
state
=
{}
memo
[
id
(
self
)]
=
result
for
k
,
v
in
self
.
__dict__
.
items
():
if
not
isinstance
(
v
,
weakref
.
ReferenceType
)
and
k
!=
"actual_node"
:
state
[
k
]
=
copy
.
deepcopy
(
v
,
memo
)
result
.
__dict__
.
update
(
state
)
return
result
class
ModuleNode
(
Node
):
class
ModuleNode
(
Node
):
r
"""``ModuleNode`` represents the Module objects."""
r
"""``ModuleNode`` represents the Module objects."""
...
@@ -157,19 +178,28 @@ class ModuleNode(Node):
...
@@ -157,19 +178,28 @@ class ModuleNode(Node):
super
().
__init__
(
expr
,
name
,
qualname
)
super
().
__init__
(
expr
,
name
,
qualname
)
def
__getstate__
(
self
):
def
__getstate__
(
self
):
return
{
state
=
{
"expr"
:
self
.
expr
,
"expr"
:
self
.
expr
,
"users"
:
self
.
users
,
"users"
:
self
.
users
,
"_id"
:
self
.
_id
,
"_id"
:
self
.
_id
,
"_name"
:
self
.
_name
,
"_name"
:
self
.
_name
,
"_qualname"
:
self
.
_qualname
,
"_qualname"
:
self
.
_qualname
,
"module_type"
:
self
.
module_type
,
"module_type"
:
(
self
.
module_type
.
__module__
,
self
.
module_type
.
__qualname__
)
,
}
}
_check_obj_attr
(
state
)
return
state
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
if
"_orig_name"
in
state
:
if
"_orig_name"
in
state
:
state
[
"_qualname"
]
=
state
.
pop
(
"_orig_name"
)
state
[
"_qualname"
]
=
state
.
pop
(
"_orig_name"
)
self
.
__dict__
.
update
(
state
)
self
.
__dict__
.
update
(
state
)
try
:
if
isinstance
(
self
.
module_type
,
tuple
):
mname
,
classname
=
self
.
module_type
mtype
=
getattr
(
import_module
(
mname
),
classname
)
self
.
module_type
=
mtype
except
Exception
:
pass
@
property
@
property
def
owner
(
self
):
def
owner
(
self
):
...
@@ -185,12 +215,26 @@ class TensorNode(Node):
...
@@ -185,12 +215,26 @@ class TensorNode(Node):
_shape
=
None
# type: Tuple[int]
_shape
=
None
# type: Tuple[int]
_dtype
=
None
# type: numpy.dtype
_dtype
=
None
# type: numpy.dtype
_qparams
=
None
_qparams
=
None
# type: QParams
_device
=
None
_device
=
None
_value
=
None
# type: Tensor
_value
=
None
# type: Tensor
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
,
qualname
:
str
=
None
,
shape
:
Tuple
[
int
]
=
None
,
dtype
:
numpy
.
dtype
=
None
,
qparams
:
QParams
=
None
,
):
super
().
__init__
(
expr
,
name
,
qualname
)
self
.
_shape
=
shape
self
.
_dtype
=
shape
self
.
_qparams
=
qparams
def
__getstate__
(
self
):
def
__getstate__
(
self
):
return
{
state
=
{
"expr"
:
self
.
expr
,
"expr"
:
self
.
expr
,
"users"
:
self
.
users
,
"users"
:
self
.
users
,
"_id"
:
self
.
_id
,
"_id"
:
self
.
_id
,
...
@@ -201,6 +245,8 @@ class TensorNode(Node):
...
@@ -201,6 +245,8 @@ class TensorNode(Node):
"_name"
:
self
.
_name
,
"_name"
:
self
.
_name
,
"_qualname"
:
self
.
_qualname
,
"_qualname"
:
self
.
_qualname
,
}
}
_check_obj_attr
(
state
)
return
state
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
if
"_orig_name"
in
state
:
if
"_orig_name"
in
state
:
...
@@ -276,7 +322,10 @@ class NodeMixin(abc.ABC):
...
@@ -276,7 +322,10 @@ class NodeMixin(abc.ABC):
assert
isinstance
(
node
,
TensorNode
)
assert
isinstance
(
node
,
TensorNode
)
assert
isinstance
(
value
,
RawTensor
)
assert
isinstance
(
value
,
RawTensor
)
if
isinstance
(
value
,
RawTensor
):
if
isinstance
(
value
,
RawTensor
):
node
.
_dtype
=
value
.
dtype
try
:
node
.
_dtype
=
value
.
dtype
except
RuntimeError
:
node
.
_dtype
=
None
node
.
_shape
=
(
node
.
_shape
=
(
value
.
_tuple_shape
if
isinstance
(
value
,
Tensor
)
else
value
.
shape
value
.
_tuple_shape
if
isinstance
(
value
,
Tensor
)
else
value
.
shape
)
)
...
...
imperative/python/megengine/traced_module/pytree.py
浏览文件 @
7b19bc76
...
@@ -7,15 +7,18 @@
...
@@ -7,15 +7,18 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
import
collections
from
collections
import
OrderedDict
from
collections
import
OrderedDict
,
defaultdict
from
functools
import
partial
from
typing
import
Callable
,
NamedTuple
from
typing
import
Callable
,
NamedTuple
import
numpy
as
np
import
numpy
as
np
from
..core._imperative_rt
import
OpDef
from
..core._imperative_rt.common
import
CompNode
from
..core._imperative_rt.common
import
CompNode
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._imperative_rt.core2
import
Tensor
as
RawTensor
from
..core._wrap
import
Device
from
..core._wrap
import
Device
from
..core.tensor.dtype
import
QuantDtypeMeta
from
..core.tensor.dtype
import
QuantDtypeMeta
from
..distributed
import
Group
from
..module
import
Module
from
..module
import
Module
from
..quantization.utils
import
LSQParams
,
QParams
,
QuantMode
from
..quantization.utils
import
LSQParams
,
QParams
,
QuantMode
from
..tensor
import
Parameter
,
Tensor
from
..tensor
import
Parameter
,
Tensor
...
@@ -49,45 +52,54 @@ SUPPORTED_LEAF_TYPE = {
...
@@ -49,45 +52,54 @@ SUPPORTED_LEAF_TYPE = {
type
(
Ellipsis
),
type
(
Ellipsis
),
QuantMode
,
QuantMode
,
ArgsIndex
,
ArgsIndex
,
Group
,
}
}
USER_REGISTERED_LEAF_TYPE
=
[]
USER_REGISTERED_CONTAINER_TYPE
=
[]
# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree
# if isinstance(object, SUPPORTED_LEAF_CLS) or issubclass(obj, SUPPORTED_LEAF_CLS) is True, the object could be threated as leaf node of pytree
SUPPORTED_LEAF_CLS
=
[
Module
,
Node
,
NodeMixin
,
np
.
dtype
,
np
.
ndarray
,
np
.
number
]
SUPPORTED_LEAF_CLS
=
[
Module
,
Node
,
NodeMixin
,
np
.
dtype
,
np
.
ndarray
,
np
.
number
,
np
.
bool_
,
OpDef
,
]
NodeType
=
NamedTuple
(
"NodeType"
,
[(
"flatten"
,
Callable
),
(
"unflatten"
,
Callable
)])
NodeType
=
NamedTuple
(
"NodeType"
,
[(
"flatten"
,
Callable
),
(
"unflatten"
,
Callable
)])
def
register_supported_type
(
type
,
flatten
=
None
,
unflatten
=
None
):
def
register_supported_type
(
type
,
flatten
=
None
,
unflatten
=
None
):
tp_info
=
(
type
.
__module__
,
type
.
__qualname__
)
if
flatten
and
unflatten
:
if
flatten
and
unflatten
:
SUPPORTED_TYPE
[
type
]
=
NodeType
(
flatten
,
unflatten
)
USER_REGISTERED_CONTAINER_TYPE
.
append
(
tp_info
)
else
:
else
:
SUPPORTED_LEAF_CLS
.
append
(
type
)
USER_REGISTERED_LEAF_TYPE
.
append
(
tp_info
)
_register_supported_type
(
type
,
flatten
,
unflatten
)
def
_dict_flatten
(
inp
):
aux_data
=
[]
results
=
[]
for
key
,
value
in
sorted
(
inp
.
items
()):
results
.
append
(
value
)
aux_data
.
append
(
key
)
return
results
,
tuple
(
aux_data
)
def
_dict_unflatten
(
inps
,
aux_data
):
def
_register_supported_type
(
type
,
flatten
=
None
,
unflatten
=
None
):
return
dict
(
zip
(
aux_data
,
inps
))
if
flatten
and
unflatten
:
SUPPORTED_TYPE
[
type
]
=
NodeType
(
flatten
,
unflatten
)
else
:
SUPPORTED_LEAF_CLS
.
append
(
type
)
def
_
ordereddict_flatten
(
inp
):
def
_
dict_flatten
(
ordered
,
inp
):
aux_data
=
[]
aux_data
=
[]
results
=
[]
results
=
[]
for
key
,
value
in
inp
.
items
():
dict_items
=
inp
.
items
()
if
ordered
else
sorted
(
inp
.
items
())
for
key
,
value
in
dict_items
:
results
.
append
(
value
)
results
.
append
(
value
)
aux_data
.
append
(
key
)
aux_data
.
append
(
key
)
return
results
,
tuple
(
aux_data
)
return
results
,
tuple
(
aux_data
)
def
_
ordereddict_unflatten
(
inps
,
aux_data
):
def
_
dict_unflatten
(
dict_type
,
inps
,
aux_data
):
return
OrderedDict
(
zip
(
aux_data
,
inps
))
return
dict_type
(
zip
(
aux_data
,
inps
))
def
qparams_flatten
(
inp
):
def
qparams_flatten
(
inp
):
...
@@ -99,33 +111,41 @@ def qparams_flatten(inp):
...
@@ -99,33 +111,41 @@ def qparams_flatten(inp):
return
results
,
tuple
(
aux_data
)
return
results
,
tuple
(
aux_data
)
def
qparams_unflatten
(
inp
,
aux_data
):
def
qparams_unflatten
(
qparam_type
,
inp
,
aux_data
):
obj
=
QParams
.
__new__
(
QParams
)
obj
=
qparam_type
.
__new__
(
qparam_type
)
for
k
,
v
in
zip
(
aux_data
,
inp
):
for
k
,
v
in
zip
(
aux_data
,
inp
):
setattr
(
obj
,
k
,
v
)
setattr
(
obj
,
k
,
v
)
return
obj
return
obj
register_supported_type
(
list
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
list
(
x
))
_register_supported_type
(
list
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
list
(
x
))
register_supported_type
(
tuple
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
tuple
(
x
))
_register_supported_type
(
tuple
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
tuple
(
x
))
register_supported_type
(
dict
,
_dict_flatten
,
_dict_unflatten
)
_register_supported_type
(
register_supported_type
(
dict
,
partial
(
_dict_flatten
,
False
),
partial
(
_dict_unflatten
,
dict
)
collections
.
OrderedDict
,
_ordereddict_flatten
,
_ordereddict_unflatten
)
_register_supported_type
(
defaultdict
,
partial
(
_dict_flatten
,
False
),
partial
(
_dict_unflatten
,
defaultdict
)
)
)
register_supported_type
(
_register_supported_type
(
OrderedDict
,
partial
(
_dict_flatten
,
True
),
partial
(
_dict_unflatten
,
OrderedDict
)
)
_register_supported_type
(
slice
,
slice
,
lambda
x
:
([
x
.
start
,
x
.
stop
,
x
.
step
],
None
),
lambda
x
:
([
x
.
start
,
x
.
stop
,
x
.
step
],
None
),
lambda
x
,
aux_data
:
slice
(
x
[
0
],
x
[
1
],
x
[
2
]),
lambda
x
,
aux_data
:
slice
(
x
[
0
],
x
[
1
],
x
[
2
]),
)
)
register_supported_type
(
QParams
,
qparams_flatten
,
qparams_unflatten
)
_register_supported_type
(
QParams
,
qparams_flatten
,
partial
(
qparams_unflatten
,
QParams
))
_register_supported_type
(
LSQParams
,
qparams_flatten
,
partial
(
qparams_unflatten
,
LSQParams
)
)
def
_is_leaf
(
obj
):
def
_is_leaf
(
obj
):
if
isinstance
(
obj
,
type
):
obj_type
=
obj
if
isinstance
(
obj
,
type
)
else
type
(
obj
)
return
issubclass
(
obj
,
tuple
(
SUPPORTED_LEAF_CLS
))
or
obj
in
SUPPORTED_LEAF_TYPE
return
(
return
(
isinstance
(
obj
,
tuple
(
SUPPORTED_LEAF_CLS
))
or
type
(
obj
)
in
SUPPORTED_LEAF_TYPE
issubclass
(
obj_type
,
tuple
(
SUPPORTED_LEAF_CLS
))
or
obj_type
in
SUPPORTED_LEAF_TYPE
)
)
...
...
imperative/python/megengine/traced_module/serialization.py
浏览文件 @
7b19bc76
...
@@ -5,30 +5,158 @@
...
@@ -5,30 +5,158 @@
# Unless required by applicable law or agreed to in writing,
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Dict
from
importlib
import
import_module
from
typing
import
Dict
,
Tuple
from
..core._imperative_rt
import
OpDef
from
..core._imperative_rt
import
OpDef
from
..core.ops
import
builtin
from
..core.ops
import
builtin
from
..tensor
import
Tensor
from
..version
import
__version__
from
..version
import
__version__
from
.utils
import
_convert_kwargs_to_args
OPDEF_PARAM_LOADER
=
{}
OPDEF_LOADER
=
{}
FUNCTIONAL_LOADER
=
{}
TENSORMETHOD_LOADER
=
{}
MODULE_LOADER
=
{}
def
get_opdef_state
(
obj
:
OpDef
)
->
Dict
:
class
_ModuleState
:
state
=
obj
.
__getstate__
()
obj
=
None
state
[
"type"
]
=
type
(
obj
)
state
[
"version"
]
=
__version__
return
state
def
__init__
(
self
,
module
:
Tuple
,
state
:
Dict
,
version
:
str
):
self
.
module
=
module
self
.
state
=
state
self
.
version
=
version
def
load_opdef_from_state
(
state
:
Dict
)
->
OpDef
:
@
classmethod
assert
"type"
in
state
and
issubclass
(
state
[
"type"
],
OpDef
)
def
get_module_state
(
cls
,
module
):
assert
"version"
in
state
typem
=
(
type
(
module
).
__module__
,
type
(
module
).
__qualname__
)
opdef_type
=
state
.
pop
(
"type"
)
state
=
module
.
__dict__
.
copy
()
if
opdef_type
in
OPDEF_PARAM_LOADER
:
state
.
pop
(
"_m_dump_modulestate"
,
None
)
loader
=
OPDEF_PARAM_LOADER
[
opdef_type
]
if
hasattr
(
module
,
"_m_dump_modulestate"
):
state
=
loader
(
state
)
assert
isinstance
(
module
.
_m_dump_modulestate
,
cls
)
state
.
pop
(
"version"
)
module
.
_m_dump_modulestate
.
__init__
(
typem
,
state
,
__version__
)
opdef_obj
=
opdef_type
()
else
:
opdef_obj
.
__setstate__
(
state
)
module
.
__dict__
[
"_m_dump_modulestate"
]
=
_ModuleState
(
return
opdef_obj
typem
,
state
,
__version__
)
return
module
.
_m_dump_modulestate
def
__getstate__
(
self
):
return
{
"module"
:
self
.
module
,
"state"
:
self
.
state
,
"version"
:
self
.
version
}
def
to_module
(
self
):
if
self
.
obj
is
None
:
typem
=
getattr
(
import_module
(
self
.
module
[
0
]),
self
.
module
[
1
])
m_obj
=
typem
.
__new__
(
typem
)
m_obj
.
__dict__
.
update
(
self
.
state
)
self
.
obj
=
m_obj
return
self
.
obj
def
register_opdef_loader
(
*
opdefs
):
def
callback
(
loader
):
for
opdef
in
opdefs
:
assert
opdef
not
in
OPDEF_LOADER
OPDEF_LOADER
[
opdef
]
=
loader
return
loader
return
callback
def
register_functional_loader
(
*
funcs
):
def
callback
(
loader
):
for
func
in
funcs
:
assert
func
not
in
FUNCTIONAL_LOADER
FUNCTIONAL_LOADER
[
func
]
=
loader
return
loader
return
callback
def
register_module_loader
(
*
module_types
):
def
callback
(
loader
):
for
module_type
in
module_types
:
assert
module_type
not
in
MODULE_LOADER
MODULE_LOADER
[
module_type
]
=
loader
return
loader
return
callback
def
register_tensor_method_loader
(
*
methods
):
def
callback
(
loader
):
for
method
in
methods
:
assert
method
not
in
TENSORMETHOD_LOADER
TENSORMETHOD_LOADER
[
method
]
=
loader
return
loader
return
callback
def
_replace_args_kwargs
(
expr
,
new_args
,
new_kwargs
):
if
len
(
new_args
)
!=
len
(
expr
.
args
)
or
set
(
new_kwargs
.
keys
())
!=
set
(
expr
.
kwargs
.
keys
()
):
expr
.
set_args_kwargs
(
*
new_args
,
**
new_kwargs
)
def
load_functional
(
expr
):
func
=
(
(
expr
.
func
.
__module__
,
expr
.
func
.
__qualname__
)
if
callable
(
expr
.
func
)
else
expr
.
func
)
assert
isinstance
(
func
,
tuple
)
if
func
in
FUNCTIONAL_LOADER
:
loader
=
FUNCTIONAL_LOADER
[
func
]
loader
(
expr
)
mname
,
fname
=
func
f
=
import_module
(
mname
)
for
i
in
fname
.
split
(
"."
):
f
=
getattr
(
f
,
i
)
expr
.
func
=
f
assert
callable
(
expr
.
func
)
if
not
hasattr
(
expr
,
"version"
)
or
expr
.
version
!=
__version__
:
args
,
kwargs
=
_convert_kwargs_to_args
(
expr
.
func
,
expr
.
args
,
expr
.
kwargs
)
_replace_args_kwargs
(
expr
,
args
,
kwargs
)
def
load_call_module_expr
(
expr
):
m_type
=
expr
.
inputs
[
0
].
module_type
if
isinstance
(
m_type
,
type
):
m_type
=
(
m_type
.
__module__
,
m_type
.
__qualname__
)
if
m_type
in
MODULE_LOADER
:
MODULE_LOADER
[
m_type
](
expr
)
if
isinstance
(
expr
.
inputs
[
0
].
module_type
,
tuple
):
mname
,
classname
=
expr
.
inputs
[
0
].
module_type
expr
.
inputs
[
0
].
module_type
=
getattr
(
import_module
(
mname
),
classname
)
if
not
hasattr
(
expr
,
"version"
)
or
expr
.
version
!=
__version__
:
fwd_func
=
getattr
(
expr
.
inputs
[
0
].
module_type
,
"forward"
)
args
,
kwargs
=
_convert_kwargs_to_args
(
fwd_func
,
expr
.
args
,
expr
.
kwargs
)
_replace_args_kwargs
(
expr
,
args
,
kwargs
)
def
load_call_tensor_method_expr
(
expr
):
if
expr
.
method
in
TENSORMETHOD_LOADER
:
loader
=
TENSORMETHOD_LOADER
[
expr
.
method
]
loader
(
expr
)
if
not
hasattr
(
expr
,
"version"
)
or
expr
.
version
!=
__version__
:
tmethod
=
(
getattr
(
expr
.
args
[
0
],
expr
.
method
)
if
isinstance
(
expr
.
args
[
0
],
type
)
else
getattr
(
Tensor
,
expr
.
method
)
)
args
,
kwargs
=
_convert_kwargs_to_args
(
tmethod
,
expr
.
args
,
expr
.
kwargs
)
_replace_args_kwargs
(
expr
,
args
,
kwargs
)
def
load_apply_expr
(
expr
):
opdef_type
=
type
(
expr
.
opdef
)
if
opdef_type
in
OPDEF_LOADER
:
OPDEF_LOADER
[
opdef_type
](
expr
)
opdef_state
=
expr
.
opdef_state
opdef_obj
=
opdef_state
.
pop
(
"opdef_type"
)()
opdef_obj
.
__setstate__
(
opdef_state
)
expr
.
opdef
=
opdef_obj
imperative/python/megengine/traced_module/traced_module.py
浏览文件 @
7b19bc76
...
@@ -14,6 +14,7 @@ import inspect
...
@@ -14,6 +14,7 @@ import inspect
import
keyword
import
keyword
import
re
import
re
import
weakref
import
weakref
from
importlib
import
import_module
from
inspect
import
getcallargs
,
getmembers
,
isclass
,
ismethod
from
inspect
import
getcallargs
,
getmembers
,
isclass
,
ismethod
from
itertools
import
chain
from
itertools
import
chain
from
types
import
FunctionType
from
types
import
FunctionType
...
@@ -53,6 +54,7 @@ from ..quantization.observer import (
...
@@ -53,6 +54,7 @@ from ..quantization.observer import (
SyncMinMaxObserver
,
SyncMinMaxObserver
,
)
)
from
..tensor
import
Tensor
from
..tensor
import
Tensor
from
..version
import
__version__
from
.expr
import
(
from
.expr
import
(
Apply
,
Apply
,
CallFunction
,
CallFunction
,
...
@@ -80,8 +82,27 @@ from .module_tracer import (
...
@@ -80,8 +82,27 @@ from .module_tracer import (
set_active_module_tracer
,
set_active_module_tracer
,
)
)
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.pytree
import
ArgsIndex
,
tree_flatten
from
.pytree
import
(
from
.utils
import
replace_container_with_module_container
USER_REGISTERED_CONTAINER_TYPE
,
USER_REGISTERED_LEAF_TYPE
,
ArgsIndex
,
TreeDef
,
_register_supported_type
,
tree_flatten
,
)
from
.serialization
import
(
_ModuleState
,
load_apply_expr
,
load_call_module_expr
,
load_call_tensor_method_expr
,
load_functional
,
)
from
.utils
import
(
_check_builtin_module_attr
,
_check_obj_attr
,
_convert_kwargs_to_args
,
replace_container_with_module_container
,
)
logger
=
get_logger
(
__name__
)
logger
=
get_logger
(
__name__
)
...
@@ -341,7 +362,7 @@ class NameSpace:
...
@@ -341,7 +362,7 @@ class NameSpace:
def
create_unique_name
(
self
,
name
:
str
,
node
:
Any
=
None
)
->
str
:
def
create_unique_name
(
self
,
name
:
str
,
node
:
Any
=
None
)
->
str
:
assert
isinstance
(
name
,
str
),
"The name must be a string"
assert
isinstance
(
name
,
str
),
"The name must be a string"
if
name
in
self
.
_used_names
and
self
.
_used_names
[
name
]
is
node
:
if
name
in
self
.
_used_names
and
(
self
.
_used_names
[
name
]
is
node
)
:
return
name
return
name
name
=
re
.
sub
(
"[^0-9a-zA-Z_]+"
,
"_"
,
name
)
name
=
re
.
sub
(
"[^0-9a-zA-Z_]+"
,
"_"
,
name
)
...
@@ -1067,6 +1088,7 @@ class InternalGraph:
...
@@ -1067,6 +1088,7 @@ class InternalGraph:
if
node2value
[
n
][
1
]
==
0
:
if
node2value
[
n
][
1
]
==
0
:
node2value
.
pop
(
n
)
node2value
.
pop
(
n
)
if
values
is
not
None
:
if
values
is
not
None
:
assert
len
(
values
)
==
len
(
expr
.
outputs
)
for
n
,
v
in
zip
(
expr
.
outputs
,
values
):
for
n
,
v
in
zip
(
expr
.
outputs
,
values
):
if
ref_count
(
n
)
>
0
:
if
ref_count
(
n
)
>
0
:
node2value
[
n
]
=
[
v
,
ref_count
(
n
)]
node2value
[
n
]
=
[
v
,
ref_count
(
n
)]
...
@@ -1105,13 +1127,27 @@ class InternalGraph:
...
@@ -1105,13 +1127,27 @@ class InternalGraph:
return
res
return
res
def
__getstate__
(
self
):
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
=
{
if
"_top_graph"
in
state
:
"_exprs"
:
self
.
_exprs
,
state
.
pop
(
"_top_graph"
)
"_inputs"
:
self
.
_inputs
,
"_outputs"
:
self
.
_outputs
,
"_watch_point"
:
[],
"_end_point"
:
[],
"_namespace"
:
self
.
_namespace
,
"_rst"
:
collections
.
defaultdict
(
list
),
"_name"
:
self
.
_name
,
"_qualname"
:
self
.
_qualname
,
}
if
self
.
_total_ids
:
state
[
"_total_ids"
]
=
self
.
_total_ids
_check_obj_attr
(
state
)
return
state
return
state
def
__setstate__
(
self
,
state
):
def
__setstate__
(
self
,
state
):
old_version
=
False
old_version
=
False
if
"_module_name"
in
state
:
if
"_module_name"
in
state
:
old_version
=
True
old_version
=
True
state
[
"_qualname"
]
=
state
.
pop
(
"_module_name"
)
state
[
"_qualname"
]
=
state
.
pop
(
"_module_name"
)
...
@@ -1144,6 +1180,25 @@ class InternalGraph:
...
@@ -1144,6 +1180,25 @@ class InternalGraph:
self
.
_namespace
=
NameSpace
(
self
.
_name
,
self
.
_qualname
)
self
.
_namespace
=
NameSpace
(
self
.
_name
,
self
.
_qualname
)
self
.
_re_associate_name
()
self
.
_re_associate_name
()
def
__copy__
(
self
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
result
.
__dict__
.
update
(
self
.
__dict__
)
return
result
def
__deepcopy__
(
self
,
memo
):
if
id
(
self
)
in
memo
:
return
memo
[
id
(
self
)]
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
state
=
{}
memo
[
id
(
self
)]
=
result
for
k
,
v
in
self
.
__dict__
.
items
():
if
not
isinstance
(
v
,
weakref
.
ReferenceType
):
state
[
k
]
=
copy
.
deepcopy
(
v
,
memo
)
result
.
__dict__
.
update
(
state
)
return
result
def
_get_meth_name
(
obj
,
func
):
def
_get_meth_name
(
obj
,
func
):
tp
=
obj
if
isinstance
(
obj
,
type
)
else
type
(
obj
)
tp
=
obj
if
isinstance
(
obj
,
type
)
else
type
(
obj
)
...
@@ -1157,9 +1212,7 @@ def _get_meth_name(obj, func):
...
@@ -1157,9 +1212,7 @@ def _get_meth_name(obj, func):
def
_wrapped_function
(
orig_func
):
def
_wrapped_function
(
orig_func
):
@
functools
.
wraps
(
orig_func
)
@
functools
.
wraps
(
orig_func
)
def
wrapped_fn
(
*
args
,
**
kwargs
):
def
wrapped_fn
(
*
args
,
**
kwargs
):
method_func
=
wrapped_fn
method_func
=
kwargs
.
pop
(
"method_func"
,
wrapped_fn
)
if
"method_func"
in
kwargs
:
method_func
=
kwargs
.
pop
(
"method_func"
)
if
is_tracing_module
():
if
is_tracing_module
():
unset_module_tracing
()
unset_module_tracing
()
inputs
,
tree_def
=
tree_flatten
((
args
,
kwargs
))
inputs
,
tree_def
=
tree_flatten
((
args
,
kwargs
))
...
@@ -1167,11 +1220,11 @@ def _wrapped_function(orig_func):
...
@@ -1167,11 +1220,11 @@ def _wrapped_function(orig_func):
if
not
NodeMixin
.
get
(
i
,
None
):
if
not
NodeMixin
.
get
(
i
,
None
):
if
isinstance
(
i
,
(
RawTensor
,
NodeMixin
)):
if
isinstance
(
i
,
(
RawTensor
,
NodeMixin
)):
NodeMixin
.
wrap_safe
(
i
,
Constant
.
make
(
i
))
NodeMixin
.
wrap_safe
(
i
,
Constant
.
make
(
i
))
meth_name
,
arg_type
=
None
,
None
args
,
kwargs
=
_convert_kwargs_to_args
(
orig_func
,
args
,
kwargs
)
if
args
:
meth_name
=
_get_meth_name
(
args
[
0
],
method_func
)
meth_name
=
_get_meth_name
(
args
[
0
],
method_func
)
arg_type
=
args
[
0
]
if
isinstance
(
args
[
0
],
type
)
else
type
(
args
[
0
])
arg_type
=
args
[
0
]
if
isinstance
(
args
[
0
],
type
)
else
type
(
args
[
0
])
if
meth_name
and
arg_type
and
issubclass
(
arg_type
,
RawTensor
):
if
meth_name
and
arg_type
and
issubclass
(
arg_type
,
RawTensor
):
inputs
,
tree_def
=
tree_flatten
((
args
,
kwargs
))
self
=
inputs
[
0
]
self
=
inputs
[
0
]
if
meth_name
==
"__new__"
:
if
meth_name
==
"__new__"
:
if
all
([
not
isinstance
(
i
,
RawTensor
)
for
i
in
inputs
]):
if
all
([
not
isinstance
(
i
,
RawTensor
)
for
i
in
inputs
]):
...
@@ -1190,6 +1243,7 @@ def _wrapped_function(orig_func):
...
@@ -1190,6 +1243,7 @@ def _wrapped_function(orig_func):
call_node
=
CallMethod
.
make
(
NodeMixin
.
get
(
self
),
meth_name
)
call_node
=
CallMethod
.
make
(
NodeMixin
.
get
(
self
),
meth_name
)
call_node
.
add_inputs
(
inputs
[
1
:])
call_node
.
add_inputs
(
inputs
[
1
:])
else
:
else
:
inputs
,
tree_def
=
tree_flatten
((
args
,
kwargs
))
call_node
=
CallFunction
.
make
(
orig_func
)
call_node
=
CallFunction
.
make
(
orig_func
)
call_node
.
add_inputs
(
inputs
)
call_node
.
add_inputs
(
inputs
)
...
@@ -1228,9 +1282,11 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -1228,9 +1282,11 @@ class TracedModuleBuilder(NodeMixin):
"_record_wrapped_nodes"
,
"_record_wrapped_nodes"
,
"_argdef_graph_map"
,
"_argdef_graph_map"
,
"_argdef_outdef_map"
,
"_argdef_outdef_map"
,
"_check_qat_module"
,
"nodes"
,
"nodes"
,
"__class__"
,
"__class__"
,
"__dict__"
,
"__dict__"
,
"_is_top"
,
]
]
def
__init__
(
self
,
mod
,
is_top_module
=
False
):
def
__init__
(
self
,
mod
,
is_top_module
=
False
):
...
@@ -1301,22 +1357,18 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -1301,22 +1357,18 @@ class TracedModuleBuilder(NodeMixin):
qat_module
.
weight_fake_quant
.
set_qparams
(
qparams
)
qat_module
.
weight_fake_quant
.
set_qparams
(
qparams
)
def
build
(
self
):
def
build
(
self
):
if
self
.
_is_builtin
or
isinstance
(
self
.
_mod
,
TracedModule
):
if
self
.
_is_builtin
:
if
module_tracer
.
is_builtin
(
self
.
_mod
)
or
isinstance
(
assert
module_tracer
.
is_builtin
(
self
.
_mod
)
self
.
_mod
,
TracedModule
mod_type
=
type
(
self
.
_mod
)
):
mod_type
=
type
(
self
.
_mod
)
else
:
assert
isinstance
(
self
.
_mod
,
(
Observer
,
_FakeQuantize
))
mod_type
=
(
Observer
if
isinstance
(
self
.
_mod
,
Observer
)
else
_FakeQuantize
)
for
node
in
self
.
nodes
:
for
node
in
self
.
nodes
:
node
.
module_type
=
mod_type
node
.
module_type
=
mod_type
return
self
.
_mod
return
self
.
_mod
else
:
else
:
is_qat
=
isinstance
(
self
.
_mod
,
QATModule
)
is_qat
=
isinstance
(
self
.
_mod
,
QATModule
)
or
(
isinstance
(
self
.
_mod
,
TracedModule
)
and
self
.
_mod
.
is_qat
)
traced_module
=
TracedModule
(
traced_module
=
TracedModule
(
self
.
_is_top
,
self
.
_argdef_graph_map
,
self
.
_argdef_outdef_map
,
is_qat
self
.
_is_top
,
self
.
_argdef_graph_map
,
self
.
_argdef_outdef_map
,
is_qat
)
)
...
@@ -1338,15 +1390,18 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -1338,15 +1390,18 @@ class TracedModuleBuilder(NodeMixin):
traced_module
.
with_act
=
self
.
_mod
.
with_act
traced_module
.
with_act
=
self
.
_mod
.
with_act
traced_module
.
with_weight
=
self
.
_mod
.
with_weight
traced_module
.
with_weight
=
self
.
_mod
.
with_weight
if
not
hasattr
(
traced_module
,
"act_fake_quant"
):
if
not
hasattr
(
traced_module
,
"act_fake_quant"
):
traced_module
.
act_fakequant
=
None
traced_module
.
act_fake
_
quant
=
None
if
not
hasattr
(
traced_module
,
"act_observer"
):
if
not
hasattr
(
traced_module
,
"act_observer"
):
traced_module
.
act_observer
=
None
traced_module
.
act_observer
=
None
if
not
hasattr
(
traced_module
,
"weight_fake_quant"
):
if
not
hasattr
(
traced_module
,
"weight_fake_quant"
):
traced_module
.
weight_fakequant
=
None
traced_module
.
weight_fake
_
quant
=
None
if
not
hasattr
(
traced_module
,
"weight_observer"
):
if
not
hasattr
(
traced_module
,
"weight_observer"
):
traced_module
.
weight_observer
=
None
traced_module
.
weight_observer
=
None
set_module_tracing
()
set_module_tracing
()
if
self
.
_is_top
:
traced_module
.
_update_ref
()
return
traced_module
return
traced_module
def
_record_wrapped_nodes
(
self
,
node
):
def
_record_wrapped_nodes
(
self
,
node
):
...
@@ -1357,6 +1412,7 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -1357,6 +1412,7 @@ class TracedModuleBuilder(NodeMixin):
# prepare args and kwargs for inner graph
# prepare args and kwargs for inner graph
if
"method_func"
in
kwargs
:
if
"method_func"
in
kwargs
:
kwargs
.
pop
(
"method_func"
)
kwargs
.
pop
(
"method_func"
)
args
,
kwargs
=
_convert_kwargs_to_args
(
self
.
_mod
.
forward
,
args
,
kwargs
,
True
)
def
mark_constant
(
x
):
def
mark_constant
(
x
):
node
=
NodeMixin
.
get
(
x
,
None
)
node
=
NodeMixin
.
get
(
x
,
None
)
...
@@ -1372,11 +1428,7 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -1372,11 +1428,7 @@ class TracedModuleBuilder(NodeMixin):
callnode
.
arg_def
=
tree_def
callnode
.
arg_def
=
tree_def
if
(
if
self
.
_is_builtin
or
tree_def
in
self
.
_argdef_graph_map
:
self
.
_is_builtin
or
tree_def
in
self
.
_argdef_graph_map
or
isinstance
(
self
.
_mod
,
TracedModule
)
):
unset_module_tracing
()
unset_module_tracing
()
rst
=
self
.
_mod
(
*
args
,
**
kwargs
)
rst
=
self
.
_mod
(
*
args
,
**
kwargs
)
outputs
,
out_def
=
tree_flatten
(
rst
,
is_leaf
=
_is_leaf
)
outputs
,
out_def
=
tree_flatten
(
rst
,
is_leaf
=
_is_leaf
)
...
@@ -1385,33 +1437,7 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -1385,33 +1437,7 @@ class TracedModuleBuilder(NodeMixin):
self
.
_body
=
None
self
.
_body
=
None
elif
tree_def
in
self
.
_argdef_graph_map
:
elif
tree_def
in
self
.
_argdef_graph_map
:
self
.
_body
=
self
.
_argdef_graph_map
[
tree_def
]
self
.
_body
=
self
.
_argdef_graph_map
[
tree_def
]
else
:
self
.
_mod
.
_is_top
=
False
self
.
_body
=
self
.
_mod
.
argdef_graph_map
[
tree_def
]
module_qualname
=
NodeMixin
.
get
(
self
).
qualname
if
module_qualname
!=
self
.
_body
.
qualname
:
src_name
,
dst_name
=
self
.
_body
.
qualname
,
module_qualname
def
replace_qualname
(
g
):
attr_name
=
get_suffix_name
(
src_name
,
g
.
qualname
)
if
attr_name
is
not
None
:
g
.
_qualname
=
(
(
"%s.%s"
%
(
dst_name
,
attr_name
))
if
attr_name
else
dst_name
)
assert
get_suffix_name
(
dst_name
,
g
.
qualname
)
is
not
None
for
mod
in
self
.
_mod
.
modules
():
if
not
hasattr
(
mod
,
"argdef_graph_map"
):
continue
for
g
in
mod
.
argdef_graph_map
.
values
():
replace_qualname
(
g
)
g
.
_namespace
.
qualname
=
g
.
qualname
for
n
in
g
.
nodes
(
False
):
replace_qualname
(
n
)
else
:
else
:
self_node
=
None
orig_self
=
NodeMixin
.
get
(
self
)
orig_self
=
NodeMixin
.
get
(
self
)
parent_graph
=
active_module_tracer
().
current_scope
()
parent_graph
=
active_module_tracer
().
current_scope
()
module_qualname
=
orig_self
.
_qualname
module_qualname
=
orig_self
.
_qualname
...
@@ -1423,20 +1449,14 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -1423,20 +1449,14 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer
().
push_scope
(
self
.
_body
)
active_module_tracer
().
push_scope
(
self
.
_body
)
# rebind self to new input node
# rebind self to new input node
if
self_node
:
NodeMixin
.
wrap_safe
(
NodeMixin
.
wrap_safe
(
self
,
self_node
)
self
,
active_module_tracer
().
current_scope
().
_add_input
(
self_node
)
Input
.
make
(
else
:
name
=
"self"
,
NodeMixin
.
wrap_safe
(
qualname
=
module_qualname
,
self
,
type
=
NodeMixin
.
get_wrapped_type
(
self
),
self_node
),
if
self_node
)
else
Input
.
make
(
name
=
"self"
,
qualname
=
module_qualname
,
type
=
NodeMixin
.
get_wrapped_type
(
self
),
),
)
origin_inp_node
=
[
NodeMixin
.
get
(
i
,
None
)
for
i
in
inputs
[
1
:]]
origin_inp_node
=
[
NodeMixin
.
get
(
i
,
None
)
for
i
in
inputs
[
1
:]]
# prepare args and kwargs for inner graph
# prepare args and kwargs for inner graph
...
@@ -1470,8 +1490,23 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -1470,8 +1490,23 @@ class TracedModuleBuilder(NodeMixin):
return
x
return
x
args
=
[
self
]
args
=
[
self
]
for
i
,
v
in
enumerate
(
inputs
[
1
:]):
orig_traced_inputs
=
(
args
.
append
(
wrap
(
v
,
idx2key
[
i
+
1
]))
None
if
not
isinstance
(
self
.
_mod
,
TracedModule
)
else
self
.
_mod
.
argdef_graph_map
[
tree_def
].
inputs
)
ind
=
1
for
v
in
inputs
[
1
:]:
if
isinstance
(
v
,
(
RawTensor
,
NodeMixin
)):
args_name
=
(
orig_traced_inputs
[
ind
].
_name
if
orig_traced_inputs
else
idx2key
[
ind
]
)
ind
+=
1
args
.
append
(
wrap
(
v
,
args_name
))
else
:
args
.
append
(
v
)
args
,
kwargs
=
tree_def
.
unflatten
(
args
)
args
,
kwargs
=
tree_def
.
unflatten
(
args
)
active_module_tracer
().
patcher
.
auto_patch
(
active_module_tracer
().
patcher
.
auto_patch
(
...
@@ -1514,7 +1549,6 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -1514,7 +1549,6 @@ class TracedModuleBuilder(NodeMixin):
attr
=
getattr
(
type
(
self
.
_mod
),
name
).
__get__
(
self
,
type
(
self
))
attr
=
getattr
(
type
(
self
.
_mod
),
name
).
__get__
(
self
,
type
(
self
))
else
:
else
:
attr
=
getattr
(
self
.
_mod
,
name
)
attr
=
getattr
(
self
.
_mod
,
name
)
if
(
if
(
isinstance
(
attr
,
FunctionType
)
isinstance
(
attr
,
FunctionType
)
and
id
(
attr
)
in
active_module_tracer
().
patcher
.
patched_fn_ids
and
id
(
attr
)
in
active_module_tracer
().
patcher
.
patched_fn_ids
...
@@ -1568,7 +1602,7 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -1568,7 +1602,7 @@ class TracedModuleBuilder(NodeMixin):
wrapped
=
self
.
__getattr__
(
name
)
wrapped
=
self
.
__getattr__
(
name
)
if
isinstance
(
wrapped
,
TracedModuleBuilder
):
if
isinstance
(
wrapped
,
TracedModuleBuilder
):
if
not
isinstance
(
mod_attr
,
(
List
,
Dict
)):
if
not
isinstance
(
mod_attr
,
(
List
,
Dict
,
QATModule
)):
assert
mod_attr
is
wrapped
.
_mod
assert
mod_attr
is
wrapped
.
_mod
else
:
else
:
assert
mod_attr
is
wrapped
assert
mod_attr
is
wrapped
...
@@ -1977,8 +2011,6 @@ class TracedModule(Module):
...
@@ -1977,8 +2011,6 @@ class TracedModule(Module):
def
graph
(
self
)
->
InternalGraph
:
def
graph
(
self
)
->
InternalGraph
:
"""Return the ``InternalGraph`` of this ``TracedModule``.
"""Return the ``InternalGraph`` of this ``TracedModule``.
"""
"""
if
self
.
_is_top
:
self
.
_update_ref
()
assert
len
(
self
.
argdef_graph_map
)
==
1
assert
len
(
self
.
argdef_graph_map
)
==
1
return
list
(
self
.
argdef_graph_map
.
values
())[
0
]
return
list
(
self
.
argdef_graph_map
.
values
())[
0
]
...
@@ -2112,7 +2144,7 @@ class TracedModule(Module):
...
@@ -2112,7 +2144,7 @@ class TracedModule(Module):
if
hasattr
(
obj
,
"argdef_graph_map"
)
if
hasattr
(
obj
,
"argdef_graph_map"
)
else
None
else
None
)
)
if
expr_graph
is
not
None
:
if
expr_graph
is
not
None
and
not
obj
.
is_qat
:
exprs
=
_flatten_subgraph
(
graph
,
expr_graph
,
expr
,
obj
)
exprs
=
_flatten_subgraph
(
graph
,
expr_graph
,
expr
,
obj
)
if
parent_graph
is
not
None
:
if
parent_graph
is
not
None
:
...
@@ -2137,26 +2169,119 @@ class TracedModule(Module):
...
@@ -2137,26 +2169,119 @@ class TracedModule(Module):
)
)
new_module
.
graph
.
_re_associate_name
()
new_module
.
graph
.
_re_associate_name
()
new_module
.
graph
.
compile
()
new_module
.
graph
.
compile
()
new_module
.
_update_ref
()
new_module
.
graph
.
_reset_ids
()
new_module
.
graph
.
_reset_ids
()
return
new_module
return
new_module
def
__getstate__
(
self
):
def
__getstate__
(
self
):
d
=
self
.
__dict__
d
=
self
.
__dict__
.
copy
()
for
k
in
Module
.
__dict__
:
for
k
in
Module
.
__dict__
:
d
.
pop
(
k
,
None
)
d
.
pop
(
k
,
None
)
_check_obj_attr
(
d
)
for
k
in
d
:
if
module_tracer
.
is_builtin
(
d
[
k
]):
assert
_check_builtin_module_attr
(
d
[
k
]
),
"Module {} can not be serialized. "
.
format
(
type
(
d
[
k
]))
d
[
k
]
=
_ModuleState
.
get_module_state
(
d
[
k
])
dump_info
=
{
"version"
:
__version__
,
"register_type"
:
USER_REGISTERED_LEAF_TYPE
,
"register_container_type"
:
USER_REGISTERED_CONTAINER_TYPE
,
"register_mdule"
:
USER_REGISTERED_MODULE
,
"register_function"
:
USER_REGISTERED_FUNCTION
,
}
d
[
"dump_info"
]
=
dump_info
return
d
return
d
def
__setstate__
(
self
,
state
):
for
k
,
v
in
state
.
items
():
if
isinstance
(
v
,
_ModuleState
):
state
[
k
]
=
v
.
to_module
()
self
.
__dict__
.
update
(
state
)
self
.
_update_ref
()
for
_
,
graph
in
self
.
argdef_graph_map
.
items
():
for
expr
in
graph
.
_exprs
:
if
isinstance
(
expr
,
CallFunction
):
load_functional
(
expr
)
if
isinstance
(
expr
,
CallMethod
):
if
expr
.
method
==
"__call__"
:
load_call_module_expr
(
expr
)
else
:
load_call_tensor_method_expr
(
expr
)
if
isinstance
(
expr
,
Apply
):
load_apply_expr
(
expr
)
for
_
,
graph
in
self
.
argdef_graph_map
.
items
():
ind
=
0
while
ind
<
len
(
graph
.
_exprs
):
cur_expr
=
graph
.
_exprs
[
ind
]
has_new_expr
=
False
for
i
in
cur_expr
.
inputs
:
if
i
.
expr
not
in
graph
.
_exprs
and
not
isinstance
(
i
.
expr
,
Input
):
graph
.
_exprs
.
insert
(
ind
,
i
.
expr
)
has_new_expr
=
True
if
not
has_new_expr
:
ind
+=
1
for
expr
in
graph
.
_exprs
:
for
i
in
expr
.
inputs
:
if
expr
.
inputs
.
count
(
i
)
!=
i
.
users
.
count
(
expr
):
add_or_del_count
=
expr
.
inputs
.
count
(
i
)
-
i
.
users
.
count
(
expr
)
if
add_or_del_count
>
0
:
i
.
users
.
extend
([
expr
]
*
add_or_del_count
)
else
:
[
i
.
users
.
remove
(
expr
)
for
i
in
range
(
-
add_or_del_count
)]
for
o
in
expr
.
outputs
:
if
o
.
expr
is
not
expr
:
assert
o
not
in
o
.
expr
.
outputs
o
.
expr
=
expr
for
node
in
graph
.
nodes
(
False
):
# remove users of node which doesn't use node as input
node
.
users
=
[
e
for
e
in
node
.
users
if
node
in
e
.
inputs
]
for
expr
in
graph
.
_exprs
:
graph
.
_namespace
.
auto_naming_for_outputs
(
expr
)
self
.
_update_ref
()
for
_
,
graph
in
self
.
argdef_graph_map
.
items
():
graph
.
_reset_ids
()
def
__copy__
(
self
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
result
.
__dict__
.
update
(
self
.
__dict__
)
return
result
def
__deepcopy__
(
self
,
memo
):
cls
=
self
.
__class__
result
=
cls
.
__new__
(
cls
)
state
=
{}
memo
[
id
(
self
)]
=
result
for
k
,
v
in
self
.
__dict__
.
items
():
if
not
isinstance
(
v
,
weakref
.
ReferenceType
):
state
[
k
]
=
copy
.
deepcopy
(
v
,
memo
)
result
.
__dict__
.
update
(
state
)
result
.
_update_ref
()
return
result
def
cpp_apply_module_trace
(
opdef
,
*
args
):
def
cpp_apply_module_trace
(
opdef
,
*
args
):
return
Apply
.
apply_module_trace_hook
(
opdef
,
*
args
)
return
Apply
.
apply_module_trace_hook
(
opdef
,
*
args
)
USER_REGISTERED_MODULE
=
[]
USER_REGISTERED_FUNCTION
=
[]
def
register_as_builtin
(
mod_cls
:
Type
[
Module
])
->
None
:
def
register_as_builtin
(
mod_cls
:
Type
[
Module
])
->
None
:
r
"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module.
r
"""Registers class ``mod_cls`` (subclass of :class:`~.Module`) as builtin module.
Args:
Args:
mod_cls: the module class which will be treated as builtin module in tracing.
mod_cls: the module class which will be treated as builtin module in tracing.
"""
"""
USER_REGISTERED_MODULE
.
append
((
mod_cls
.
__module__
,
mod_cls
.
__qualname__
))
module_tracer
.
register_as_builtin
(
mod_cls
)
module_tracer
.
register_as_builtin
(
mod_cls
)
...
@@ -2181,6 +2306,7 @@ def wrap(func: Callable):
...
@@ -2181,6 +2306,7 @@ def wrap(func: Callable):
Args:
Args:
func: the function of the global function to insert into the graph when it's called.
func: the function of the global function to insert into the graph when it's called.
"""
"""
USER_REGISTERED_FUNCTION
.
append
((
func
.
__module__
,
func
.
__qualname__
))
assert
callable
(
func
),
"func must be a callable"
assert
callable
(
func
),
"func must be a callable"
assert
hasattr
(
func
,
"__code__"
)
assert
hasattr
(
func
,
"__code__"
)
fn_name
=
func
.
__code__
.
co_name
fn_name
=
func
.
__code__
.
co_name
...
@@ -2247,6 +2373,8 @@ def trace_module(
...
@@ -2247,6 +2373,8 @@ def trace_module(
NodeMixin
.
wrap_safe
(
NodeMixin
.
wrap_safe
(
builder
,
Input
.
make
(
name
=
"top"
,
type
=
ModuleNode
,
qualname
=
net_name
)
builder
,
Input
.
make
(
name
=
"top"
,
type
=
ModuleNode
,
qualname
=
net_name
)
)
)
args
,
kwargs
=
_convert_kwargs_to_args
(
mod
.
forward
,
args
,
kwargs
,
True
)
inputs
,
_
=
tree_flatten
((
args
,
kwargs
))
inputs
,
_
=
tree_flatten
((
args
,
kwargs
))
for
_
,
i
in
enumerate
(
inputs
):
for
_
,
i
in
enumerate
(
inputs
):
# assert isinstance(i, Tensor), "not support "
# assert isinstance(i, Tensor), "not support "
...
...
imperative/python/megengine/traced_module/utils.py
浏览文件 @
7b19bc76
...
@@ -5,12 +5,17 @@
...
@@ -5,12 +5,17 @@
# Unless required by applicable law or agreed to in writing,
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
collections
import
copy
import
copy
import
inspect
from
collections.abc
import
MutableMapping
,
MutableSequence
from
collections.abc
import
MutableMapping
,
MutableSequence
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Sequence
,
Type
from
..
import
get_logger
from
..module
import
Module
from
..module
import
Module
logger
=
get_logger
(
__name__
)
def
replace_container_with_module_container
(
container
):
def
replace_container_with_module_container
(
container
):
has_module
=
False
has_module
=
False
...
@@ -52,6 +57,101 @@ def replace_container_with_module_container(container):
...
@@ -52,6 +57,101 @@ def replace_container_with_module_container(container):
return
has_module
,
module_container
return
has_module
,
module_container
def
_convert_kwargs_to_args
(
func
,
args
,
kwargs
,
is_bounded
=
False
):
# is_bounded = True when func is a method and provided args don't include 'self'
arg_specs
=
inspect
.
getfullargspec
(
func
)
arg_specs_args
=
arg_specs
.
args
if
is_bounded
:
arg_specs_args
=
arg_specs
.
args
[
1
:]
new_args
=
[]
new_kwargs
=
{}
new_args
.
extend
(
args
)
if
set
(
arg_specs_args
[
0
:
len
(
new_args
)])
&
set
(
kwargs
.
keys
()):
repeated_arg_name
=
set
(
arg_specs_args
[
0
:
len
(
new_args
)])
&
set
(
kwargs
.
keys
())
raise
TypeError
(
"{} got multiple values for argument {}"
.
format
(
func
.
__qualname__
,
", "
.
join
(
repeated_arg_name
)
)
)
if
len
(
new_args
)
<
len
(
arg_specs
.
args
):
for
ind
in
range
(
len
(
new_args
),
len
(
arg_specs_args
)):
arg_name
=
arg_specs_args
[
ind
]
if
arg_name
in
kwargs
:
new_args
.
append
(
kwargs
[
arg_name
])
else
:
index
=
ind
-
len
(
arg_specs_args
)
+
len
(
arg_specs
.
defaults
)
assert
index
<
len
(
arg_specs
.
defaults
)
and
index
>=
0
new_args
.
append
(
arg_specs
.
defaults
[
index
])
for
kwarg_name
in
arg_specs
.
kwonlyargs
:
if
kwarg_name
in
kwargs
:
new_kwargs
[
kwarg_name
]
=
kwargs
[
kwarg_name
]
else
:
assert
kwarg_name
in
arg_specs
.
kwonlydefaults
new_kwargs
[
kwarg_name
]
=
arg_specs
.
kwonlydefaults
[
kwarg_name
]
for
k
,
v
in
kwargs
.
items
():
if
k
not
in
arg_specs
.
args
and
k
not
in
arg_specs
.
kwonlyargs
:
if
arg_specs
.
varkw
is
None
:
raise
TypeError
(
"{} got an unexpected keyword argument {}"
.
format
(
func
.
__qualname__
,
k
)
)
new_kwargs
[
k
]
=
v
return
tuple
(
new_args
),
new_kwargs
def
_check_obj_attr
(
obj
):
# check if all the attributes of a obj is serializable
from
.pytree
import
tree_flatten
from
.pytree
import
SUPPORTED_LEAF_CLS
,
SUPPORTED_LEAF_TYPE
,
TreeDef
from
.expr
import
Expr
from
.traced_module
import
TracedModule
,
InternalGraph
,
NameSpace
def
_check_leaf_type
(
leaf
):
leaf_type
=
leaf
if
isinstance
(
leaf
,
type
)
else
type
(
leaf
)
traced_module_types
=
[
Expr
,
TreeDef
,
TracedModule
,
InternalGraph
,
NameSpace
]
return
(
issubclass
(
leaf_type
,
tuple
(
SUPPORTED_LEAF_CLS
+
traced_module_types
))
or
leaf_type
in
SUPPORTED_LEAF_TYPE
)
for
_
,
v
in
obj
.
items
():
leafs
,
_
=
tree_flatten
(
v
,
is_leaf
=
lambda
_
:
True
)
for
leaf
in
leafs
:
assert
_check_leaf_type
(
leaf
),
"Type {} is not supported by traced module"
.
format
(
leaf
if
isinstance
(
leaf
,
type
)
else
type
(
leaf
)
)
def
_check_builtin_module_attr
(
mod
):
from
.pytree
import
_is_leaf
as
_check_leaf_type
from
.pytree
import
tree_flatten
# check if all the attributes of a builtin module is serializable
is_non_serializable_module
=
lambda
m
:
isinstance
(
m
,
Module
)
and
not
_check_builtin_module_attr
(
m
)
for
k
,
v
in
mod
.
__dict__
.
items
():
if
k
==
"_m_dump_modulestate"
:
continue
if
is_non_serializable_module
(
v
):
return
False
elif
not
isinstance
(
v
,
Module
):
leafs
,
_
=
tree_flatten
(
v
,
is_leaf
=
lambda
_
:
True
)
for
leaf
in
leafs
:
if
not
_check_leaf_type
(
leaf
)
or
is_non_serializable_module
(
leaf
):
logger
.
warn
(
"Type {} is not supported by traced module"
.
format
(
leaf
if
isinstance
(
leaf
,
type
)
else
type
(
leaf
)
)
)
return
False
return
True
class
_ModuleList
(
Module
,
MutableSequence
):
class
_ModuleList
(
Module
,
MutableSequence
):
r
"""A List-like container.
r
"""A List-like container.
...
...
imperative/python/test/unit/core/test_serialization.py
浏览文件 @
7b19bc76
...
@@ -15,7 +15,6 @@ import numpy as np
...
@@ -15,7 +15,6 @@ import numpy as np
import
megengine
as
mge
import
megengine
as
mge
from
megengine
import
Parameter
,
Tensor
from
megengine
import
Parameter
,
Tensor
from
megengine.core.ops
import
builtin
from
megengine.core.ops
import
builtin
from
megengine.traced_module.serialization
import
get_opdef_state
,
load_opdef_from_state
def
test_tensor_serialization
():
def
test_tensor_serialization
():
...
@@ -88,25 +87,3 @@ def test_compatibility():
...
@@ -88,25 +87,3 @@ def test_compatibility():
test_old_tensor
(
"tensor_v1_1.mge"
)
test_old_tensor
(
"tensor_v1_1.mge"
)
test_old_tensor
(
"tensor_v1_2.mge"
)
test_old_tensor
(
"tensor_v1_2.mge"
)
def
test_opdef_serialization
():
with
TemporaryFile
()
as
f
:
x
=
builtin
.
Elemwise
(
mode
=
"Add"
)
pickle
.
dump
(
get_opdef_state
(
x
),
f
)
f
.
seek
(
0
)
load_x
=
load_opdef_from_state
(
pickle
.
load
(
f
))
assert
x
==
load_x
with
TemporaryFile
()
as
f
:
x
=
builtin
.
Convolution
(
stride_h
=
9
,
compute_mode
=
"float32"
)
x
.
strategy
=
(
builtin
.
Convolution
.
Strategy
.
PROFILE
|
builtin
.
Convolution
.
Strategy
.
HEURISTIC
|
builtin
.
Convolution
.
Strategy
.
REPRODUCIBLE
)
pickle
.
dump
(
get_opdef_state
(
x
),
f
)
f
.
seek
(
0
)
load_x
=
load_opdef_from_state
(
pickle
.
load
(
f
))
assert
x
.
strategy
==
load_x
.
strategy
assert
x
==
load_x
imperative/python/test/unit/traced_module/test_modification.py
浏览文件 @
7b19bc76
...
@@ -85,12 +85,12 @@ class NewModule(M.Module):
...
@@ -85,12 +85,12 @@ class NewModule(M.Module):
return
x
return
x
def
_check_expr_users
(
trac
ed_module
):
def
_check_expr_users
(
flatten
ed_module
):
node_user
=
defaultdict
(
list
)
node_user
=
defaultdict
(
list
)
for
expr
in
trac
ed_module
.
graph
.
_exprs
:
for
expr
in
flatten
ed_module
.
graph
.
_exprs
:
for
node
in
expr
.
inputs
:
for
node
in
expr
.
inputs
:
node_user
[
node
].
append
(
expr
)
node_user
[
node
].
append
(
expr
)
for
node
in
trac
ed_module
.
graph
.
nodes
():
for
node
in
flatten
ed_module
.
graph
.
nodes
():
node
.
users
.
sort
(
key
=
lambda
m
:
m
.
_id
)
node
.
users
.
sort
(
key
=
lambda
m
:
m
.
_id
)
node_user
[
node
].
sort
(
key
=
lambda
m
:
m
.
_id
)
node_user
[
node
].
sort
(
key
=
lambda
m
:
m
.
_id
)
assert
node
.
users
==
node_user
[
node
]
assert
node
.
users
==
node_user
[
node
]
...
...
imperative/python/test/unit/traced_module/test_qat_module.py
浏览文件 @
7b19bc76
...
@@ -8,6 +8,7 @@ import numpy as np
...
@@ -8,6 +8,7 @@ import numpy as np
import
megengine
as
mge
import
megengine
as
mge
import
megengine.functional
as
F
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.module
as
M
import
megengine.module.qat
as
QM
import
megengine.quantization
as
Q
import
megengine.quantization
as
Q
from
megengine
import
Tensor
from
megengine
import
Tensor
from
megengine.module.qat.module
import
QATModule
from
megengine.module.qat.module
import
QATModule
...
@@ -28,10 +29,18 @@ def get_subattr(self: M.Module, name: str):
...
@@ -28,10 +29,18 @@ def get_subattr(self: M.Module, name: str):
return
getattr
(
self
,
name
)
return
getattr
(
self
,
name
)
class
MyConvBnRelu2d
(
M
.
ConvBnRelu2d
):
pass
class
MyQATConvBnRelu2d
(
QM
.
ConvBnRelu2d
):
pass
class
Myblcok
(
M
.
Module
):
class
Myblcok
(
M
.
Module
):
def
__init__
(
self
,):
def
__init__
(
self
,):
super
().
__init__
()
super
().
__init__
()
self
.
conv0
=
M
.
ConvBnRelu2d
(
3
,
3
,
3
,
1
,
1
)
self
.
conv0
=
M
y
ConvBnRelu2d
(
3
,
3
,
3
,
1
,
1
)
self
.
conv1
=
M
.
ConvBn2d
(
3
,
3
,
1
,
1
,
0
)
self
.
conv1
=
M
.
ConvBn2d
(
3
,
3
,
1
,
1
,
0
)
self
.
conv2
=
M
.
ConvBn2d
(
3
,
3
,
1
,
1
,
0
)
self
.
conv2
=
M
.
ConvBn2d
(
3
,
3
,
1
,
1
,
0
)
self
.
add
=
M
.
Elemwise
(
"FUSE_ADD_RELU"
)
self
.
add
=
M
.
Elemwise
(
"FUSE_ADD_RELU"
)
...
@@ -106,7 +115,11 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams):
...
@@ -106,7 +115,11 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams):
def
build_observered_net
(
net
:
M
.
Module
,
observer_cls
):
def
build_observered_net
(
net
:
M
.
Module
,
observer_cls
):
qat_net
=
Q
.
quantize_qat
(
net
,
qconfig
=
get_observer_config
(
observer_cls
))
qat_net
=
Q
.
quantize_qat
(
net
,
qconfig
=
get_observer_config
(
observer_cls
),
mapping
=
{
MyConvBnRelu2d
:
MyQATConvBnRelu2d
},
)
Q
.
enable_observer
(
qat_net
)
Q
.
enable_observer
(
qat_net
)
inp
=
Tensor
(
np
.
random
.
random
(
size
=
(
5
,
3
,
32
,
32
)))
inp
=
Tensor
(
np
.
random
.
random
(
size
=
(
5
,
3
,
32
,
32
)))
qat_net
(
inp
)
qat_net
(
inp
)
...
@@ -134,6 +147,15 @@ def test_trace_qat():
...
@@ -134,6 +147,15 @@ def test_trace_qat():
check_qparams
(
weight_qparams
,
traced_weight_qparams
)
check_qparams
(
weight_qparams
,
traced_weight_qparams
)
if
act_qparams
:
if
act_qparams
:
check_qparams
(
act_qparams
,
traced_act_qparams
)
check_qparams
(
act_qparams
,
traced_act_qparams
)
flatten_traced_net
=
traced_net
.
flatten
()
conv0_node
=
flatten_traced_net
.
graph
.
get_node_by_name
(
"MyModule_block0_conv0"
).
as_unique
()
conv0_out_node
=
flatten_traced_net
.
graph
.
get_node_by_name
(
"MyModule_block0_conv0_out"
).
as_unique
()
assert
isinstance
(
conv0_node
.
owner
,
TracedModule
)
assert
conv0_out_node
.
expr
.
inputs
[
0
]
is
conv0_node
_check_qat_module
(
build_observered_net
(
MyModule
(),
Q
.
MinMaxObserver
))
_check_qat_module
(
build_observered_net
(
MyModule
(),
Q
.
MinMaxObserver
))
_check_qat_module
(
build_observered_net
(
MyModule
(),
MyMinMaxObserver
))
_check_qat_module
(
build_observered_net
(
MyModule
(),
MyMinMaxObserver
))
...
...
imperative/python/test/unit/traced_module/test_serialization.py
浏览文件 @
7b19bc76
...
@@ -6,14 +6,59 @@
...
@@ -6,14 +6,59 @@
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
pickle
import
pickle
from
collections
import
defaultdict
from
tempfile
import
TemporaryFile
import
numpy
as
np
import
numpy
as
np
import
megengine.functional
as
F
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.module
as
M
import
megengine.traced_module.serialization
as
S
from
megengine
import
Tensor
from
megengine
import
Tensor
from
megengine.core._imperative_rt.core2
import
apply
from
megengine.core.ops
import
builtin
from
megengine.core.ops.builtin
import
Elemwise
from
megengine.module
import
Module
from
megengine.module
import
Module
from
megengine.traced_module
import
trace_module
from
megengine.traced_module
import
trace_module
from
megengine.traced_module.expr
import
CallMethod
,
Constant
from
megengine.traced_module.node
import
TensorNode
from
megengine.traced_module.serialization
import
(
register_functional_loader
,
register_module_loader
,
register_opdef_loader
,
register_tensor_method_loader
,
)
from
megengine.traced_module.utils
import
_convert_kwargs_to_args
def
_check_id
(
traced_module
):
_total_ids
=
traced_module
.
graph
.
_total_ids
node_ids
=
[
n
.
_id
for
n
in
traced_module
.
graph
.
nodes
().
as_list
()]
assert
len
(
set
(
node_ids
))
==
len
(
node_ids
)
assert
max
(
node_ids
)
+
1
==
_total_ids
[
0
]
expr_ids
=
[
n
.
_id
for
n
in
traced_module
.
graph
.
exprs
().
as_list
()]
assert
len
(
set
(
expr_ids
))
==
len
(
expr_ids
)
assert
max
(
expr_ids
)
+
1
==
_total_ids
[
1
]
def
_check_name
(
flatened_module
):
node_names
=
[
n
.
_name
for
n
in
flatened_module
.
graph
.
nodes
().
as_list
()]
assert
len
(
set
(
node_names
))
==
len
(
node_names
)
def
_check_expr_users
(
traced_module
):
node_user
=
defaultdict
(
list
)
for
expr
in
traced_module
.
graph
.
_exprs
:
for
node
in
expr
.
inputs
:
node_user
[
node
].
append
(
expr
)
if
isinstance
(
expr
,
CallMethod
)
and
expr
.
graph
:
_check_expr_users
(
expr
.
inputs
[
0
].
owner
)
for
node
in
traced_module
.
graph
.
nodes
(
False
):
node
.
users
.
sort
(
key
=
lambda
m
:
m
.
_id
)
node_user
[
node
].
sort
(
key
=
lambda
m
:
m
.
_id
)
assert
node
.
users
==
node_user
[
node
]
class
MyBlock
(
Module
):
class
MyBlock
(
Module
):
...
@@ -48,5 +93,274 @@ def test_dump_and_load():
...
@@ -48,5 +93,274 @@ def test_dump_and_load():
traced_module
=
trace_module
(
module
,
x
)
traced_module
=
trace_module
(
module
,
x
)
np
.
testing
.
assert_array_equal
(
expect
,
traced_module
(
x
))
np
.
testing
.
assert_array_equal
(
expect
,
traced_module
(
x
))
obj
=
pickle
.
dumps
(
traced_module
)
obj
=
pickle
.
dumps
(
traced_module
)
pickle
.
loads
(
obj
)
new_tm
=
pickle
.
loads
(
obj
)
_check_id
(
new_tm
)
_check_expr_users
(
new_tm
)
traced_module
.
graph
.
_reset_ids
()
old_nodes
=
traced_module
.
graph
.
nodes
().
as_list
()
new_nodes
=
new_tm
.
graph
.
nodes
().
as_list
()
old_exprs
=
traced_module
.
graph
.
exprs
().
as_list
()
new_exprs
=
new_tm
.
graph
.
exprs
().
as_list
()
assert
len
(
old_nodes
)
==
len
(
new_nodes
)
for
i
,
j
in
zip
(
old_nodes
,
new_nodes
):
assert
i
.
_name
==
j
.
_name
assert
i
.
_qualname
==
j
.
_qualname
assert
i
.
_id
==
j
.
_id
assert
len
(
old_exprs
)
==
len
(
new_exprs
)
for
i
,
j
in
zip
(
old_exprs
,
new_exprs
):
assert
i
.
_id
==
j
.
_id
np
.
testing
.
assert_array_equal
(
expect
,
traced_module
(
x
))
np
.
testing
.
assert_array_equal
(
expect
,
traced_module
(
x
))
def
test_opdef_loader
():
class
MyModule1
(
Module
):
def
forward
(
self
,
x
,
y
):
op
=
Elemwise
(
"ADD"
)
return
apply
(
op
,
x
,
y
)[
0
]
m
=
MyModule1
()
x
=
Tensor
(
np
.
ones
((
20
)))
y
=
Tensor
(
np
.
ones
((
20
)))
traced_module
=
trace_module
(
m
,
x
,
y
)
orig_loader_dict
=
S
.
OPDEF_LOADER
S
.
OPDEF_LOADER
=
{}
@
register_opdef_loader
(
Elemwise
)
def
add_opdef_loader
(
expr
):
if
expr
.
opdef_state
[
"mode"
]
==
"ADD"
:
expr
.
opdef_state
[
"mode"
]
=
"MUL"
node
=
expr
.
inputs
[
1
]
astype_expr
=
CallMethod
(
node
,
"astype"
)
oup
=
TensorNode
(
astype_expr
,
shape
=
node
.
shape
,
dtype
=
expr
.
inputs
[
0
].
dtype
,
qparams
=
node
.
qparams
,
)
astype_expr
.
set_args_kwargs
(
node
,
expr
.
inputs
[
0
].
dtype
)
astype_expr
.
return_val
=
(
oup
,)
expr
.
inputs
[
1
]
=
oup
obj
=
pickle
.
dumps
(
traced_module
)
new_module
=
pickle
.
loads
(
obj
)
_check_id
(
new_module
)
_check_expr_users
(
new_module
)
_check_name
(
new_module
.
flatten
())
assert
(
isinstance
(
new_module
.
graph
.
_exprs
[
0
],
CallMethod
)
and
new_module
.
graph
.
_exprs
[
1
].
opdef
.
mode
==
"MUL"
and
len
(
new_module
.
graph
.
_exprs
)
==
2
)
result
=
new_module
(
x
,
y
)
np
.
testing
.
assert_equal
(
result
.
numpy
(),
x
.
numpy
())
S
.
OPDEF_LOADER
=
orig_loader_dict
def
test_functional_loader
():
class
MyModule2
(
Module
):
def
forward
(
self
,
x
,
y
):
return
F
.
conv2d
(
x
,
y
)
m
=
MyModule2
()
x
=
Tensor
(
np
.
random
.
random
((
1
,
3
,
32
,
32
)))
y
=
Tensor
(
np
.
random
.
random
((
3
,
3
,
3
,
3
)))
traced_module
=
trace_module
(
m
,
x
,
y
)
orig_loader_dict
=
S
.
FUNCTIONAL_LOADER
S
.
FUNCTIONAL_LOADER
=
{}
@
register_functional_loader
((
"megengine.functional.nn"
,
"conv2d"
))
def
conv2df_loader
(
expr
):
# expr.func = ("megengine.functional.nn","conv2d")
kwargs
=
expr
.
kwargs
orig_weight
=
expr
.
named_args
[
"weight"
]
astype_expr
=
CallMethod
(
orig_weight
,
"astype"
)
oup
=
TensorNode
(
astype_expr
,
shape
=
orig_weight
.
shape
,
dtype
=
orig_weight
.
dtype
,
qparams
=
orig_weight
.
qparams
,
)
astype_expr
.
set_args_kwargs
(
orig_weight
,
expr
.
named_args
[
"inp"
].
dtype
)
astype_expr
.
return_val
=
(
oup
,)
expr
.
set_arg
(
"weight"
,
oup
)
obj
=
pickle
.
dumps
(
traced_module
)
new_module
=
pickle
.
loads
(
obj
)
_check_expr_users
(
new_module
)
_check_id
(
new_module
)
result
=
new_module
(
x
,
y
)
gt
=
m
(
x
,
y
)
assert
(
isinstance
(
new_module
.
graph
.
_exprs
[
0
],
CallMethod
)
and
len
(
new_module
.
graph
.
_exprs
)
==
2
)
np
.
testing
.
assert_equal
(
result
.
numpy
(),
gt
.
numpy
())
S
.
FUNCTIONAL_LOADER
=
orig_loader_dict
def
test_tensor_method_loader
():
class
MyModule3
(
Module
):
def
forward
(
self
,
x
):
return
x
+
1
m
=
MyModule3
()
x
=
Tensor
(
np
.
ones
((
20
)))
traced_module
=
trace_module
(
m
,
x
)
orig_loader_dict
=
S
.
TENSORMETHOD_LOADER
S
.
TENSORMETHOD_LOADER
=
{}
@
register_tensor_method_loader
(
"__add__"
)
def
add_loader
(
expr
):
args
=
list
(
expr
.
args
)
if
not
isinstance
(
args
[
1
],
TensorNode
):
args
[
1
]
=
Tensor
(
args
[
1
])
node
=
Constant
(
args
[
1
],
"const"
).
outputs
[
0
]
astype_expr
=
CallMethod
(
node
,
"astype"
)
oup
=
TensorNode
(
astype_expr
,
shape
=
node
.
shape
,
dtype
=
node
.
dtype
,
qparams
=
node
.
qparams
,
)
astype_expr
.
set_args_kwargs
(
node
,
expr
.
inputs
[
0
].
dtype
)
astype_expr
.
return_val
=
(
oup
,)
add_expr
=
CallMethod
(
oup
,
"__add__"
)
add_expr
.
set_args_kwargs
(
oup
,
oup
)
oup1
=
TensorNode
(
add_expr
,
shape
=
oup
.
shape
,
dtype
=
oup
.
dtype
,
qparams
=
node
.
qparams
,
)
add_expr
.
return_val
=
oup1
args
[
1
]
=
oup1
expr
.
set_args_kwargs
(
*
args
)
obj
=
pickle
.
dumps
(
traced_module
)
new_module
=
pickle
.
loads
(
obj
)
_check_expr_users
(
new_module
)
_check_id
(
new_module
)
result
=
new_module
(
x
)
gt
=
m
(
x
)
assert
(
isinstance
(
new_module
.
graph
.
_exprs
[
0
],
Constant
)
and
len
(
new_module
.
graph
.
_exprs
)
==
4
)
np
.
testing
.
assert_equal
(
result
.
numpy
(),
(
x
+
2
).
numpy
())
S
.
TENSORMETHOD_LOADER
=
orig_loader_dict
def
test_module_loader
():
class
MyModule4
(
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv
=
M
.
Conv2d
(
3
,
3
,
3
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
m
=
MyModule4
()
x
=
Tensor
(
np
.
random
.
random
((
1
,
3
,
32
,
32
)))
traced_module
=
trace_module
(
m
,
x
)
orig_loader_dict
=
S
.
MODULE_LOADER
S
.
MODULE_LOADER
=
{}
@
register_module_loader
((
"megengine.module.conv"
,
"Conv2d"
))
def
conv2dm_loader
(
expr
):
module
=
expr
.
inputs
[
0
].
owner
args
=
list
(
expr
.
args
)
orig_inp
=
args
[
1
]
astype_expr
=
CallMethod
(
orig_inp
,
"astype"
)
oup
=
TensorNode
(
astype_expr
,
shape
=
orig_inp
.
shape
,
dtype
=
orig_inp
.
dtype
,
qparams
=
orig_inp
.
qparams
,
)
astype_expr
.
set_args_kwargs
(
orig_inp
,
module
.
weight
.
dtype
)
astype_expr
.
return_val
=
(
oup
,)
args
[
1
]
=
oup
expr
.
set_args_kwargs
(
*
args
)
obj
=
pickle
.
dumps
(
traced_module
)
new_module
=
pickle
.
loads
(
obj
)
result
=
new_module
(
x
)
gt
=
m
(
x
)
assert
(
isinstance
(
new_module
.
graph
.
_exprs
[
1
],
CallMethod
)
and
len
(
new_module
.
graph
.
_exprs
)
==
3
)
np
.
testing
.
assert_equal
(
result
.
numpy
(),
gt
.
numpy
())
S
.
MODULE_LOADER
=
orig_loader_dict
def
test_shared_module
():
class
MyModule
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
a
=
M
.
Elemwise
(
"ADD"
)
self
.
b
=
self
.
a
def
forward
(
self
,
x
,
y
):
z
=
self
.
a
(
x
,
y
)
z
=
self
.
b
(
z
,
y
)
return
z
x
=
Tensor
(
1
)
y
=
Tensor
(
2
)
m
=
MyModule
()
tm
=
trace_module
(
m
,
x
,
y
)
obj
=
pickle
.
dumps
(
tm
)
load_tm
=
pickle
.
loads
(
obj
)
_check_expr_users
(
load_tm
)
_check_name
(
load_tm
.
flatten
())
_check_id
(
load_tm
)
assert
load_tm
.
a
is
load_tm
.
b
def
test_convert_kwargs_to_args
():
def
func
(
a
,
b
,
c
=
4
,
*
,
d
,
e
=
3
,
f
=
4
):
pass
args
=
(
1
,)
kwargs
=
{
"b"
:
1
,
"d"
:
6
}
new_args
,
new_kwargs
=
_convert_kwargs_to_args
(
func
,
args
,
kwargs
)
assert
new_args
==
(
1
,
1
,
4
)
assert
new_kwargs
==
{
"d"
:
6
,
"e"
:
3
,
"f"
:
4
}
args
=
(
1
,)
kwargs
=
{
"d"
:
6
}
new_args
,
new_kwargs
=
_convert_kwargs_to_args
(
func
,
args
,
kwargs
,
is_bounded
=
True
)
assert
new_args
==
(
1
,
4
)
assert
new_kwargs
==
{
"d"
:
6
,
"e"
:
3
,
"f"
:
4
}
def
func1
(
a
,
b
,
c
,
d
,
e
,
*
,
f
):
pass
args
=
()
kwargs
=
{
"a"
:
1
,
"b"
:
2
,
"c"
:
3
,
"d"
:
4
,
"e"
:
5
,
"f"
:
6
}
new_args
,
new_kwargs
=
_convert_kwargs_to_args
(
func1
,
args
,
kwargs
)
assert
new_args
==
(
1
,
2
,
3
,
4
,
5
)
assert
new_kwargs
==
{
"f"
:
6
}
def
test_opdef_serialization
():
with
TemporaryFile
()
as
f
:
x
=
builtin
.
Elemwise
(
mode
=
"Add"
)
pickle
.
dump
(
x
,
f
)
f
.
seek
(
0
)
load_x
=
pickle
.
load
(
f
)
assert
x
==
load_x
with
TemporaryFile
()
as
f
:
x
=
builtin
.
Convolution
(
stride_h
=
9
,
compute_mode
=
"float32"
)
x
.
strategy
=
(
builtin
.
Convolution
.
Strategy
.
PROFILE
|
builtin
.
Convolution
.
Strategy
.
HEURISTIC
|
builtin
.
Convolution
.
Strategy
.
REPRODUCIBLE
)
pickle
.
dump
(
x
,
f
)
f
.
seek
(
0
)
load_x
=
pickle
.
load
(
f
)
assert
x
.
strategy
==
load_x
.
strategy
assert
x
==
load_x
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录