Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b28ad4e8
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b28ad4e8
编写于
11月 23, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/traced_module): add pattern match for TracedModule
GitOrigin-RevId: 0af7b076e6740db30fab7126f6f496e88ef91b48
上级
2318ea3f
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
473 addition
and
0 deletion
+473
-0
imperative/python/megengine/traced_module/_passes/matcher.py
imperative/python/megengine/traced_module/_passes/matcher.py
+183
-0
imperative/python/megengine/traced_module/_passes/pattern.py
imperative/python/megengine/traced_module/_passes/pattern.py
+252
-0
imperative/python/megengine/traced_module/_passes/utils.py
imperative/python/megengine/traced_module/_passes/utils.py
+38
-0
未找到文件。
imperative/python/megengine/traced_module/_passes/matcher.py
0 → 100644
浏览文件 @
b28ad4e8
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
collections
import
OrderedDict
,
defaultdict
from
functools
import
partial
from
...logger
import
get_logger
from
..expr
import
(
Expr
,
is_apply_def
,
is_call_function
,
is_call_module
,
is_call_tensor_method
,
is_constant
,
)
from
.pattern
import
(
AnyPattern
,
ApplyDefPattern
,
CallPattern
,
ConstantPattern
,
ExprPattern
,
FunctionPattern
,
ModulePattern
,
OrPattern
,
TensorMethodPattern
,
VarPattern
,
)
from
.utils
import
register_obj
logger
=
get_logger
(
__name__
)
class
PatternMatcher
:
method_dict
=
{}
register_visiter_func
=
partial
(
register_obj
,
_dict
=
method_dict
)
def
__init__
(
self
)
->
None
:
self
.
matched_patterns
=
[]
self
.
matched_exprs
=
OrderedDict
()
def
match
(
self
,
pattern
:
ExprPattern
,
expr
:
Expr
)
->
bool
:
self
.
matched_exprs
.
clear
()
self
.
matched_patterns
.
clear
()
pattern
.
check_users
(
False
)
res
=
self
.
visit_pattern
(
pattern
,
expr
)
if
res
and
not
self
.
_check_users
():
self
.
clear_map
(
0
)
res
=
False
self
.
_clear_pattern_users
()
return
res
def
clear_map
(
self
,
mark
):
for
_
in
range
(
len
(
self
.
matched_patterns
)
-
mark
):
p
=
self
.
matched_patterns
.
pop
()
self
.
matched_exprs
.
pop
(
p
)
p
.
_clear_users
()
def
_clear_pattern_users
(
self
):
for
p
in
self
.
matched_patterns
:
p
.
_clear_users
()
def
_check_users
(
self
)
->
bool
:
for
pat
,
expr
in
self
.
matched_exprs
.
items
():
if
pat
.
_check_users
:
pattern_users
=
pat
.
_users
if
len
(
expr
.
outputs
)
!=
1
:
logger
.
warning
(
"only support single output, and the matching "
"result may be wrong"
)
continue
expr_users
=
expr
.
outputs
[
0
].
users
if
len
(
pattern_users
)
!=
len
(
expr_users
):
return
False
for
pat
,
expr
in
zip
(
pattern_users
,
expr_users
):
if
self
.
matched_exprs
[
pat
]
!=
expr
:
return
False
return
True
def
visit_pattern
(
self
,
pattern
:
ExprPattern
,
expr
:
Expr
)
->
bool
:
if
pattern
in
self
.
matched_exprs
:
if
self
.
matched_exprs
[
pattern
]
is
expr
:
if
isinstance
(
pattern
,
(
OrPattern
)):
assert
self
.
_visit_or_pattern
(
pattern
,
expr
)
==
True
return
True
else
:
return
False
else
:
mark
=
len
(
self
.
matched_patterns
)
visiter
=
self
.
method_dict
.
get
(
type
(
pattern
))
matched
=
visiter
(
self
,
pattern
,
expr
)
if
matched
:
self
.
matched_patterns
.
append
(
pattern
)
self
.
matched_exprs
[
pattern
]
=
expr
else
:
self
.
clear_map
(
mark
)
return
matched
@
register_visiter_func
(
OrPattern
)
def
_visit_or_pattern
(
self
,
pattern
:
OrPattern
,
expr
:
Expr
)
->
bool
:
if
self
.
visit_pattern
(
pattern
.
left
,
expr
):
if
pattern
.
_users
:
pattern
.
left
.
_add_users
(
pattern
.
_users
[
-
1
])
return
True
if
self
.
visit_pattern
(
pattern
.
right
,
expr
):
if
pattern
.
_users
:
pattern
.
right
.
_add_users
(
pattern
.
_users
[
-
1
])
return
True
return
False
@
register_visiter_func
(
CallPattern
)
def
_visit_call_pattern
(
self
,
pattern
:
CallPattern
,
expr
:
Expr
)
->
bool
:
mark
=
len
(
self
.
matched_patterns
)
match_res
=
self
.
visit_pattern
(
pattern
.
op
,
expr
)
if
not
match_res
:
self
.
clear_map
(
mark
)
return
False
inputs
=
expr
.
inputs
if
isinstance
(
pattern
.
op
,
ModulePattern
):
inputs
=
inputs
[
1
:]
if
(
pattern
.
_match_all_args
and
len
(
pattern
.
args
)
!=
len
(
inputs
))
or
(
not
pattern
.
_match_all_args
and
len
(
pattern
.
args
)
>
len
(
inputs
)
):
self
.
clear_map
(
mark
)
return
False
for
i
,
pat
in
enumerate
(
pattern
.
args
):
pat
.
_add_users
(
pattern
)
match_res
=
self
.
visit_pattern
(
pat
,
inputs
[
i
].
expr
)
if
not
match_res
:
pat
.
_clear_users
()
self
.
clear_map
(
mark
)
return
False
return
True
@
register_visiter_func
(
ModulePattern
)
def
_visit_module_pattern
(
self
,
pattern
:
ModulePattern
,
expr
:
Expr
)
->
bool
:
if
not
is_call_module
(
expr
,
pattern
.
target
):
return
False
module
=
expr
.
inputs
[
0
].
owner
for
key
,
target
in
pattern
.
attrs
.
items
():
value
=
getattr
(
module
,
key
,
None
)
if
target
!=
value
:
return
False
return
True
@
register_visiter_func
(
FunctionPattern
)
def
_visit_function_pattern
(
self
,
pattern
:
FunctionPattern
,
expr
:
Expr
)
->
bool
:
if
not
is_call_function
(
expr
,
pattern
.
target
):
return
False
kwargs
=
expr
.
kwargs
for
key
,
target
in
pattern
.
params
.
items
():
value
=
kwargs
.
get
(
key
,
None
)
if
target
!=
value
:
return
False
return
True
@
register_visiter_func
(
TensorMethodPattern
)
def
_visit_tensor_method_pattern
(
self
,
pattern
:
TensorMethodPattern
,
expr
:
Expr
)
->
bool
:
return
is_call_tensor_method
(
expr
,
pattern
.
target
)
@
register_visiter_func
(
ApplyDefPattern
)
def
_visit_apply_pattern
(
self
,
pattern
:
ApplyDefPattern
,
expr
:
Expr
)
->
bool
:
return
is_apply_def
(
expr
,
pattern
.
target
)
@
register_visiter_func
(
ConstantPattern
)
def
_visit_const_pattern
(
self
,
pattern
:
ConstantPattern
,
expr
:
Expr
)
->
bool
:
return
is_constant
(
expr
)
@
register_visiter_func
(
VarPattern
)
def
_visit_var_pattern
(
self
,
pattern
:
VarPattern
,
expr
:
Expr
)
->
bool
:
return
not
is_constant
(
expr
)
@
register_visiter_func
(
AnyPattern
)
def
_visit_any_pattern
(
self
,
pattern
:
AnyPattern
,
expr
:
Expr
)
->
bool
:
return
True
imperative/python/megengine/traced_module/_passes/pattern.py
0 → 100644
浏览文件 @
b28ad4e8
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
abc
import
abstractmethod
from
typing
import
Any
,
Callable
,
Dict
,
List
from
...core._imperative_rt
import
OpDef
from
...logger
import
get_logger
from
...module
import
Module
from
..expr
import
Expr
from
..node
import
Node
logger
=
get_logger
(
__name__
)
class
ExprPattern
:
def
__init__
(
self
):
self
.
_check_users
=
True
self
.
_users
=
[]
def
__call__
(
self
,
*
args
):
args
=
list
(
args
)
if
len
(
args
)
==
1
and
args
[
0
]
is
None
:
args
=
None
return
CallPattern
(
self
,
*
args
)
def
__add__
(
self
,
other
):
return
is_op
(
"__add__"
)(
self
,
other
)
def
__iadd__
(
self
,
other
):
return
is_op
(
"__iadd__"
)(
self
,
other
)
def
__radd__
(
self
,
other
):
return
is_op
(
"__radd__"
)(
self
,
other
)
def
__sub__
(
self
,
other
):
return
is_op
(
"__sub__"
)(
self
,
other
)
def
__isub__
(
self
,
other
):
return
is_op
(
"__isub__"
)(
self
,
other
)
def
__rsub__
(
self
,
other
):
return
is_op
(
"__rsub__"
)(
self
,
other
)
def
__mul__
(
self
,
other
):
return
is_op
(
"__mul__"
)(
self
,
other
)
def
__imul__
(
self
,
other
):
return
is_op
(
"__imul__"
)(
self
,
other
)
def
__rmul__
(
self
,
other
):
return
is_op
(
"__rmul__"
)(
self
,
other
)
def
__truediv__
(
self
,
other
):
return
is_op
(
"__truediv__"
)(
self
,
other
)
def
__itruediv__
(
self
,
other
):
return
is_op
(
"__itruediv__"
)(
self
,
other
)
def
__rtruediv__
(
self
,
other
):
return
is_op
(
"__rtruediv__"
)(
self
,
other
)
def
__or__
(
self
,
other
):
assert
isinstance
(
other
,
ExprPattern
)
return
OrPattern
(
self
,
other
)
def
get_output
(
self
,
index
):
raise
NotImplementedError
def
check_users
(
self
,
check
:
bool
=
True
):
self
.
_check_users
=
check
return
self
def
_add_users
(
self
,
pattern
:
"ExprPattern"
):
self
.
_users
.
append
(
pattern
)
def
_clear_users
(
self
,):
self
.
_users
.
clear
()
def
__getitem__
(
self
,
index
):
return
is_op
(
"__getitem__"
)(
self
,
index
)
def
has_attr
(
self
,
**
attrs
):
logger
.
warning
(
"has_param only support ModulePattern"
)
return
self
def
has_param
(
self
,
**
params
):
logger
.
warning
(
"has_param only support FunctionPattern"
)
return
self
@
abstractmethod
def
__repr__
(
self
)
->
str
:
raise
NotImplementedError
class
CallPattern
(
ExprPattern
):
def
__init__
(
self
,
op
:
ExprPattern
,
*
args
:
List
[
ExprPattern
]):
super
().
__init__
()
self
.
op
=
op
self
.
args
=
list
(
filter
(
lambda
x
:
isinstance
(
x
,
ExprPattern
),
args
))
self
.
_match_all_args
=
True
def
__repr__
(
self
)
->
str
:
return
"{}({})"
.
format
(
self
.
op
,
","
.
join
(
str
(
x
)
for
x
in
self
.
args
))
def
not_all_args
(
self
):
self
.
_match_all_args
=
False
def
check_users
(
self
,
check
:
bool
=
True
):
self
.
_check_users
=
check
self
.
op
.
check_users
(
check
)
return
self
def
_add_users
(
self
,
pattern
:
"ExprPattern"
):
self
.
_users
.
append
(
pattern
)
self
.
op
.
_add_users
(
pattern
)
def
_clear_users
(
self
):
self
.
_users
.
clear
()
self
.
op
.
_clear_users
()
class
OrPattern
(
ExprPattern
):
def
__init__
(
self
,
left
:
ExprPattern
,
right
:
ExprPattern
):
super
().
__init__
()
self
.
left
=
left
self
.
right
=
right
def
__repr__
(
self
)
->
str
:
return
"({}|{})"
.
format
(
self
.
left
,
self
.
right
)
def
check_users
(
self
,
check
:
bool
=
True
):
self
.
_check_users
=
check
self
.
left
.
check_users
(
check
)
self
.
right
.
check_users
(
check
)
return
self
def
_clear_users
(
self
):
self
.
_users
.
clear
()
self
.
left
.
_clear_users
()
self
.
right
.
_clear_users
()
class
GetOutputPaterrn
(
ExprPattern
):
def
__init__
(
self
,
op
,
index
):
super
().
__init__
()
self
.
op
=
op
self
.
index
=
index
def
__repr__
(
self
)
->
str
:
return
"{}[{}]"
.
format
(
self
.
op
,
self
.
index
)
class
ModulePattern
(
ExprPattern
):
def
__init__
(
self
,
module_cls
:
Module
)
->
None
:
super
().
__init__
()
self
.
attrs
=
{}
self
.
target
=
module_cls
def
has_attr
(
self
,
**
attrs
):
self
.
attrs
.
update
(
attrs
)
return
self
def
__repr__
(
self
)
->
str
:
return
"{}"
.
format
(
self
.
target
.
__name__
)
class
FunctionPattern
(
ExprPattern
):
def
__init__
(
self
,
func
:
Callable
):
super
().
__init__
()
self
.
params
=
{}
self
.
target
=
func
def
has_params
(
self
,
**
params
):
self
.
params
.
update
(
params
)
return
self
def
__repr__
(
self
)
->
str
:
return
"{}"
.
format
(
self
.
target
.
__name__
)
class
TensorMethodPattern
(
ExprPattern
):
def
__init__
(
self
,
method
:
str
):
super
().
__init__
()
self
.
target
=
method
def
__repr__
(
self
)
->
str
:
return
self
.
target
class
ApplyDefPattern
(
ExprPattern
):
def
__init__
(
self
,
opdef
:
OpDef
):
super
().
__init__
()
self
.
target
=
opdef
def
__repr__
(
self
)
->
str
:
return
"{}"
.
format
(
self
.
target
.
__name__
)
class
VarPattern
(
ExprPattern
):
def
__init__
(
self
):
super
().
__init__
()
def
__repr__
(
self
)
->
str
:
return
"var"
class
ConstantPattern
(
ExprPattern
):
def
__init__
(
self
):
super
().
__init__
()
def
__repr__
(
self
)
->
str
:
return
"const"
class
AnyPattern
(
ExprPattern
):
def
__init__
(
self
):
super
().
__init__
()
def
__repr__
(
self
)
->
str
:
return
"any"
def
is_op
(
target
):
if
isinstance
(
target
,
type
):
if
issubclass
(
target
,
Module
):
return
ModulePattern
(
target
)
if
issubclass
(
target
,
OpDef
):
return
ApplyDefPattern
(
target
)
elif
callable
(
target
):
return
FunctionPattern
(
target
)
elif
isinstance
(
target
,
str
):
return
TensorMethodPattern
(
target
)
else
:
raise
ValueError
(
"not support"
)
def
is_const
():
return
ConstantPattern
().
check_users
(
False
)
def
any_node
():
return
AnyPattern
()
def
is_var
():
return
VarPattern
()
imperative/python/megengine/traced_module/_passes/utils.py
0 → 100644
浏览文件 @
b28ad4e8
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
copy
from
typing
import
Any
,
Dict
,
List
from
..expr
import
Expr
,
is_constant
,
is_getattr
from
..node
import
Node
,
TensorNode
def
register_obj
(
objs
:
List
[
Any
],
_dict
:
Dict
):
if
not
isinstance
(
objs
,
List
):
objs
=
[
objs
]
def
_register
(
any_obj
:
Any
):
for
obj
in
objs
:
_dict
[
obj
]
=
any_obj
return
any_obj
return
_register
def
get_const_value
(
expr
:
Expr
,
fall_back
:
Any
=
None
):
value
=
fall_back
if
isinstance
(
expr
,
Node
):
expr
=
expr
.
expr
if
is_getattr
(
expr
)
and
isinstance
(
expr
.
outputs
[
0
],
TensorNode
):
module
=
expr
.
inputs
[
0
].
owner
assert
module
is
not
None
value
=
copy
.
deepcopy
(
expr
.
interpret
(
module
)[
0
])
elif
is_constant
(
expr
):
value
=
copy
.
deepcopy
(
expr
.
interpret
()[
0
])
return
value
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录