提交 b28ad4e8 编写于 作者: M Megvii Engine Team

feat(mge/traced_module): add pattern match for TracedModule

GitOrigin-RevId: 0af7b076e6740db30fab7126f6f496e88ef91b48
上级 2318ea3f
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from collections import OrderedDict, defaultdict
from functools import partial
from ...logger import get_logger
from ..expr import (
Expr,
is_apply_def,
is_call_function,
is_call_module,
is_call_tensor_method,
is_constant,
)
from .pattern import (
AnyPattern,
ApplyDefPattern,
CallPattern,
ConstantPattern,
ExprPattern,
FunctionPattern,
ModulePattern,
OrPattern,
TensorMethodPattern,
VarPattern,
)
from .utils import register_obj
logger = get_logger(__name__)
class PatternMatcher:
method_dict = {}
register_visiter_func = partial(register_obj, _dict=method_dict)
def __init__(self) -> None:
self.matched_patterns = []
self.matched_exprs = OrderedDict()
def match(self, pattern: ExprPattern, expr: Expr) -> bool:
self.matched_exprs.clear()
self.matched_patterns.clear()
pattern.check_users(False)
res = self.visit_pattern(pattern, expr)
if res and not self._check_users():
self.clear_map(0)
res = False
self._clear_pattern_users()
return res
def clear_map(self, mark):
for _ in range(len(self.matched_patterns) - mark):
p = self.matched_patterns.pop()
self.matched_exprs.pop(p)
p._clear_users()
def _clear_pattern_users(self):
for p in self.matched_patterns:
p._clear_users()
def _check_users(self) -> bool:
for pat, expr in self.matched_exprs.items():
if pat._check_users:
pattern_users = pat._users
if len(expr.outputs) != 1:
logger.warning(
"only support single output, and the matching "
"result may be wrong"
)
continue
expr_users = expr.outputs[0].users
if len(pattern_users) != len(expr_users):
return False
for pat, expr in zip(pattern_users, expr_users):
if self.matched_exprs[pat] != expr:
return False
return True
def visit_pattern(self, pattern: ExprPattern, expr: Expr) -> bool:
if pattern in self.matched_exprs:
if self.matched_exprs[pattern] is expr:
if isinstance(pattern, (OrPattern)):
assert self._visit_or_pattern(pattern, expr) == True
return True
else:
return False
else:
mark = len(self.matched_patterns)
visiter = self.method_dict.get(type(pattern))
matched = visiter(self, pattern, expr)
if matched:
self.matched_patterns.append(pattern)
self.matched_exprs[pattern] = expr
else:
self.clear_map(mark)
return matched
@register_visiter_func(OrPattern)
def _visit_or_pattern(self, pattern: OrPattern, expr: Expr) -> bool:
if self.visit_pattern(pattern.left, expr):
if pattern._users:
pattern.left._add_users(pattern._users[-1])
return True
if self.visit_pattern(pattern.right, expr):
if pattern._users:
pattern.right._add_users(pattern._users[-1])
return True
return False
@register_visiter_func(CallPattern)
def _visit_call_pattern(self, pattern: CallPattern, expr: Expr) -> bool:
mark = len(self.matched_patterns)
match_res = self.visit_pattern(pattern.op, expr)
if not match_res:
self.clear_map(mark)
return False
inputs = expr.inputs
if isinstance(pattern.op, ModulePattern):
inputs = inputs[1:]
if (pattern._match_all_args and len(pattern.args) != len(inputs)) or (
not pattern._match_all_args and len(pattern.args) > len(inputs)
):
self.clear_map(mark)
return False
for i, pat in enumerate(pattern.args):
pat._add_users(pattern)
match_res = self.visit_pattern(pat, inputs[i].expr)
if not match_res:
pat._clear_users()
self.clear_map(mark)
return False
return True
@register_visiter_func(ModulePattern)
def _visit_module_pattern(self, pattern: ModulePattern, expr: Expr) -> bool:
if not is_call_module(expr, pattern.target):
return False
module = expr.inputs[0].owner
for key, target in pattern.attrs.items():
value = getattr(module, key, None)
if target != value:
return False
return True
@register_visiter_func(FunctionPattern)
def _visit_function_pattern(self, pattern: FunctionPattern, expr: Expr) -> bool:
if not is_call_function(expr, pattern.target):
return False
kwargs = expr.kwargs
for key, target in pattern.params.items():
value = kwargs.get(key, None)
if target != value:
return False
return True
@register_visiter_func(TensorMethodPattern)
def _visit_tensor_method_pattern(
self, pattern: TensorMethodPattern, expr: Expr
) -> bool:
return is_call_tensor_method(expr, pattern.target)
@register_visiter_func(ApplyDefPattern)
def _visit_apply_pattern(self, pattern: ApplyDefPattern, expr: Expr) -> bool:
return is_apply_def(expr, pattern.target)
@register_visiter_func(ConstantPattern)
def _visit_const_pattern(self, pattern: ConstantPattern, expr: Expr) -> bool:
return is_constant(expr)
@register_visiter_func(VarPattern)
def _visit_var_pattern(self, pattern: VarPattern, expr: Expr) -> bool:
return not is_constant(expr)
@register_visiter_func(AnyPattern)
def _visit_any_pattern(self, pattern: AnyPattern, expr: Expr) -> bool:
return True
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod
from typing import Any, Callable, Dict, List
from ...core._imperative_rt import OpDef
from ...logger import get_logger
from ...module import Module
from ..expr import Expr
from ..node import Node
logger = get_logger(__name__)
class ExprPattern:
def __init__(self):
self._check_users = True
self._users = []
def __call__(self, *args):
args = list(args)
if len(args) == 1 and args[0] is None:
args = None
return CallPattern(self, *args)
def __add__(self, other):
return is_op("__add__")(self, other)
def __iadd__(self, other):
return is_op("__iadd__")(self, other)
def __radd__(self, other):
return is_op("__radd__")(self, other)
def __sub__(self, other):
return is_op("__sub__")(self, other)
def __isub__(self, other):
return is_op("__isub__")(self, other)
def __rsub__(self, other):
return is_op("__rsub__")(self, other)
def __mul__(self, other):
return is_op("__mul__")(self, other)
def __imul__(self, other):
return is_op("__imul__")(self, other)
def __rmul__(self, other):
return is_op("__rmul__")(self, other)
def __truediv__(self, other):
return is_op("__truediv__")(self, other)
def __itruediv__(self, other):
return is_op("__itruediv__")(self, other)
def __rtruediv__(self, other):
return is_op("__rtruediv__")(self, other)
def __or__(self, other):
assert isinstance(other, ExprPattern)
return OrPattern(self, other)
def get_output(self, index):
raise NotImplementedError
def check_users(self, check: bool = True):
self._check_users = check
return self
def _add_users(self, pattern: "ExprPattern"):
self._users.append(pattern)
def _clear_users(self,):
self._users.clear()
def __getitem__(self, index):
return is_op("__getitem__")(self, index)
def has_attr(self, **attrs):
logger.warning("has_param only support ModulePattern")
return self
def has_param(self, **params):
logger.warning("has_param only support FunctionPattern")
return self
@abstractmethod
def __repr__(self) -> str:
raise NotImplementedError
class CallPattern(ExprPattern):
def __init__(self, op: ExprPattern, *args: List[ExprPattern]):
super().__init__()
self.op = op
self.args = list(filter(lambda x: isinstance(x, ExprPattern), args))
self._match_all_args = True
def __repr__(self) -> str:
return "{}({})".format(self.op, ",".join(str(x) for x in self.args))
def not_all_args(self):
self._match_all_args = False
def check_users(self, check: bool = True):
self._check_users = check
self.op.check_users(check)
return self
def _add_users(self, pattern: "ExprPattern"):
self._users.append(pattern)
self.op._add_users(pattern)
def _clear_users(self):
self._users.clear()
self.op._clear_users()
class OrPattern(ExprPattern):
def __init__(self, left: ExprPattern, right: ExprPattern):
super().__init__()
self.left = left
self.right = right
def __repr__(self) -> str:
return "({}|{})".format(self.left, self.right)
def check_users(self, check: bool = True):
self._check_users = check
self.left.check_users(check)
self.right.check_users(check)
return self
def _clear_users(self):
self._users.clear()
self.left._clear_users()
self.right._clear_users()
class GetOutputPaterrn(ExprPattern):
def __init__(self, op, index):
super().__init__()
self.op = op
self.index = index
def __repr__(self) -> str:
return "{}[{}]".format(self.op, self.index)
class ModulePattern(ExprPattern):
def __init__(self, module_cls: Module) -> None:
super().__init__()
self.attrs = {}
self.target = module_cls
def has_attr(self, **attrs):
self.attrs.update(attrs)
return self
def __repr__(self) -> str:
return "{}".format(self.target.__name__)
class FunctionPattern(ExprPattern):
def __init__(self, func: Callable):
super().__init__()
self.params = {}
self.target = func
def has_params(self, **params):
self.params.update(params)
return self
def __repr__(self) -> str:
return "{}".format(self.target.__name__)
class TensorMethodPattern(ExprPattern):
def __init__(self, method: str):
super().__init__()
self.target = method
def __repr__(self) -> str:
return self.target
class ApplyDefPattern(ExprPattern):
def __init__(self, opdef: OpDef):
super().__init__()
self.target = opdef
def __repr__(self) -> str:
return "{}".format(self.target.__name__)
class VarPattern(ExprPattern):
def __init__(self):
super().__init__()
def __repr__(self) -> str:
return "var"
class ConstantPattern(ExprPattern):
def __init__(self):
super().__init__()
def __repr__(self) -> str:
return "const"
class AnyPattern(ExprPattern):
def __init__(self):
super().__init__()
def __repr__(self) -> str:
return "any"
def is_op(target):
if isinstance(target, type):
if issubclass(target, Module):
return ModulePattern(target)
if issubclass(target, OpDef):
return ApplyDefPattern(target)
elif callable(target):
return FunctionPattern(target)
elif isinstance(target, str):
return TensorMethodPattern(target)
else:
raise ValueError("not support")
def is_const():
return ConstantPattern().check_users(False)
def any_node():
return AnyPattern()
def is_var():
return VarPattern()
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
from typing import Any, Dict, List
from ..expr import Expr, is_constant, is_getattr
from ..node import Node, TensorNode
def register_obj(objs: List[Any], _dict: Dict):
if not isinstance(objs, List):
objs = [objs]
def _register(any_obj: Any):
for obj in objs:
_dict[obj] = any_obj
return any_obj
return _register
def get_const_value(expr: Expr, fall_back: Any = None):
value = fall_back
if isinstance(expr, Node):
expr = expr.expr
if is_getattr(expr) and isinstance(expr.outputs[0], TensorNode):
module = expr.inputs[0].owner
assert module is not None
value = copy.deepcopy(expr.interpret(module)[0])
elif is_constant(expr):
value = copy.deepcopy(expr.interpret()[0])
return value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册