Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b1c46ba4
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b1c46ba4
编写于
7月 08, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(traced_module): add some functions of graph modification
GitOrigin-RevId: ac0603057adaedf864f2d0ceb7bfb6d3c5a50640
上级
4bb25369
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
729 addition
and
73 deletion
+729
-73
imperative/python/megengine/experimental/traced_module/expr.py
...ative/python/megengine/experimental/traced_module/expr.py
+42
-27
imperative/python/megengine/experimental/traced_module/module_tracer.py
...hon/megengine/experimental/traced_module/module_tracer.py
+11
-2
imperative/python/megengine/experimental/traced_module/node.py
...ative/python/megengine/experimental/traced_module/node.py
+4
-1
imperative/python/megengine/experimental/traced_module/pytree.py
...ive/python/megengine/experimental/traced_module/pytree.py
+7
-2
imperative/python/megengine/experimental/traced_module/traced_module.py
...hon/megengine/experimental/traced_module/traced_module.py
+274
-41
imperative/python/test/unit/traced_module/test_haoruitao.py
imperative/python/test/unit/traced_module/test_haoruitao.py
+90
-0
imperative/python/test/unit/traced_module/test_modification.py
...ative/python/test/unit/traced_module/test_modification.py
+113
-0
imperative/python/test/unit/traced_module/test_serialization.py
...tive/python/test/unit/traced_module/test_serialization.py
+52
-0
imperative/python/test/unit/traced_module/test_trace_module.py
...ative/python/test/unit/traced_module/test_trace_module.py
+42
-0
imperative/python/test/unit/traced_module/test_wujianan.py
imperative/python/test/unit/traced_module/test_wujianan.py
+94
-0
未找到文件。
imperative/python/megengine/experimental/traced_module/expr.py
浏览文件 @
b1c46ba4
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
import
builtins
import
builtins
import
collections
import
collections
import
inspect
from
typing
import
Callable
,
List
from
typing
import
Callable
,
List
from
...core._imperative_rt
import
OpDef
from
...core._imperative_rt
import
OpDef
...
@@ -16,10 +17,10 @@ from ...core._imperative_rt.core2 import Tensor as RawTensor
...
@@ -16,10 +17,10 @@ from ...core._imperative_rt.core2 import Tensor as RawTensor
from
...core._imperative_rt.core2
import
apply
,
set_module_tracing
,
unset_module_tracing
from
...core._imperative_rt.core2
import
apply
,
set_module_tracing
,
unset_module_tracing
from
...core.ops.special
import
Const
from
...core.ops.special
import
Const
from
...module
import
Module
from
...module
import
Module
from
...tensor
import
Tensor
from
...tensor
import
Parameter
,
Tensor
from
.module_tracer
import
active_module_tracer
,
module_tracer
from
.module_tracer
import
active_module_tracer
,
module_tracer
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.pytree
import
TreeDef
from
.pytree
import
TreeDef
,
tree_flatten
class
Expr
:
class
Expr
:
...
@@ -38,25 +39,28 @@ class Expr:
...
@@ -38,25 +39,28 @@ class Expr:
for
val
in
vals
:
for
val
in
vals
:
node
=
NodeMixin
.
get
(
val
,
None
)
node
=
NodeMixin
.
get
(
val
,
None
)
if
isinstance
(
node
,
(
TensorNode
,
ModuleNode
)):
if
isinstance
(
node
,
(
TensorNode
,
ModuleNode
)):
if
node
not
in
self
.
inputs
:
self
.
inputs
.
append
(
node
)
self
.
inputs
.
append
(
node
)
node
.
users
.
append
(
self
)
else
:
else
:
assert
node
is
None
assert
node
is
None
assert
type
(
val
)
in
builtins
.
__dict__
.
values
()
idx
=
len
(
self
.
inputs
)
+
len
(
self
.
const_val
)
idx
=
len
(
self
.
inputs
)
+
len
(
self
.
const_val
)
self
.
const_val
.
append
((
idx
,
val
))
self
.
const_val
.
append
((
idx
,
val
))
def
add_outputs
(
self
,
outputs
):
def
add_outputs
(
self
,
outputs
,
check_inplace
=
True
):
self
.
outputs
=
[]
self
.
outputs
=
[]
if
not
isinstance
(
outputs
,
collections
.
Sequence
):
if
outputs
is
not
None
:
outputs
=
(
outputs
,)
if
not
isinstance
(
outputs
,
collections
.
Sequence
):
outputs
=
(
outputs
,)
for
i
in
outputs
:
for
i
in
outputs
:
assert
isinstance
(
i
,
RawTensor
)
assert
isinstance
(
i
,
RawTensor
)
self
.
outputs
.
append
(
NodeMixin
.
get_wrapped_type
(
i
)(
self
))
node
=
NodeMixin
.
get
(
i
,
None
)
if
check_inplace
else
None
self
.
outputs
.
append
(
node
if
node
else
NodeMixin
.
get_wrapped_type
(
i
)(
self
)
)
for
i
,
node
in
zip
(
outputs
,
self
.
outputs
,):
for
i
,
node
in
zip
(
outputs
,
self
.
outputs
,):
NodeMixin
.
wrap_safe
(
i
,
node
)
NodeMixin
.
wrap_safe
(
i
,
node
)
def
unflatten_args
(
self
,
inputs
):
def
unflatten_args
(
self
,
inputs
):
if
self
.
arg_def
is
not
None
:
if
self
.
arg_def
is
not
None
:
...
@@ -110,6 +114,7 @@ class GetAttr(Expr):
...
@@ -110,6 +114,7 @@ class GetAttr(Expr):
self
.
inputs
=
[
self
.
inputs
=
[
module
,
module
,
]
]
module
.
users
.
append
(
self
)
self
.
name
=
name
self
.
name
=
name
node_cls
=
type
if
type
else
Node
node_cls
=
type
if
type
else
Node
self
.
outputs
=
[
self
.
outputs
=
[
...
@@ -134,12 +139,20 @@ class GetAttr(Expr):
...
@@ -134,12 +139,20 @@ class GetAttr(Expr):
# expr: outputs = inputs[0].__call__(*inputs[1:])
# expr: outputs = inputs[0].__call__(*inputs[1:])
class
CallMethod
(
Expr
):
class
CallMethod
(
Expr
):
def
__init__
(
self
,
module
,
method
=
"__call__"
):
def
__init__
(
self
,
node
,
method
=
"__call__"
):
assert
isinstance
(
module
,
(
TensorNode
,
ModuleNode
))
if
isinstance
(
node
,
type
):
self
.
inputs
=
[
assert
issubclass
(
node
,
Tensor
)
module
,
cls
=
Parameter
if
issubclass
(
node
,
Parameter
)
else
Tensor
]
self
.
const_val
=
[]
self
.
inputs
=
[]
self
.
const_val
=
[(
0
,
cls
)]
else
:
assert
isinstance
(
node
,
(
TensorNode
,
ModuleNode
))
node
.
users
.
append
(
self
)
self
.
inputs
=
[
node
,
]
self
.
const_val
=
[]
self
.
method
=
method
self
.
method
=
method
@
classmethod
@
classmethod
...
@@ -160,10 +173,13 @@ class CallMethod(Expr):
...
@@ -160,10 +173,13 @@ class CallMethod(Expr):
def
interpret
(
self
,
*
inputs
):
def
interpret
(
self
,
*
inputs
):
args
,
kwargs
=
self
.
unflatten_args
(
inputs
)
args
,
kwargs
=
self
.
unflatten_args
(
inputs
)
obj
=
args
[
0
]
obj
=
args
[
0
]
args
=
args
[
1
:]
meth
=
getattr
(
obj
,
self
.
method
)
if
inspect
.
ismethod
(
meth
):
args
=
args
[
1
:]
outputs
=
getattr
(
obj
,
self
.
method
)(
*
args
,
**
kwargs
)
outputs
=
getattr
(
obj
,
self
.
method
)(
*
args
,
**
kwargs
)
if
isinstance
(
outputs
,
RawTensor
):
if
outputs
is
None
:
outputs
=
(
outputs
,)
return
outputs
outputs
,
_
=
tree_flatten
(
outputs
,
is_leaf
=
lambda
x
:
isinstance
(
x
,
RawTensor
))
return
outputs
return
outputs
def
__repr__
(
self
):
def
__repr__
(
self
):
...
@@ -171,7 +187,7 @@ class CallMethod(Expr):
...
@@ -171,7 +187,7 @@ class CallMethod(Expr):
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
kwargs
=
", "
.
join
(
"{}={}"
.
format
(
k
,
v
)
for
k
,
v
in
self
.
kwargs
.
items
())
return
"{} = {}.{}({})"
.
format
(
return
"{} = {}.{}({})"
.
format
(
", "
.
join
(
str
(
i
)
for
i
in
self
.
outputs
),
", "
.
join
(
str
(
i
)
for
i
in
self
.
outputs
),
self
.
input
s
[
0
],
self
.
arg
s
[
0
],
self
.
method
,
self
.
method
,
", "
.
join
([
args
,
kwargs
]),
", "
.
join
([
args
,
kwargs
]),
)
)
...
@@ -209,9 +225,8 @@ class Apply(Expr):
...
@@ -209,9 +225,8 @@ class Apply(Expr):
if
node
is
None
:
# capture as constant
if
node
is
None
:
# capture as constant
NodeMixin
.
wrap_safe
(
i
,
Constant
.
make
(
i
))
NodeMixin
.
wrap_safe
(
i
,
Constant
.
make
(
i
))
apply_node
=
cls
.
make
(
opdef
)
apply_node
=
cls
.
make
(
opdef
)
for
i
in
inputs
:
apply_node
.
add_inputs
(
inputs
)
assert
isinstance
(
i
,
RawTensor
)
assert
not
apply_node
.
const_val
apply_node
.
inputs
.
append
(
NodeMixin
.
get
(
i
))
unset_module_tracing
()
unset_module_tracing
()
outputs
=
apply
(
opdef
,
*
inputs
)
outputs
=
apply
(
opdef
,
*
inputs
)
...
@@ -283,7 +298,7 @@ class Constant(Expr):
...
@@ -283,7 +298,7 @@ class Constant(Expr):
return
(
self
.
value
,)
return
(
self
.
value
,)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
"{} = Constant({})"
.
format
(
self
.
outputs
[
0
],
self
.
value
)
return
"{} = Constant({})"
.
format
(
self
.
outputs
[
0
],
type
(
self
.
value
)
)
def
__getstate__
(
self
):
def
__getstate__
(
self
):
state
=
self
.
__dict__
.
copy
()
state
=
self
.
__dict__
.
copy
()
...
...
imperative/python/megengine/experimental/traced_module/module_tracer.py
浏览文件 @
b1c46ba4
...
@@ -79,6 +79,8 @@ BUILTIN_ARRAY_METHOD = [
...
@@ -79,6 +79,8 @@ BUILTIN_ARRAY_METHOD = [
"min"
,
"min"
,
"max"
,
"max"
,
"mean"
,
"mean"
,
"__getitem__"
,
"__setitem__"
,
]
]
...
@@ -176,7 +178,8 @@ class Patcher:
...
@@ -176,7 +178,8 @@ class Patcher:
self
.
patch_module
(
module
)
self
.
patch_module
(
module
)
for
meth
in
BUILTIN_ARRAY_METHOD
:
for
meth
in
BUILTIN_ARRAY_METHOD
:
self
.
patch_method
(
ArrayMethodMixin
,
meth
,
self
.
wrap_fn
)
self
.
patch_method
(
ArrayMethodMixin
,
meth
,
self
.
wrap_fn
)
self
.
patch_method
(
Tensor
,
"detach"
,
self
.
wrap_fn
)
self
.
patch_method
(
Tensor
,
"__new__"
,
self
.
wrap_fn
)
for
i
,
j
in
self
.
_builtin_functions
:
for
i
,
j
in
self
.
_builtin_functions
:
if
id
(
i
)
not
in
self
.
visited_frames_ids
:
if
id
(
i
)
not
in
self
.
visited_frames_ids
:
self
.
patch_function
(
i
,
j
,
self
.
wrap_fn
)
self
.
patch_function
(
i
,
j
,
self
.
wrap_fn
)
...
@@ -203,7 +206,13 @@ class Patcher:
...
@@ -203,7 +206,13 @@ class Patcher:
import
inspect
import
inspect
if
id
(
module
.
__dict__
)
not
in
self
.
visited_frames_ids
:
if
id
(
module
.
__dict__
)
not
in
self
.
visited_frames_ids
:
for
k
,
v
in
module
.
__dict__
.
items
():
keys
=
(
getattr
(
module
,
"__all__"
)
if
hasattr
(
module
,
"__all__"
)
else
module
.
__dict__
.
keys
()
)
for
k
in
keys
:
v
=
getattr
(
module
,
k
)
if
inspect
.
isfunction
(
v
)
and
not
k
.
startswith
(
"_"
):
if
inspect
.
isfunction
(
v
)
and
not
k
.
startswith
(
"_"
):
self
.
patch_function
(
module
.
__dict__
,
k
,
self
.
wrap_fn
)
self
.
patch_function
(
module
.
__dict__
,
k
,
self
.
wrap_fn
)
self
.
visited_frames_ids
.
add
(
id
(
module
.
__dict__
))
self
.
visited_frames_ids
.
add
(
id
(
module
.
__dict__
))
...
...
imperative/python/megengine/experimental/traced_module/node.py
浏览文件 @
b1c46ba4
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Any
,
Dict
,
Tuple
,
Type
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Type
import
numpy
import
numpy
...
@@ -31,6 +31,7 @@ class Node:
...
@@ -31,6 +31,7 @@ class Node:
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
self
.
expr
=
expr
self
.
expr
=
expr
self
.
users
=
[]
# List[Expr]
self
.
_id
=
Node
.
__total_id
self
.
_id
=
Node
.
__total_id
Node
.
__total_id
+=
1
Node
.
__total_id
+=
1
self
.
_name
=
name
self
.
_name
=
name
...
@@ -59,11 +60,13 @@ class ModuleNode(Node):
...
@@ -59,11 +60,13 @@ class ModuleNode(Node):
module_type
=
Module
# type: Type[Module]
module_type
=
Module
# type: Type[Module]
attr_type_map
=
None
# type: Dict[str, Type[Any]]
attr_type_map
=
None
# type: Dict[str, Type[Any]]
argdef_graph_map
=
None
# type: Dict[Treedef, "InternalGraph"]
argdef_graph_map
=
None
# type: Dict[Treedef, "InternalGraph"]
argdef_outdef_map
=
None
# type: Dict[Treedef, Treedef]
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
def
__init__
(
self
,
expr
:
"Expr"
,
name
:
str
=
None
):
super
().
__init__
(
expr
,
name
)
super
().
__init__
(
expr
,
name
)
self
.
attr_type_map
=
{}
self
.
attr_type_map
=
{}
self
.
argdef_graph_map
=
{}
self
.
argdef_graph_map
=
{}
self
.
argdef_outdef_map
=
{}
def
__repr__
(
self
):
def
__repr__
(
self
):
if
self
.
_name
is
None
:
if
self
.
_name
is
None
:
...
...
imperative/python/megengine/experimental/traced_module/pytree.py
浏览文件 @
b1c46ba4
...
@@ -10,6 +10,8 @@
...
@@ -10,6 +10,8 @@
import
collections
import
collections
from
typing
import
Callable
,
NamedTuple
from
typing
import
Callable
,
NamedTuple
import
numpy
as
np
SUPPORTED_TYPE
=
{}
SUPPORTED_TYPE
=
{}
NodeType
=
NamedTuple
(
"NodeType"
,
[(
"flatten"
,
Callable
),
(
"unflatten"
,
Callable
)])
NodeType
=
NamedTuple
(
"NodeType"
,
[(
"flatten"
,
Callable
),
(
"unflatten"
,
Callable
)])
...
@@ -33,7 +35,7 @@ def _dict_unflatten(inps, aux_data):
...
@@ -33,7 +35,7 @@ def _dict_unflatten(inps, aux_data):
register_supported_type
(
list
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
list
(
x
))
register_supported_type
(
list
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
list
(
x
))
register_supported_type
(
tuple
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
list
(
x
))
register_supported_type
(
tuple
,
lambda
x
:
(
x
,
None
),
lambda
x
,
aux_data
:
tuple
(
x
))
register_supported_type
(
dict
,
_dict_flatten
,
_dict_unflatten
)
register_supported_type
(
dict
,
_dict_flatten
,
_dict_unflatten
)
register_supported_type
(
register_supported_type
(
slice
,
slice
,
...
@@ -52,7 +54,10 @@ def tree_flatten(
...
@@ -52,7 +54,10 @@ def tree_flatten(
assert
is_leaf
(
values
)
assert
is_leaf
(
values
)
node
=
LeafDef
(
leaf_type
(
values
))
node
=
LeafDef
(
leaf_type
(
values
))
if
is_const_leaf
(
values
):
if
is_const_leaf
(
values
):
node
.
const_val
=
values
if
isinstance
(
values
,
np
.
ndarray
):
node
.
const_val
=
str
(
values
)
else
:
node
.
const_val
=
values
return
[
values
,],
node
return
[
values
,],
node
rst
=
[]
rst
=
[]
...
...
imperative/python/megengine/experimental/traced_module/traced_module.py
浏览文件 @
b1c46ba4
...
@@ -10,8 +10,13 @@ import collections
...
@@ -10,8 +10,13 @@ import collections
import
copy
import
copy
import
functools
import
functools
from
inspect
import
getmembers
,
isclass
,
ismethod
from
inspect
import
getmembers
,
isclass
,
ismethod
from
typing
import
Dict
,
List
,
Type
from
typing
import
Callable
,
Dict
,
Iterable
,
List
,
Sequence
,
Type
import
numpy
as
np
from
numpy.lib.arraysetops
import
isin
from
...
import
functional
as
F
from
...
import
get_logger
from
...
import
module
as
M
from
...
import
module
as
M
from
...core._imperative_rt.core2
import
Tensor
as
RawTensor
from
...core._imperative_rt.core2
import
Tensor
as
RawTensor
from
...core._imperative_rt.core2
import
(
from
...core._imperative_rt.core2
import
(
...
@@ -19,6 +24,7 @@ from ...core._imperative_rt.core2 import (
...
@@ -19,6 +24,7 @@ from ...core._imperative_rt.core2 import (
set_module_tracing
,
set_module_tracing
,
unset_module_tracing
,
unset_module_tracing
,
)
)
from
...core._trace_option
import
set_symbolic_shape
from
...core.tensor.array_method
import
ArrayMethodMixin
from
...core.tensor.array_method
import
ArrayMethodMixin
from
...module
import
Module
from
...module
import
Module
from
...tensor
import
Tensor
from
...tensor
import
Tensor
...
@@ -32,6 +38,8 @@ from .module_tracer import (
...
@@ -32,6 +38,8 @@ from .module_tracer import (
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.node
import
ModuleNode
,
Node
,
NodeMixin
,
TensorNode
from
.pytree
import
tree_flatten
from
.pytree
import
tree_flatten
logger
=
get_logger
(
__name__
)
def
_leaf_type
(
node
):
def
_leaf_type
(
node
):
if
isinstance
(
node
,
RawTensor
):
if
isinstance
(
node
,
RawTensor
):
...
@@ -42,6 +50,11 @@ def _leaf_type(node):
...
@@ -42,6 +50,11 @@ def _leaf_type(node):
return
type
(
node
)
return
type
(
node
)
def
_is_leaf
(
node
):
assert
isinstance
(
node
,
RawTensor
),
type
(
node
)
return
isinstance
(
node
,
RawTensor
)
def
_is_const_leaf
(
node
):
def
_is_const_leaf
(
node
):
if
isinstance
(
node
,
(
RawTensor
,
NodeMixin
,
Module
)):
if
isinstance
(
node
,
(
RawTensor
,
NodeMixin
,
Module
)):
return
False
return
False
...
@@ -80,7 +93,13 @@ class InternalGraph:
...
@@ -80,7 +93,13 @@ class InternalGraph:
@
property
@
property
def
exprs
(
self
):
def
exprs
(
self
):
return
_expr_list
(
self
)
return
ExprFilter
(
_expr_iter
(
self
))
def
get_call_function
(
self
,
func
:
Callable
=
None
):
return
self
.
exprs
.
call_function
(
func
)
def
get_call_method
(
self
,
method
:
str
=
None
):
return
self
.
exprs
.
call_method
(
method
)
def
add_input
(
self
,
i
):
def
add_input
(
self
,
i
):
self
.
_inputs
.
append
(
i
)
self
.
_inputs
.
append
(
i
)
...
@@ -88,16 +107,131 @@ class InternalGraph:
...
@@ -88,16 +107,131 @@ class InternalGraph:
def
add_output
(
self
,
o
):
def
add_output
(
self
,
o
):
self
.
_outputs
.
append
(
o
)
self
.
_outputs
.
append
(
o
)
def
get_dep_exprs
(
self
,
nodes
:
Sequence
[
Node
])
->
List
[
Expr
]:
if
not
isinstance
(
nodes
,
Sequence
):
nodes
=
(
nodes
,)
ret
=
list
()
queue
=
list
(
nodes
)
while
queue
:
node
=
queue
.
pop
()
expr
=
node
.
expr
if
expr
not
in
ret
:
ret
.
append
(
expr
)
for
i
in
expr
.
inputs
:
if
i
not
in
queue
:
queue
.
append
(
i
)
return
ret
def
insert_call_function
(
self
,
func
:
Callable
,
nodes
:
Sequence
[
Node
]):
if
not
isinstance
(
nodes
,
Sequence
):
nodes
=
[
nodes
]
assert
isinstance
(
func
,
Callable
)
for
i
in
nodes
:
assert
isinstance
(
i
,
TensorNode
),
"CallFunction only accept TensorNode as inputs"
expr
=
CallFunction
(
func
)
expr
.
inputs
=
nodes
for
i
in
nodes
:
i
.
users
.
append
(
expr
)
idx
=
max
(
self
.
_exprs
.
index
(
i
.
expr
)
for
i
in
nodes
)
+
1
self
.
_exprs
.
insert
(
idx
,
expr
)
fake_inp_val
=
tuple
(
F
.
zeros
(
shape
=
i
.
shape
,
dtype
=
i
.
dtype
)
for
i
in
nodes
)
fake_out_val
=
func
(
*
fake_inp_val
)
def
create_node
(
val
:
Tensor
):
node
=
TensorNode
(
expr
)
node
.
shape
=
val
.
shape
node
.
dtype
=
val
.
dtype
return
node
out_nodes
=
list
(
create_node
(
i
)
for
i
in
fake_out_val
)
expr
.
outputs
=
out_nodes
return
out_nodes
def
insert_call_method
(
self
,
target
,
method
,
args
):
if
not
isinstance
(
args
,
Sequence
):
args
=
[
args
]
assert
isinstance
(
target
,
(
TensorNode
,
ModuleNode
))
assert
isinstance
(
method
,
str
)
for
i
in
args
:
assert
isinstance
(
i
,
TensorNode
)
expr
=
CallMethod
(
method
)
expr
.
inputs
=
[
target
,
*
args
]
if
isinstance
(
target
,
TensorNode
):
fake_target_val
=
F
.
zeros
(
shape
=
target
.
shape
,
dtype
=
target
.
dtype
)
fake_inp_val
=
tuple
(
F
.
zeros
(
shape
=
i
.
shape
,
dtype
=
i
.
dtype
)
for
i
in
args
)
fake_out_val
=
getattr
(
fake_target_val
,
method
)(
fake_inp_val
)
def
create_node
(
val
:
Tensor
):
node
=
TensorNode
(
expr
)
node
.
shape
=
val
.
shape
node
.
dtype
=
val
.
dtype
return
node
out_nodes
=
list
(
create_node
(
i
)
for
i
in
fake_out_val
)
expr
.
outputs
=
out_nodes
else
:
raise
NotImplementedError
()
return
out_nodes
def
replace_node
(
self
,
repl_dict
:
Dict
[
Node
,
Node
]):
while
repl_dict
:
node
,
repl_node
=
repl_dict
.
popitem
()
# check graph inputs and outputs
assert
node
not
in
self
.
inputs
,
"Cannot replace inputs"
for
i
,
n
in
enumerate
(
self
.
outputs
):
if
n
is
node
:
self
.
outputs
[
i
]
=
repl_node
# update users of node and repl_node
# update inputs of expr in node.users
dep_exprs
=
self
.
get_dep_exprs
(
repl_node
)
i
=
0
while
i
<
len
(
node
.
users
):
n
=
node
.
users
[
i
]
if
n
in
dep_exprs
:
logger
.
info
(
"Find a loop: ignore this replacement once"
)
logger
.
info
(
"node: %s"
%
node
.
__repr__
())
logger
.
info
(
"repl_node: %s"
%
repl_node
.
__repr__
())
i
+=
1
continue
repl_node
.
users
.
append
(
n
)
node
.
users
.
pop
(
i
)
idx
=
n
.
inputs
.
index
(
node
)
n
.
inputs
[
idx
]
=
repl_node
def
compile
(
self
):
"""
Delete unused expr.
"""
dep_exprs
=
self
.
get_dep_exprs
(
self
.
outputs
)
i
=
0
while
i
<
len
(
self
.
_exprs
):
expr
=
self
.
_exprs
[
i
]
if
expr
in
dep_exprs
:
i
+=
1
continue
for
n
in
expr
.
inputs
:
n
.
users
.
remove
(
expr
)
self
.
_exprs
.
remove
(
expr
)
def
interpret
(
self
,
*
inputs
):
def
interpret
(
self
,
*
inputs
):
# TODO: support kwargs ?
# TODO: skip expressions which are independent and have no side effect
node2value
=
{}
node2value
=
{}
for
n
,
v
in
zip
(
self
.
_inputs
,
inputs
):
for
n
,
v
in
zip
(
self
.
_inputs
,
inputs
):
node2value
[
n
]
=
v
node2value
[
n
]
=
v
for
expr
in
self
.
_exprs
:
for
expr
in
self
.
_exprs
:
values
=
expr
.
interpret
(
*
list
(
node2value
[
i
]
for
i
in
expr
.
inputs
))
values
=
expr
.
interpret
(
*
list
(
node2value
[
i
]
for
i
in
expr
.
inputs
))
for
n
,
v
in
zip
(
expr
.
outputs
,
values
):
if
values
is
not
None
:
node2value
[
n
]
=
v
for
n
,
v
in
zip
(
expr
.
outputs
,
values
):
node2value
[
n
]
=
v
return
list
(
node2value
[
i
]
for
i
in
self
.
_outputs
)
return
list
(
node2value
[
i
]
for
i
in
self
.
_outputs
)
def
__repr__
(
self
):
def
__repr__
(
self
):
...
@@ -109,7 +243,8 @@ class InternalGraph:
...
@@ -109,7 +243,8 @@ class InternalGraph:
def
_get_meth_name
(
obj
,
func
):
def
_get_meth_name
(
obj
,
func
):
for
cls
in
type
(
obj
).
mro
():
tp
=
obj
if
isinstance
(
obj
,
type
)
else
type
(
obj
)
for
cls
in
tp
.
mro
():
for
k
,
v
in
cls
.
__dict__
.
items
():
for
k
,
v
in
cls
.
__dict__
.
items
():
if
v
==
func
:
if
v
==
func
:
return
k
return
k
...
@@ -131,15 +266,31 @@ def _wrapped_function(orig_func):
...
@@ -131,15 +266,31 @@ def _wrapped_function(orig_func):
meth_name
=
_get_meth_name
(
args
[
0
],
wrapped_fn
)
meth_name
=
_get_meth_name
(
args
[
0
],
wrapped_fn
)
if
meth_name
:
if
meth_name
:
self
=
inputs
[
0
]
self
=
inputs
[
0
]
call_node
=
CallMethod
.
make
(
NodeMixin
.
get
(
self
),
meth_name
)
if
meth_name
==
"__new__"
:
if
all
([
not
isinstance
(
i
,
RawTensor
)
for
i
in
inputs
]):
# only trace Tensor.__new__() when there are tensors in args
set_module_tracing
()
return
orig_func
(
*
args
,
**
kwargs
)
if
isinstance
(
args
[
1
],
RawTensor
):
node
=
NodeMixin
.
get
(
inputs
[
1
])
inputs
[
1
]
=
copy
.
copy
(
inputs
[
1
])
# copy inputs[1] to avoid tensor and Tensor(tensor) share same m_tensor, which will cause they have same _NodeMixin__node in tracing.
NodeMixin
.
wrap_safe
(
inputs
[
1
],
node
)
args
,
kwargs
=
tree_def
.
unflatten
(
inputs
)
call_node
=
CallMethod
.
make
(
self
,
meth_name
)
else
:
call_node
=
CallMethod
.
make
(
NodeMixin
.
get
(
self
),
meth_name
)
call_node
.
add_inputs
(
inputs
[
1
:])
else
:
else
:
call_node
=
CallFunction
.
make
(
orig_func
)
call_node
=
CallFunction
.
make
(
orig_func
)
call_node
.
add_inputs
(
inputs
)
call_node
.
add_inputs
(
inputs
)
call_node
.
arg_def
=
tree_def
call_node
.
arg_def
=
tree_def
outputs
=
orig_func
(
*
args
,
**
kwargs
)
outputs
=
orig_func
(
*
args
,
**
kwargs
)
call_node
.
add_outputs
(
outputs
)
if
meth_name
==
"__new__"
:
call_node
.
add_outputs
(
outputs
,
False
)
else
:
call_node
.
add_outputs
(
outputs
)
set_module_tracing
()
set_module_tracing
()
return
outputs
return
outputs
return
orig_func
(
*
args
,
**
kwargs
)
return
orig_func
(
*
args
,
**
kwargs
)
...
@@ -197,13 +348,14 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -197,13 +348,14 @@ class TracedModuleBuilder(NodeMixin):
mark_constant
(
i
)
mark_constant
(
i
)
callnode
=
CallMethod
.
make
(
NodeMixin
.
get
(
self
))
callnode
=
CallMethod
.
make
(
NodeMixin
.
get
(
self
))
callnode
.
add_inputs
(
inputs
)
callnode
.
add_inputs
(
inputs
[
1
:]
)
callnode
.
arg_def
=
tree_def
callnode
.
arg_def
=
tree_def
if
self
.
_is_builtin
:
if
self
.
_is_builtin
:
unset_module_tracing
()
unset_module_tracing
()
outputs
=
self
.
_mod
(
*
args
,
**
kwargs
)
rst
=
self
.
_mod
(
*
args
,
**
kwargs
)
outputs
,
out_def
=
tree_flatten
(
rst
,
leaf_type
=
_leaf_type
,
is_leaf
=
_is_leaf
)
set_module_tracing
()
set_module_tracing
()
if
self
.
_is_builtin
:
if
self
.
_is_builtin
:
self
.
_body
=
None
self
.
_body
=
None
...
@@ -215,14 +367,13 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -215,14 +367,13 @@ class TracedModuleBuilder(NodeMixin):
NodeMixin
.
wrap_safe
(
NodeMixin
.
wrap_safe
(
self
,
Input
.
make
(
"self"
,
NodeMixin
.
get_wrapped_type
(
self
))
self
,
Input
.
make
(
"self"
,
NodeMixin
.
get_wrapped_type
(
self
))
)
)
origin_inp_node
=
[
NodeMixin
.
get
(
i
,
None
)
for
i
in
inputs
[
1
:]]
# prepare args and kwargs for inner graph
# prepare args and kwargs for inner graph
def
wrap
(
x
):
def
wrap
(
x
):
wrapped
=
copy
.
copy
(
x
)
# FIXME
NodeMixin
.
wrap
(
NodeMixin
.
wrap
(
wrapped
,
x
,
lambda
:
Input
.
make
(
type
=
NodeMixin
.
get_wrapped_type
(
x
)),
lambda
:
Input
.
make
(
type
=
NodeMixin
.
get_wrapped_type
(
wrapped
)),
)
)
return
wrapped
return
x
args
=
[
self
]
args
=
[
self
]
for
i
in
inputs
[
1
:]:
for
i
in
inputs
[
1
:]:
...
@@ -231,21 +382,25 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -231,21 +382,25 @@ class TracedModuleBuilder(NodeMixin):
active_module_tracer
().
patcher
.
auto_patch
(
active_module_tracer
().
patcher
.
auto_patch
(
getattr
(
getattr
(
self
.
_mod
,
"forward"
,
self
.
_mod
),
"__globals__"
,
{})
getattr
(
getattr
(
self
.
_mod
,
"forward"
,
self
.
_mod
),
"__globals__"
,
{})
)
)
outputs
=
type
(
self
.
_mod
).
forward
(
*
args
,
**
kwargs
)
rst
=
type
(
self
.
_mod
).
forward
(
*
args
,
**
kwargs
)
outputs
,
out_def
=
tree_flatten
(
rst
,
leaf_type
=
_leaf_type
,
is_leaf
=
_is_leaf
)
for
i
in
(
for
i
in
(
outputs
if
isinstance
(
outputs
,
collections
.
abc
.
Sequence
)
else
(
outputs
,)
outputs
if
isinstance
(
outputs
,
collections
.
abc
.
Sequence
)
else
(
outputs
,)
):
):
active_module_tracer
().
current_scope
().
add_output
(
NodeMixin
.
get
(
i
))
active_module_tracer
().
current_scope
().
add_output
(
NodeMixin
.
get
(
i
))
NodeMixin
.
wrap_safe
(
self
,
orig_self
)
NodeMixin
.
wrap_safe
(
self
,
orig_self
)
for
arg
,
node
in
zip
(
inputs
[
1
:],
origin_inp_node
):
if
node
:
NodeMixin
.
wrap_safe
(
arg
,
node
)
active_module_tracer
().
pop_scope
()
active_module_tracer
().
pop_scope
()
# rebind output to outer graph
# rebind output to outer graph
callnode
.
add_outputs
(
outputs
)
callnode
.
add_outputs
(
outputs
)
self_node
=
NodeMixin
.
get
(
self
)
self_node
=
NodeMixin
.
get
(
self
)
self_node
.
argdef_graph_map
[
callnode
.
arg_def
]
=
self
.
_body
self_node
.
argdef_graph_map
[
callnode
.
arg_def
]
=
self
.
_body
return
outputs
self_node
.
argdef_outdef_map
[
callnode
.
arg_def
]
=
out_def
return
rst
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
):
if
name
not
in
self
.
_mod
.
__dict__
:
if
name
not
in
self
.
_mod
.
__dict__
:
...
@@ -268,20 +423,29 @@ class TracedModuleBuilder(NodeMixin):
...
@@ -268,20 +423,29 @@ class TracedModuleBuilder(NodeMixin):
return
super
().
__getattribute__
(
name
)
return
super
().
__getattribute__
(
name
)
else
:
else
:
wrapped
=
super
().
__getattribute__
(
name
)
wrapped
=
super
().
__getattribute__
(
name
)
if
name
in
self
.
_mod
.
__dict__
and
not
NodeMixin
.
get
(
wrapped
,
None
):
if
name
in
self
.
_mod
.
__dict__
:
assert
not
self
.
_is_builtin
if
not
NodeMixin
.
get
(
wrapped
,
None
):
NodeMixin
.
wrap
(
assert
not
self
.
_is_builtin
wrapped
,
NodeMixin
.
wrap
(
lambda
:
GetAttr
.
make
(
wrapped
,
lambda
:
GetAttr
.
make
(
NodeMixin
.
get
(
self
),
name
,
type
=
NodeMixin
.
get_wrapped_type
(
wrapped
),
),
)
else
:
node
=
NodeMixin
.
get
(
wrapped
)
expr
=
GetAttr
.
make
(
NodeMixin
.
get
(
self
),
NodeMixin
.
get
(
self
),
name
,
name
,
type
=
NodeMixin
.
get_wrapped_type
(
wrapped
),
type
=
NodeMixin
.
get_wrapped_type
(
wrapped
),
)
,
)
.
expr
)
expr
.
outputs
[
0
]
=
node
return
wrapped
return
wrapped
class
_expr_
list
:
class
_expr_
iter
:
def
__init__
(
self
,
graph
:
InternalGraph
):
def
__init__
(
self
,
graph
:
InternalGraph
):
self
.
graph
=
graph
self
.
graph
=
graph
...
@@ -295,6 +459,59 @@ class _expr_list:
...
@@ -295,6 +459,59 @@ class _expr_list:
yield
expr
yield
expr
class
ExprFilter
:
def
__init__
(
self
,
expr_iter
:
Iterable
):
self
.
_iter
=
expr_iter
def
__iter__
(
self
):
return
iter
(
self
.
_iter
)
def
call_function
(
self
,
func
):
return
ExprFilterCallFunction
(
self
,
func
)
def
call_method
(
self
,
method
):
return
ExprFilterCallMethod
(
self
,
method
)
def
as_list
(
self
):
return
list
(
self
)
def
as_dict
(
self
):
raise
NotImplementedError
(
"need key"
)
def
as_unique
(
self
):
(
expr
,)
=
self
return
expr
def
as_count
(
self
):
return
sum
(
1
for
_
in
self
)
class
ExprFilterCallFunction
(
ExprFilter
):
def
__init__
(
self
,
expr_iter
,
func
:
Callable
=
None
):
super
().
__init__
(
expr_iter
)
self
.
func
=
func
def
__iter__
(
self
):
for
i
in
self
.
_iter
:
if
not
isinstance
(
i
,
CallFunction
):
continue
if
self
.
func
is
None
or
i
.
func
==
self
.
func
:
yield
i
class
ExprFilterCallMethod
(
ExprFilter
):
def
__init__
(
self
,
expr_iter
,
method
:
str
=
None
):
super
().
__init__
(
expr_iter
)
self
.
method
=
method
def
__iter__
(
self
):
for
i
in
self
.
_iter
:
if
not
isinstance
(
i
,
CallMethod
):
continue
if
self
.
method
is
None
or
i
.
method
==
self
.
method
:
yield
i
class
TracedModule
(
Module
):
class
TracedModule
(
Module
):
"""
"""
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
`TracedModule` is the Module created by tracing normal module. It owns a ModuleNode(m_node). `TracedModule` can not be called directly. It can be
...
@@ -312,10 +529,12 @@ class TracedModule(Module):
...
@@ -312,10 +529,12 @@ class TracedModule(Module):
((
self
,
*
args
),
kwargs
),
_leaf_type
,
is_const_leaf
=
_is_const_leaf
((
self
,
*
args
),
kwargs
),
_leaf_type
,
is_const_leaf
=
_is_const_leaf
)
)
assert
treedef
in
self
.
m_node
.
argdef_graph_map
assert
treedef
in
self
.
m_node
.
argdef_graph_map
inputs
=
[
i
for
i
in
inputs
if
isinstance
(
i
,
(
Module
,
RawTensor
))]
inputs
=
filter
(
lambda
i
:
isinstance
(
i
,
(
Module
,
TracedModuleBuilder
,
RawTensor
)),
inputs
)
# allow TracedModuleBuilder for retrace.
outputs
=
self
.
m_node
.
argdef_graph_map
[
treedef
].
interpret
(
*
inputs
)
outputs
=
self
.
m_node
.
argdef_graph_map
[
treedef
].
interpret
(
*
inputs
)
if
len
(
outputs
)
==
1
:
out_def
=
self
.
m_node
.
argdef_outdef_map
[
treedef
]
return
outputs
[
0
]
outputs
=
out_def
.
unflatten
(
outputs
)
return
outputs
return
outputs
@
property
@
property
...
@@ -339,9 +558,8 @@ class TracedModule(Module):
...
@@ -339,9 +558,8 @@ class TracedModule(Module):
if
graph
is
None
:
if
graph
is
None
:
assert
not
isinstance
(
module
,
TracedModule
)
assert
not
isinstance
(
module
,
TracedModule
)
const
=
Constant
(
module
)
const
=
Constant
(
module
)
modulenode
=
const
.
outputs
[
0
]
const
.
outputs
[
0
]
=
call
.
inputs
[
0
]
modulenode
.
module_type
=
type
(
module
)
const
.
outputs
[
0
].
expr
=
const
call
.
inputs
[
0
]
=
modulenode
return
[
const
,
call
]
return
[
const
,
call
]
exprs
=
[]
exprs
=
[]
for
expr
in
graph
.
_exprs
:
for
expr
in
graph
.
_exprs
:
...
@@ -350,30 +568,41 @@ class TracedModule(Module):
...
@@ -350,30 +568,41 @@ class TracedModule(Module):
if
call
and
inp
in
graph
.
_inputs
:
if
call
and
inp
in
graph
.
_inputs
:
inp_idx
=
graph
.
_inputs
.
index
(
inp
)
inp_idx
=
graph
.
_inputs
.
index
(
inp
)
expr
.
inputs
[
idx
]
=
call
.
inputs
[
inp_idx
]
expr
.
inputs
[
idx
]
=
call
.
inputs
[
inp_idx
]
call
.
inputs
[
inp_idx
].
users
.
append
(
expr
)
# replace outputs for submodule's expr
# replace outputs for submodule's expr
for
idx
,
outp
in
enumerate
(
expr
.
outputs
):
for
idx
,
outp
in
enumerate
(
expr
.
outputs
):
if
call
and
outp
in
graph
.
_outputs
:
if
call
and
outp
in
graph
.
_outputs
:
oup_idx
=
graph
.
_outputs
.
index
(
outp
)
oup_idx
=
graph
.
_outputs
.
index
(
outp
)
expr
.
outputs
[
idx
]
=
call
.
outputs
[
oup_idx
]
expr
.
outputs
[
idx
]
=
call
.
outputs
[
oup_idx
]
call
.
outputs
[
oup_idx
].
expr
=
expr
if
isinstance
(
expr
,
GetAttr
):
if
isinstance
(
expr
,
GetAttr
):
# replace GetAttr with Constant
# replace GetAttr with Constant
if
isinstance
(
expr
.
outputs
[
0
],
TensorNode
):
if
isinstance
(
expr
.
outputs
[
0
],
TensorNode
):
const
=
Constant
(
getattr
(
module
,
expr
.
name
))
const
=
Constant
(
getattr
(
module
,
expr
.
name
))
const
.
outputs
=
expr
.
outputs
const
.
outputs
=
expr
.
outputs
const
.
outputs
[
0
].
expr
=
const
exprs
.
append
(
const
)
exprs
.
append
(
const
)
elif
isinstance
(
expr
,
CallMethod
):
elif
isinstance
(
expr
,
CallMethod
):
obj_node
=
expr
.
inputs
[
0
]
obj_node
=
expr
.
inputs
[
0
]
if
isinstance
(
obj_node
,
ModuleNode
):
if
isinstance
(
obj_node
,
ModuleNode
):
assert
isinstance
(
expr
.
inputs
[
0
].
expr
,
GetAttr
)
pre_expr
=
expr
.
inputs
[
0
].
expr
(
obj
,)
=
expr
.
inputs
[
0
].
expr
.
interpret
(
module
)
if
isinstance
(
pre_expr
,
GetAttr
):
exprs
.
extend
(
_flatten_subgraph
(
expr
.
graph
,
obj
,
expr
))
(
obj
,)
=
expr
.
inputs
[
0
].
expr
.
interpret
(
module
)
exprs
.
extend
(
_flatten_subgraph
(
expr
.
graph
,
obj
,
expr
))
else
:
# module has been replaced.
assert
isinstance
(
pre_expr
,
Constant
)
else
:
else
:
exprs
.
append
(
expr
)
exprs
.
append
(
expr
)
else
:
else
:
exprs
.
append
(
expr
)
exprs
.
append
(
expr
)
if
call
is
not
None
:
for
i
in
call
.
inputs
:
i
.
users
.
remove
(
call
)
return
exprs
return
exprs
new_module
.
graph
.
_exprs
=
_flatten_subgraph
(
new_module
.
graph
,
new_module
)
new_module
.
graph
.
_exprs
=
_flatten_subgraph
(
new_module
.
graph
,
new_module
)
...
@@ -422,22 +651,26 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
...
@@ -422,22 +651,26 @@ def trace_module(mod: Module, *args: Tensor, **kwargs: Tensor) -> TracedModule:
"""
"""
assert
active_module_tracer
()
is
None
assert
active_module_tracer
()
is
None
try
:
try
:
use_sym_shape
=
set_symbolic_shape
(
True
)
set_module_tracing
()
set_module_tracing
()
set_active_module_tracer
(
module_tracer
(
_wrapped_function
))
set_active_module_tracer
(
module_tracer
(
_wrapped_function
))
with
active_module_tracer
().
patcher
:
with
active_module_tracer
().
patcher
:
global_scope
=
InternalGraph
()
global_scope
=
InternalGraph
()
active_module_tracer
().
push_scope
(
global_scope
)
active_module_tracer
().
push_scope
(
global_scope
)
builder
=
TracedModuleBuilder
(
mod
,
True
)
builder
=
TracedModuleBuilder
(
mod
,
True
)
NodeMixin
.
wrap_safe
(
builder
,
Input
.
make
(
"TopModule"
,
ModuleNode
))
NodeMixin
.
wrap_safe
(
builder
,
Input
.
make
(
"TopModule"
,
ModuleNode
))
inputs
,
_
=
tree_flatten
((
args
,
kwargs
))
inputs
,
_
=
tree_flatten
((
args
,
kwargs
)
,
is_const_leaf
=
_is_const_leaf
)
for
_
,
i
in
enumerate
(
inputs
):
for
_
,
i
in
enumerate
(
inputs
):
NodeMixin
.
wrap_safe
(
if
isinstance
(
i
,
RawTensor
):
i
,
Input
.
make
(
"arg_{}"
.
format
(
_
),
NodeMixin
.
get_wrapped_type
(
i
))
NodeMixin
.
wrap_safe
(
)
i
,
Input
.
make
(
"arg_{}"
.
format
(
_
),
NodeMixin
.
get_wrapped_type
(
i
))
)
builder
(
*
args
,
**
kwargs
)
builder
(
*
args
,
**
kwargs
)
active_module_tracer
().
pop_scope
()
active_module_tracer
().
pop_scope
()
return
builder
.
build
()
return
builder
.
build
()
finally
:
finally
:
set_symbolic_shape
(
use_sym_shape
)
set_active_module_tracer
(
None
)
set_active_module_tracer
(
None
)
unset_module_tracing
()
unset_module_tracing
()
imperative/python/test/unit/traced_module/test_haoruitao.py
0 → 100644
浏览文件 @
b1c46ba4
import
io
import
pickle
import
numpy
as
np
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.utils.comp_graph_tools
as
cgtools
from
megengine.core._trace_option
import
set_symbolic_shape
from
megengine.experimental.traced_module
import
trace_module
from
megengine.jit
import
trace
set_symbolic_shape
(
True
)
class
Main
(
M
.
Module
):
def
forward
(
self
,
x
):
return
x
class
PreProcess
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
I
=
F
.
ones
((
1
,))
self
.
M
=
F
.
zeros
((
1
,))
def
forward
(
self
,
data
,
idx
,
roi
):
N
,
H
,
W
,
C
=
data
.
shape
xmax
=
roi
[:,
1
,
0
]
xmin
=
roi
[:,
0
,
0
]
ymax
=
roi
[:,
1
,
1
]
ymin
=
roi
[:,
0
,
1
]
scale
=
F
.
maximum
((
xmax
-
xmin
)
/
W
,
(
ymax
-
ymin
)
/
H
)
I
=
F
.
broadcast_to
(
self
.
I
,
(
N
,))
M
=
F
.
broadcast_to
(
self
.
M
,
(
N
,
3
,
3
))
M
[:,
0
,
0
]
=
scale
M
[:,
0
,
2
]
=
xmin
M
[:,
1
,
1
]
=
scale
M
[:,
1
,
2
]
=
ymin
M
[:,
2
,
2
]
=
I
resized
=
(
F
.
warp_perspective
(
data
,
M
,
(
H
,
W
),
mat_idx
=
idx
,
border_mode
=
"CONSTANT"
,
format
=
"NHWC"
)
.
transpose
(
0
,
3
,
1
,
2
)
.
astype
(
np
.
float32
)
)
return
resized
class
Net
(
M
.
Module
):
def
__init__
(
self
,
traced_module
):
super
().
__init__
()
self
.
pre_process
=
PreProcess
()
self
.
traced_module
=
traced_module
def
forward
(
self
,
data
,
idx
,
roi
):
x
=
self
.
pre_process
(
data
,
idx
,
roi
)
x
=
self
.
traced_module
(
x
)
return
x
def
test_preprocess
():
module
=
Main
()
data
=
F
.
ones
((
1
,
14
,
8
,
8
),
dtype
=
np
.
uint8
)
traced_module
=
trace_module
(
module
,
data
)
obj
=
pickle
.
dumps
(
traced_module
)
traced_module
=
pickle
.
loads
(
obj
)
module
=
Net
(
traced_module
)
module
.
eval
()
idx
=
F
.
zeros
((
1
,),
dtype
=
np
.
int32
)
roi
=
F
.
ones
((
1
,
2
,
2
),
dtype
=
np
.
float32
)
y
=
module
(
data
,
idx
,
roi
)
traced_module
=
trace_module
(
module
,
data
,
idx
,
roi
)
np
.
testing
.
assert_array_equal
(
traced_module
(
data
,
idx
,
roi
),
y
)
func
=
trace
(
traced_module
,
capture_as_const
=
True
)
np
.
testing
.
assert_array_equal
(
func
(
data
,
idx
,
roi
),
y
)
model
=
io
.
BytesIO
()
func
.
dump
(
model
,
arg_names
=
(
"data"
,
"idx"
,
"roi"
))
model
.
seek
(
0
)
infer_cg
=
cgtools
.
GraphInference
(
model
)
np
.
testing
.
assert_allclose
(
list
(
infer_cg
.
run
(
inp_dict
=
{
"data"
:
data
.
numpy
(),
"idx"
:
idx
.
numpy
(),
"roi"
:
roi
.
numpy
()}
).
values
()
)[
0
],
y
,
atol
=
1e-6
,
)
imperative/python/test/unit/traced_module/test_modification.py
0 → 100644
浏览文件 @
b1c46ba4
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
numpy
as
np
import
megengine.functional
as
F
import
megengine.module
as
M
from
megengine.experimental.traced_module
import
trace_module
from
megengine.experimental.traced_module.expr
import
CallFunction
,
GetAttr
class
MyBlock
(
M
.
Module
):
def
__init__
(
self
,
in_channels
=
3
,
channels
=
3
):
super
(
MyBlock
,
self
).
__init__
()
self
.
conv1
=
M
.
Conv2d
(
in_channels
,
channels
,
3
,
1
,
padding
=
1
,
bias
=
False
)
self
.
bn1
=
M
.
BatchNorm2d
(
channels
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
F
.
relu
(
x
)
+
1
return
x
class
MyModule
(
M
.
Module
):
def
__init__
(
self
):
super
(
MyModule
,
self
).
__init__
()
self
.
block0
=
MyBlock
()
self
.
block1
=
MyBlock
()
def
forward
(
self
,
x
):
x
=
self
.
block0
(
x
)
x
=
self
.
block1
(
x
)
return
x
def
_init_cls
(
cls
):
module
=
cls
()
x
=
F
.
ones
((
1
,
3
,
3
,
3
))
y
=
module
(
x
)
traced_module
=
trace_module
(
module
,
x
)
return
traced_module
,
x
,
y
def
_init_block
():
return
_init_cls
(
MyBlock
)
def
_init_module
():
return
_init_cls
(
MyModule
)
def
test_search
():
traced_module
,
*
_
=
_init_block
()
graph
=
traced_module
.
graph
relu_expr
=
graph
.
get_call_function
(
F
.
relu
).
as_unique
()
assert
isinstance
(
relu_expr
,
CallFunction
)
and
relu_expr
.
func
==
F
.
relu
def
test_insert
():
traced_module
,
x
,
expect
=
_init_block
()
graph
=
traced_module
.
graph
relu_node
=
graph
.
get_call_function
(
F
.
relu
).
as_unique
().
outputs
neg_node
=
graph
.
insert_call_function
(
F
.
neg
,
relu_node
)
graph
.
replace_node
({
relu_node
[
0
]:
neg_node
[
0
]})
graph
.
compile
()
np
.
testing
.
assert_allclose
(
expect
-
1
,
1
-
traced_module
(
x
),
atol
=
1e-6
)
def
test_delete
():
traced_module
,
x
,
expect
=
_init_block
()
graph
=
traced_module
.
graph
relu_expr
=
graph
.
get_call_function
(
F
.
relu
).
as_unique
()
node
=
relu_expr
.
outputs
repl_node
=
relu_expr
.
inputs
graph
.
replace_node
({
node
[
0
]:
repl_node
[
0
]})
graph
.
compile
()
np
.
testing
.
assert_allclose
(
expect
-
1
,
F
.
relu
(
traced_module
(
x
)
-
1
),
atol
=
1e-6
)
def
test_flatten
():
traced_module
,
x
,
expect
=
_init_module
()
traced_module
=
traced_module
.
flatten
()
traced_module
.
graph
.
compile
()
assert
all
(
not
isinstance
(
i
,
GetAttr
)
for
i
in
traced_module
.
graph
.
_exprs
)
assert
len
(
traced_module
.
graph
.
_exprs
)
==
12
def
test_extra_block
():
class
PostProcess
(
M
.
Module
):
def
forward
(
self
,
x
):
return
x
*
2
class
Net
(
M
.
Module
):
def
__init__
(
self
,
traced_module
):
super
().
__init__
()
self
.
post_process
=
PostProcess
()
self
.
traced_module
=
traced_module
def
forward
(
self
,
x
):
x
=
self
.
traced_module
(
x
)
x
=
self
.
post_process
(
x
)
return
x
traced_module
,
x
,
expect
=
_init_block
()
module
=
Net
(
traced_module
)
np
.
testing
.
assert_allclose
(
2
*
expect
,
module
(
x
),
atol
=
1e-6
)
traced_module
=
trace_module
(
module
,
x
)
np
.
testing
.
assert_allclose
(
2
*
expect
,
traced_module
(
x
),
atol
=
1e-6
)
imperative/python/test/unit/traced_module/test_serialization.py
0 → 100644
浏览文件 @
b1c46ba4
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
pickle
import
numpy
as
np
import
megengine.functional
as
F
import
megengine.module
as
M
from
megengine
import
Tensor
from
megengine.experimental.traced_module
import
trace_module
from
megengine.module
import
Module
class
MyBlock
(
Module
):
def
__init__
(
self
,
in_channels
,
channels
):
super
(
MyBlock
,
self
).
__init__
()
self
.
conv1
=
M
.
Conv2d
(
in_channels
,
channels
,
3
,
1
,
padding
=
1
,
bias
=
False
)
self
.
bn1
=
M
.
BatchNorm2d
(
channels
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
F
.
relu
(
x
)
+
1
return
x
class
MyModule
(
Module
):
def
__init__
(
self
):
super
(
MyModule
,
self
).
__init__
()
self
.
block0
=
MyBlock
(
8
,
4
)
self
.
block1
=
MyBlock
(
4
,
2
)
def
forward
(
self
,
x
):
x
=
self
.
block0
(
x
)
x
=
self
.
block1
(
x
)
return
x
def
test_dump_and_load
():
module
=
MyModule
()
x
=
Tensor
(
np
.
ones
((
1
,
8
,
14
,
14
)))
expect
=
module
(
x
)
traced_module
=
trace_module
(
module
,
x
)
np
.
testing
.
assert_array_equal
(
expect
,
traced_module
(
x
))
obj
=
pickle
.
dumps
(
traced_module
)
pickle
.
loads
(
obj
)
np
.
testing
.
assert_array_equal
(
expect
,
traced_module
(
x
))
imperative/python/test/unit/traced_module/test_trace_module.py
0 → 100644
浏览文件 @
b1c46ba4
import
numpy
as
np
from
megengine
import
Tensor
from
megengine.experimental.traced_module
import
trace_module
from
megengine.module
import
Module
as
M
class
MyModule1
(
M
):
def
forward
(
self
,
x
):
y
=
Tensor
(
x
)
y
+=
1
x
=
x
+
2
return
x
,
y
class
MyModule2
(
M
):
def
forward
(
self
,
x
):
y
=
Tensor
([
1
,
x
,
1
])
y
+=
1
x
=
x
+
2
return
x
,
y
def
test_trace_module
():
x
=
Tensor
(
1
)
m1
=
MyModule1
()
tm1
=
trace_module
(
m1
,
x
)
m2
=
MyModule2
()
tm2
=
trace_module
(
m2
,
x
)
inp
=
Tensor
(
2
)
gt
=
m1
(
inp
)
output
=
tm1
(
inp
)
for
a
,
b
in
zip
(
output
,
gt
):
np
.
testing
.
assert_equal
(
a
.
numpy
(),
b
.
numpy
())
gt1
=
m2
(
inp
)
output1
=
tm2
(
inp
)
for
a
,
b
in
zip
(
output1
,
gt1
):
np
.
testing
.
assert_equal
(
a
.
numpy
(),
b
.
numpy
())
imperative/python/test/unit/traced_module/test_wujianan.py
0 → 100644
浏览文件 @
b1c46ba4
import
io
import
pickle
import
numpy
as
np
import
megengine
as
mge
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.utils.comp_graph_tools
as
cgtools
from
megengine.core._trace_option
import
set_symbolic_shape
from
megengine.experimental.traced_module
import
trace_module
from
megengine.jit
import
trace
set_symbolic_shape
(
True
)
class
Main
(
M
.
Module
):
def
forward
(
self
,
x
):
return
x
[
"data"
]
class
PreProcess
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
A
=
F
.
zeros
((
1
,))
self
.
I
=
F
.
ones
((
1
,))
self
.
bb_out
=
mge
.
tensor
(
np
.
array
([[[
0
,
0
],
[
160
,
0
],
[
160
,
48
],
[
0
,
48
]]],
dtype
=
"float32"
)
)
def
forward
(
self
,
data
,
quad
):
"""
data: (1, 3, 48, 160)
quad: (1, 4, 2)
"""
N
=
quad
.
shape
[
0
]
dst
=
F
.
repeat
(
self
.
bb_out
,
N
,
axis
=
0
).
reshape
(
-
1
,
4
,
2
)
I
=
F
.
broadcast_to
(
self
.
I
,
quad
.
shape
)
A
=
F
.
broadcast_to
(
self
.
A
,
(
N
,
8
,
8
))
A
[:,
0
:
4
,
0
:
2
]
=
quad
A
[:,
4
:
8
,
5
:
6
]
=
I
[:,
:,
0
:
1
]
A
[:,
0
:
4
,
6
:
8
]
=
-
quad
*
dst
[:,
:,
0
:
1
]
A
[:,
4
:
8
,
3
:
5
]
=
quad
A
[:,
0
:
4
,
2
:
3
]
=
I
[:,
:,
0
:
1
]
A
[:,
4
:
8
,
6
:
8
]
=
-
quad
*
dst
[:,
:,
1
:
2
]
B
=
dst
.
transpose
(
0
,
2
,
1
).
reshape
(
-
1
,
8
,
1
)
M
=
F
.
concat
([
F
.
matmul
(
F
.
matinv
(
A
),
B
)[:,
:,
0
],
I
[:,
0
:
1
,
0
]],
axis
=
1
).
reshape
(
-
1
,
3
,
3
)
new_data
=
F
.
warp_perspective
(
data
,
M
,
(
48
,
160
))
# (N, 3, 48, 160)
return
{
"data"
:
new_data
}
class
Net
(
M
.
Module
):
def
__init__
(
self
,
traced_module
):
super
().
__init__
()
self
.
pre_process
=
PreProcess
()
self
.
traced_module
=
traced_module
def
forward
(
self
,
data
,
quad
):
x
=
self
.
pre_process
(
data
,
quad
)
x
=
self
.
traced_module
(
x
)
return
x
def
test_preprocess
():
batch_size
=
2
module
=
Main
()
data
=
mge
.
tensor
(
np
.
random
.
randint
(
0
,
256
,
size
=
(
batch_size
,
3
,
48
,
160
)),
dtype
=
np
.
float32
)
traced_module
=
trace_module
(
module
,
{
"data"
:
data
})
obj
=
pickle
.
dumps
(
traced_module
)
traced_module
=
pickle
.
loads
(
obj
)
module
=
Net
(
traced_module
)
module
.
eval
()
quad
=
mge
.
tensor
(
np
.
random
.
normal
(
size
=
(
batch_size
,
4
,
2
)),
dtype
=
np
.
float32
)
expect
=
module
(
data
,
quad
)
traced_module
=
trace_module
(
module
,
data
,
quad
)
actual
=
traced_module
(
data
,
quad
)
for
i
,
j
in
zip
(
expect
,
actual
):
np
.
testing
.
assert_array_equal
(
i
,
j
)
func
=
trace
(
traced_module
,
capture_as_const
=
True
)
actual
=
func
(
data
,
quad
)
for
i
,
j
in
zip
(
expect
,
actual
):
np
.
testing
.
assert_array_equal
(
i
,
j
)
model
=
io
.
BytesIO
()
func
.
dump
(
model
,
arg_names
=
(
"data"
,
"quad"
))
model
.
seek
(
0
)
infer_cg
=
cgtools
.
GraphInference
(
model
)
actual
=
list
(
infer_cg
.
run
(
inp_dict
=
{
"data"
:
data
.
numpy
(),
"quad"
:
quad
.
numpy
()}).
values
()
)[
0
]
np
.
testing
.
assert_allclose
(
expect
,
actual
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录