提交 6c692b26 编写于 作者: M Megvii Engine Team

feat(mge/traced_module): add some fuse passes

GitOrigin-RevId: 065f9df32eaead53544989c826910f8c326ba738
上级 b28ad4e8
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import functional as F
from ... import module as M
from ...core.ops.builtin import GetVarShape
from ...logger import get_logger
from ...tensor import Tensor
from ..expr import Constant, Expr, is_apply_def, is_constant, is_getattr
from ..node import Node, TensorNode
from .matcher import PatternMatcher
from .pass_base import BackwardPass, ForwardPass, register_pass
from .pattern import is_op
from .utils import get_const_value
logger = get_logger(__name__)
@register_pass("AttrToConstant")
class AttrToConstant(BackwardPass):
r"""Convert :class:`~.GetAttr` to :class:`~.Constant` expr."""
name = "AttrToConstant"
run_once = True
def run_transform(self, expr: Expr):
if not (is_getattr(expr) and isinstance(expr.outputs[0], TensorNode)):
return expr
graph = expr.top_graph
value = get_const_value(expr)
orig_node = expr.outputs[0]
name = orig_node.name
with graph.insert_exprs(expr):
const_node = Constant.make(value, name=name)
graph.replace_node({orig_node: const_node})
graph.compile()
name = orig_node.name
return const_node.expr
@register_pass("FixInputShape")
class FixInputShape(BackwardPass):
name = "FixInputShape"
run_once = True
def run_transform(self, expr: Expr):
if not is_apply_def(expr, GetVarShape):
return expr
shape = Tensor(expr.inputs[0].shape, dtype="int32")
graph = expr.top_graph
with graph.insert_exprs(expr):
const_shape = Constant.make(shape)
graph.replace_node({expr.outputs[0]: const_shape})
graph.compile()
const_shape.name = expr.outputs[0].name
return const_shape.expr
@register_pass("FlodConstant")
class FlodConstant(ForwardPass):
r"""Constant folding."""
name = "FlodConstant"
required_pass = ["AttrToConstant"]
run_once = False
def run_transform(self, expr: Expr):
if len(expr.inputs) == 0 or any(not is_constant(n.expr) for n in expr.inputs):
return expr
const_var = expr.interpret(*[get_const_value(n.expr) for n in expr.inputs])[0]
graph = expr.top_graph
with graph.insert_exprs(expr):
const_node = Constant.make(const_var)
graph.replace_node({expr.outputs[0]: const_node})
graph.compile()
const_node.name = expr.outputs[0].name
return const_node.expr
@register_pass("NormElemWise")
class NormElemWise(BackwardPass):
r"""Transform add/sub or mul/div expr to add-only or mul-only chains.
For example, the following code
.. code-block::
b = 1 - a
c = 2 * b
d = 1 / c
will be changed to
.. code-block::
a1 = F.neg(a)
b = a1 + 1
c = b * 2
d = F.pow(d, -1)
"""
name = "NormElemWise"
required_pass = ["FlodConstant"]
run_once = False
def __init__(self,):
super().__init__()
self.pattern = is_op(F.add)
for op in [F.sub, F.mul, F.div]:
self.pattern |= is_op(op)
for op in ["__add__", "__iadd__", "__radd__"]:
self.pattern |= is_op(op)
for op in ["__sub__", "__isub__", "__rsub__"]:
self.pattern |= is_op(op)
for op in ["__mul__", "__imul__", "__rmul__"]:
self.pattern |= is_op(op)
for op in ["__truediv__", "__itruediv__", "__rtruediv__"]:
self.pattern |= is_op(op)
def run_transform(self, expr: Expr):
matcher = PatternMatcher()
if not matcher.match(self.pattern, expr):
return expr
pattern = matcher.matched_patterns[0]
target = pattern.target
cofee, left_node, right_node = 1, None, None
if len(expr.inputs) == 1 and target not in ["__add__", "__mul__"]:
left_node = expr.inputs[0]
right_node = expr.const_val[0][-1]
if target in ["__rsub__", "__rtruediv__"]:
cofee = -1
if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]:
cofee = -1
elif len(expr.inputs) == 2 and (
target not in ["__add__", "__mul__"] or is_constant(expr.inputs[0].expr)
):
left_node, right_node = expr.inputs
if target in ["__rsub__", "__rtruediv__"]:
left_node, right_node = right_node, left_node
if target in [F.sub, F.div] and left_node is not expr.kwargs["x"]:
left_node, right_node = right_node, left_node
if is_constant(left_node.expr):
left_node, right_node = right_node, left_node
cofee = -1
if left_node is None:
return expr
if isinstance(right_node, TensorNode):
right_node = get_const_value(right_node.expr, right_node)
graph = expr.top_graph
with graph.insert_exprs():
if target in ["__mul__", "__imul__", "__rmul__", F.mul]:
out_node = left_node * right_node
elif target in ["__add__", "__iadd__", "__radd__", F.add]:
out_node = left_node + right_node
elif target in ["__sub__", "__isub__", "__rsub__", F.sub]:
if cofee == -1:
left_node = F.neg(left_node)
else:
if isinstance(right_node, TensorNode):
right_node = F.neg(right_node)
else:
right_node = -1 * right_node
out_node = left_node + right_node
elif target in ["__truediv__", "__itruediv__", "__rtruediv__", F.div]:
if cofee == -1:
left_node = F.pow(left_node, -1)
else:
if isinstance(right_node, TensorNode):
right_node = F.pow(right_node, -1)
else:
right_node = 1 / right_node
out_node = left_node * right_node
graph.replace_node({expr.outputs[0]: out_node})
graph.compile()
return out_node.expr
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from collections import OrderedDict, defaultdict
from copy import deepcopy
from typing import Any, Dict, List, Set
from ... import functional as F
from ... import module as M
from ...core.ops.builtin import GetVarShape
from ...logger import get_logger
from ...tensor import Parameter, Tensor
from ..expr import (
Expr,
is_apply_def,
is_call_function,
is_call_module,
is_call_tensor_method,
is_constant,
is_getattr,
)
from ..traced_module import InternalGraph
from ..utils import assign_attr, get_subattr
from .matcher import PatternMatcher
from .pass_base import BackwardPass, register_pass
from .pattern import is_const, is_op, is_var
from .utils import get_const_value
logger = get_logger(__name__)
@register_pass("BackwardFoldScale")
class BackwardFoldScale(BackwardPass):
r"""Backward fold const scaling into weights of conv2d.
For example, the following code
.. code-block::
x = conv(x, w, b)
x = relu(x)
x1 = x + 3
x2 = x + 4
y = (x1 + x2) * 3
will be changed to
.. code-block::
x = conv(x, w * 3, b * 3)
x = relu(x)
x1 = x + 9
x2 = x + 12
y = x1 + x2
"""
name = "BackwardFoldScale"
required_pass = ["AttrToConstant", "NormElemWise"]
run_once = True
def __init__(self):
super().__init__()
# todo : supoort more axis
self.scale_message = OrderedDict()
self.used_names = defaultdict(int)
def run_transform(self, expr: Expr) -> Expr:
if expr not in self.scale_message:
return expr
var = is_var().check_users(False)
mul_const_pattern = var * is_const() | var * "*" | is_op(F.neg)
add_const_pattern = var + is_const() | var + "*"
conv_pattern = is_op(F.conv2d) | is_op(M.Conv2d)
pattern = conv_pattern | add_const_pattern | mul_const_pattern
macther = PatternMatcher()
if not macther.match(pattern, expr):
return expr
macther_exprs = macther.matched_exprs
if conv_pattern in macther_exprs:
return self.fold_conv_mul(expr)
if mul_const_pattern in macther_exprs:
return self.fold_mul(expr)
if add_const_pattern in macther_exprs:
return self.fold_add_mul(expr)
return expr
def fold_add_mul(self, expr: Expr):
if self.scale_message[expr] is None:
return expr
scale = self.scale_message[expr]
if len(expr.inputs) == 1:
const = expr.const_val[0][-1]
else:
const = get_const_value(expr.inputs[1])
const = const * scale
inp_node = expr.inputs[0]
graph = expr.top_graph
with graph.insert_exprs():
add_node = inp_node + const
graph.replace_node({expr.outputs[0]: add_node})
graph.compile()
add_node.name = expr.outputs[0].name
return add_node.expr
def fold_mul(self, expr: Expr):
if self.scale_message[expr] is None:
return expr
graph = expr.top_graph
graph.replace_node({expr.outputs[0]: expr.inputs[0]})
graph.compile()
return expr
def fold_conv_mul(self, expr: Expr):
graph = expr.top_graph
scale = self.scale_message[expr]
if scale is None:
return expr
if is_call_function(expr, F.conv2d):
named_args = expr.named_args
weight = get_const_value(named_args["weight"], named_args["weight"]) * scale
bias = get_const_value(named_args["bias"], named_args["bias"]) * scale
named_args["weight"] = weight
named_args["bias"] = bias
with graph.insert_exprs():
out_node = F.conv2d(**named_args)
graph.replace_node({expr.outputs[0]: out_node})
graph.compile()
out_node.name = expr.outputs[0].name
return out_node.expr
else:
mnode = expr.inputs[0]
attr_name = expr.inputs[0].expr.name
graph = expr.top_graph
if len(mnode.users) > 1:
self.used_names[mnode.qualname] += 1
attr_name = "{}_{}".format(attr_name, self.used_names[mnode.qualname])
logger.warning(
"{} is used {} times and its name will be reset to {}.{}".format(
mnode.qualname, len(mnode.users), graph.qualname, attr_name
)
)
conv_module = mnode.owner
if len(mnode.users) > 1:
conv_module = deepcopy(conv_module)
conv_module._name = None
conv_module.weight = Parameter(conv_module.weight * scale)
if conv_module.bias is not None:
conv_module.bias = Parameter(conv_module.bias * scale)
if len(mnode.users) > 1:
self_node = mnode.expr.inputs[0]
assign_attr(conv_module, self_node.owner, attr_name)
with graph.insert_exprs(mnode.expr):
new_conv_node = get_subattr(self_node, attr_name)
expr.replace_inputs({mnode: new_conv_node})
return expr
def reset_expr_message_to_none(
self, expr: Expr, scale_message: Dict[Expr, Any], skip_exprs: Set[Expr],
):
if expr in skip_exprs:
return
scale_message[expr] = None
if is_call_function(expr, F.conv2d) or is_call_module(expr, M.Conv2d):
return
for out_node in expr.outputs:
for user in out_node.users:
if user in scale_message:
self.reset_expr_message_to_none(user, scale_message, skip_exprs)
def before_visit_graph(self, graph: InternalGraph):
var = is_var().check_users(False)
mul_const_pattern = var * is_const() | var * "*" | is_op(F.neg)
relu_pattern = (
is_op(F.relu) | is_op(M.ReLU) | is_op(F.leaky_relu) | is_op(M.LeakyReLU)
)
# The param of conv must be const, not support dynamic conv
conv_pattern = (
is_op(F.conv2d)(var, is_const(), is_const())
| is_op(F.conv2d)(var, is_const())
| is_op(M.Conv2d)
)
pattern = mul_const_pattern | relu_pattern | conv_pattern
for op in [
"__add__",
F.reshape,
"reshape",
F.transpose,
"tranpose",
F.min,
"min",
F.max,
"max",
F.max_pool2d,
M.MaxPool2d,
F.avg_pool2d,
M.AvgPool2d,
F.adaptive_avg_pool2d,
M.AdaptiveAvgPool2d,
F.adaptive_max_pool2d,
M.AdaptiveMaxPool2d,
F.expand_dims,
F.concat,
"__getitem__",
]:
pattern |= is_op(op)
matcher = PatternMatcher()
scale_message = OrderedDict()
mem_conv_scale_message = OrderedDict()
skip_exprs = self.init_skip_exprs(graph)
for expr in reversed(graph._exprs):
if expr in skip_exprs:
continue
if len(expr.outputs) > 1 or not matcher.match(pattern, expr):
self.reset_expr_message_to_none(expr, scale_message, skip_exprs)
if is_call_function(expr, F.conv2d):
for user in expr.outputs[0].users:
self.reset_expr_message_to_none(user, scale_message, skip_exprs)
continue
matched_exprs = matcher.matched_exprs
const = None
if mul_const_pattern in matched_exprs:
if is_call_function(expr, F.neg):
const = -1
elif len(expr.inputs) == 1:
const = expr.const_val[0][-1]
else:
const = get_const_value(expr.inputs[1])
if isinstance(const, Tensor) and const._tuple_shape not in [(1,), tuple()]:
self.reset_expr_message_to_none(expr, scale_message, skip_exprs)
continue
users_const = [
scale_message[e] for e in expr.outputs[0].users if e not in skip_exprs
]
if len(users_const) == 0:
scale_message[expr] = const
continue
if any(c is None or c != users_const[0] for c in users_const):
self.reset_expr_message_to_none(expr, scale_message, skip_exprs)
scale_message[expr] = const
continue
const = 1 if const is None else const
const = const * users_const[0]
if relu_pattern in matched_exprs and const < 0:
self.reset_expr_message_to_none(expr, scale_message, skip_exprs)
continue
if conv_pattern in matched_exprs:
self.reset_expr_message_to_none(expr, scale_message, skip_exprs)
mem_conv_scale_message[expr] = const
continue
scale_message[expr] = const
self.scale_message.update(scale_message)
self.scale_message.update(mem_conv_scale_message)
def init_skip_exprs(self, graph: InternalGraph):
skip_exprs = set()
for expr in graph._exprs:
if is_apply_def(expr, GetVarShape):
skip_exprs.add(expr)
elif is_call_tensor_method(expr, "__getitem__") and expr in skip_exprs:
skip_exprs.add(expr)
elif is_getattr(expr):
skip_exprs.add(expr)
elif is_constant(expr):
skip_exprs.add(expr)
elif all(n.expr in skip_exprs for n in expr.inputs):
skip_exprs.add(expr)
return skip_exprs
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import operator
from collections import defaultdict
from typing import Any, Callable, List
from ... import functional as F
from ... import module as M
from ...logger import get_logger
from ...tensor import Parameter, Tensor
from ...utils.bn_fusion import fold_weight_bias
from ..expr import Expr, is_call_function
from ..utils import assign_attr, get_subattr
from .matcher import PatternMatcher
from .pass_base import BackwardPass, register_pass
from .pattern import ExprPattern, any_node, is_const, is_op, is_var
from .utils import get_const_value, register_obj
logger = get_logger(__name__)
@register_pass("FuseAddMul")
class FuseAddMul(BackwardPass):
"""Fold adjacent const add or mul binary operations.
For example, the following code
.. code-block::
x = x + 1
x = 2 + x
x = x * 4
x = x * 0.25
will be changed to
.. code-block::
x = x + 3
"""
name = "FuseAddMul"
required_pass = ["NormElemWise"]
run_once = False
def __init__(self,):
super().__init__()
def _make_pattern(op_0, op_1) -> ExprPattern:
x = is_var().check_users(False)
if op_0 not in [operator.add, operator.mul]:
op_0 = is_op(op_0)
if op_1 not in [operator.add, operator.mul]:
op_1 = is_op(op_1)
pattern = op_0(x, is_const()) | op_0(x, "*")
pattern = op_1(pattern, is_const()) | op_1(pattern, "*")
return pattern
self.pattern_dict = {}
for op, func in zip([operator.add, F.pow], [self.fold_add, self.fold_pow],):
self.pattern_dict[_make_pattern(op, op)] = func
for op_0 in [F.neg, operator.mul]:
for op_1 in [F.neg, operator.mul]:
self.pattern_dict[_make_pattern(op_0, op_1)] = self.fold_mul
def run_transform(self, expr: Expr):
matcher = PatternMatcher()
for pattern, func in self.pattern_dict.items():
res = matcher.match(pattern, expr)
if res:
break
if not res:
return expr
return func(expr)
def _fold_helper(self, expr: Expr, op_c: Callable, op_t: Callable):
const_0 = self.get_const_value(expr)
# todo: support more shape
if isinstance(const_0, Tensor) and const_0._tuple_shape not in [(1,), tuple()]:
return expr
const_1 = self.get_const_value(expr.inputs[0].expr)
if isinstance(const_1, Tensor) and const_1._tuple_shape not in [(1,), tuple()]:
return expr
inp_node = expr.inputs[0].expr.inputs[0]
const = op_c(const_0, const_1)
graph = expr.top_graph
if (const == 1 and op_t in [operator.pow, operator.mul]) or (
const == 0 and op_t in [operator.add]
):
graph.replace_node({expr.outputs[0]: inp_node})
graph.compile()
return expr
with expr.top_graph.insert_exprs():
out_node = op_t(inp_node, const)
graph.replace_node({expr.outputs[0]: out_node})
graph.compile()
return out_node.expr
def fold_add(self, expr: Expr):
return self._fold_helper(expr, operator.add, operator.add)
def fold_mul(self, expr):
return self._fold_helper(expr, operator.mul, operator.mul)
def fold_pow(self, expr):
return self._fold_helper(expr, operator.mul, F.pow)
def get_const_value(self, expr: Expr):
if is_call_function(expr, F.neg):
return -1
if len(expr.inputs) == 2:
value = get_const_value(expr.inputs[1].expr, None)
assert value is not None, " "
return value
value = expr.const_val[0][-1]
return value
@register_pass("FuseConvBn")
class FuseConvBn(BackwardPass):
r"""Fuse BN layers into conv2d."""
name = "FuseConvBn"
required_pass = ["AttrToConstant"]
run_once = True
def __init__(self):
super().__init__()
self.used_name = defaultdict(int)
def run_transform(self, expr: Expr):
conv_pat_0 = is_op(M.Conv2d)
conv_pat_1 = is_op(F.conv2d)
bn_pat_0 = is_op(M.BatchNorm2d)(conv_pat_0 | conv_pat_1)
bn_pat_1 = is_op(F.batch_norm)
# inp, running_mean, running_var, weight, bias
bn_inps = (
conv_pat_0 | conv_pat_1,
is_const(),
is_const(),
is_const(),
is_const(),
)
bn_pat = (
(bn_pat_1(*bn_inps[:3]))
| (bn_pat_1(*bn_inps[:4]))
| (bn_pat_1(*bn_inps))
| bn_pat_0
)
matcher = PatternMatcher()
if not matcher.match(bn_pat, expr):
return expr
matched_exprs = matcher.matched_exprs
if conv_pat_0 in matched_exprs:
return self.fold_convm_bn(matched_exprs[conv_pat_0], matched_exprs[bn_pat])
else:
return self.fold_convf_bn(matched_exprs[conv_pat_1], matched_exprs[bn_pat])
def fold_convm_bn(self, conv: Expr, bn: Expr):
mnode, inp_node = conv.inputs[:2]
self_node = mnode.expr.inputs[0]
attr_name = conv.inputs[0].expr.name
graph = conv.top_graph
if len(mnode.users) > 1:
self.used_name[mnode.qualname] += 1
attr_name = "{}_{}".format(attr_name, self.used_name[mnode.qualname])
logger.warning(
"{} is used {} times and its name will be reset to {}.{}".format(
mnode.qualname, len(mnode.users), graph.qualname, attr_name
)
)
conv_module = mnode.owner
weight, bias = conv_module.weight, conv_module.bias
mean, var, gamma, beta, eps = self.get_bn_params(bn)
weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps)
new_conv = M.Conv2d(
in_channels=conv_module.in_channels,
out_channels=conv_module.out_channels,
kernel_size=conv_module.kernel_size,
stride=conv_module.stride,
padding=conv_module.padding,
dilation=conv_module.dilation,
groups=conv_module.groups,
bias=conv_module.bias is not None,
conv_mode=conv_module.conv_mode,
compute_mode=conv_module.compute_mode,
name=conv_module.name,
)
new_conv.weight = Parameter(weight)
new_conv.bias = Parameter(bias)
new_conv.training = conv_module.training
assign_attr(new_conv, self_node.owner, attr_name)
with graph.insert_exprs(mnode.expr):
out_node = get_subattr(self_node, attr_name)(inp_node)
graph.replace_node({bn.outputs[0]: out_node})
graph.compile()
out_node.name = conv.outputs[0].name
return out_node.expr
def fold_convf_bn(self, conv: Expr, bn: Expr):
named_args = conv.named_args
weight = get_const_value(named_args["weight"], named_args["weight"])
bias = get_const_value(named_args["bias"], named_args["bias"])
mean, var, gamma, beta, eps = self.get_bn_params(bn)
weight, bias = fold_weight_bias(weight, bias, gamma, beta, mean, var, eps)
named_args["weight"] = weight
named_args["bias"] = bias
graph = conv.top_graph
with graph.insert_exprs():
out_node = F.conv2d(**named_args)
graph.replace_node({bn.outputs[0]: out_node})
graph.compile()
out_node.name = conv.outputs[0].name
return out_node.expr
def get_bn_params(self, bn: Expr):
if is_call_function(bn):
named_args = bn.named_args
mean = get_const_value(
named_args["running_mean"], named_args["running_mean"]
)
var = get_const_value(named_args["running_var"], named_args["running_var"])
gamma = get_const_value(named_args["weight"], named_args["weight"])
beta = get_const_value(named_args["bias"], named_args["bias"])
eps = named_args["eps"]
return mean, var, gamma, beta, eps
else:
bn_module = bn.inputs[0].owner
mean = bn_module.running_mean
var = bn_module.running_var
gamma = bn_module.weight
beta = bn_module.bias
eps = bn_module.eps
return mean, var, gamma, beta, eps
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy
from abc import abstractmethod
from collections import OrderedDict, namedtuple
from functools import partial
from re import T
from typing import Any, Callable, Dict, Iterable, List, Union
from ...logger import get_logger
from ..expr import Expr
from ..traced_module import InternalGraph, TracedModule
from .utils import register_obj
logger = get_logger(__name__)
class PassContext:
def __init__(
self, disabled_pass: Iterable[str] = None, pass_config: Dict[str, Any] = None
):
self._disabled_pass = set()
self._config = pass_config
self._handle = None
if disabled_pass:
self.add_diabled_pass(disabled_pass)
def add_diabled_pass(self, passes: Iterable[str]):
if isinstance(passes, str):
passes = [passes]
for pas in passes:
self._disabled_pass.add(pas)
def pass_enabled(self, pas: Union["BasePass", str]):
pass_name = pas.name if isinstance(pas, BasePass) else pas
return pass_name not in self._disabled_pass
_default_context = PassContext()
def get_default_pass_context():
return _default_context
_pass_dict = OrderedDict()
register_pass = partial(register_obj, _dict=_pass_dict)
def get_registered_pass(pass_name: str):
pas = _pass_dict.get(pass_name, None)
assert (
pas is not None
), "{} is not found, please call `register_pass` to register it".format(pass_name)
return pas
class BasePass:
run_once = True # bool
required_pass = [] # Iterable[str]
name = "" # str
def __init__(self):
super().__init__()
def __call__(
self, mod: TracedModule, pass_ctx: PassContext = get_default_pass_context()
) -> TracedModule:
assert isinstance(pass_ctx, PassContext)
return self.apply_optimization(mod, pass_ctx)
def apply_optimization(
self, mod: TracedModule, pass_ctx: PassContext
) -> TracedModule:
new_mod = mod
for pass_name in self.required_pass + [self.name]:
if not pass_ctx.pass_enabled(pass_name):
logger.warning(
"Since {} is disabled, {} will skipped".format(pass_name, self.name)
)
return mod
for pass_name in self.required_pass:
pass_func = get_registered_pass(pass_name)()
new_mod = pass_func(new_mod, pass_ctx)
iter_num = 1
graph_changed = self.visit_graph(new_mod.graph)
while not self.run_once and graph_changed:
graph_changed = self.visit_graph(new_mod.graph)
iter_num += 1
if iter_num == 100:
break
assert iter_num < 100, "{} was run 100 times, plase check for pass conflict."
return new_mod
@abstractmethod
def visit_graph(self, graph: InternalGraph):
raise NotImplementedError
def before_visit_graph(self, graph: InternalGraph):
pass
def run_transform(self, expr: Expr) -> Expr:
return expr
def __repr__(self) -> str:
return self.name
class ForwardPass(BasePass):
def visit_graph(self, graph: InternalGraph):
class Item:
def __init__(self, expr: Expr, child_expanded: bool = False):
self.expr = expr
self.child_expanded = child_expanded
self.before_visit_graph(graph)
graph_changed = False
queue = [Item(n.expr) for n in graph.outputs]
visited_expr, visited_graph = set(), set()
while queue:
item = queue[-1]
if item.expr in visited_expr:
queue.pop()
elif item.child_expanded:
if item.expr not in graph._exprs:
queue.pop()
continue
new_expr = self.run_transform(item.expr)
if new_expr is not item.expr:
graph_changed = True
assert new_expr not in visited_expr
queue.append(Item(new_expr))
continue
if (
hasattr(item.expr, "graph")
and item.expr.graph is not None
and item.expr.graph not in visited_graph
):
graph_changed |= self.visit_graph(item.expr.graph)
visited_graph.add(item.expr.graph)
visited_expr.add(item.expr)
else:
item.child_expanded = True
for i in item.expr.inputs:
expr = i.expr
if expr not in queue and expr not in visited_expr:
queue.append(Item(expr))
return graph_changed
class BackwardPass(BasePass):
def visit_graph(self, graph: InternalGraph):
self.before_visit_graph(graph)
graph_changed = False
queue = [n.expr for n in graph.outputs]
visited_expr, visited_graph = set(), set()
while queue:
expr = queue.pop()
if expr not in graph._exprs:
continue
new_expr = self.run_transform(expr)
if new_expr is not expr:
graph_changed = True
queue.append(new_expr)
continue
else:
visited_expr.add(expr)
if (
hasattr(expr, "graph")
and expr.graph is not None
and expr.graph not in visited_graph
):
graph_changed |= self.visit_graph(expr.graph)
visited_graph.add(expr.graph)
for i in expr.inputs:
expr = i.expr
if expr not in queue and expr not in visited_expr:
queue.append(expr)
return graph_changed
......@@ -13,7 +13,7 @@ import inspect
import re
import weakref
from importlib import import_module
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, Iterable, List, Optional, Sequence, Union
from ..core._imperative_rt import OpDef
from ..core._imperative_rt.core2 import Tensor as RawTensor
......@@ -50,20 +50,30 @@ def get_suffix_name(prefix: str, name: str):
return matchd.group(1)
def is_call_module(expr):
def is_call_module(expr, module_cls: Module = None):
return (
isinstance(expr, CallMethod)
and isinstance(expr.inputs[0], ModuleNode)
and expr.method == "__call__"
)
) and (module_cls is None or isinstance(expr.inputs[0].owner, module_cls))
def is_call_tensor_method(expr):
return isinstance(expr, CallMethod) and not is_call_module(expr)
def is_call_tensor_method(expr, method: Iterable[str] = None):
if method and isinstance(method, str):
method = (method,)
return (
isinstance(expr, CallMethod)
and not is_call_module(expr)
and (method is None or any(expr.method == f for f in method))
)
def is_call_function(expr):
return isinstance(expr, CallFunction)
def is_call_function(expr, func: Iterable[Callable] = None):
if func and not isinstance(func, Iterable):
func = (func,)
return isinstance(expr, CallFunction) and (
func is None or any(expr.func == f for f in func)
)
def is_constant(expr):
......@@ -74,8 +84,8 @@ def is_getattr(expr):
return isinstance(expr, GetAttr)
def is_apply_def(expr):
return isinstance(expr, Apply)
def is_apply_def(expr, opdef=None):
return isinstance(expr, Apply) and (opdef is None or isinstance(expr.opdef, opdef))
def is_input(expr):
......
......@@ -78,6 +78,7 @@ class Node:
"The name(%s) is already in use. Please try a different one again."
% (new_name)
)
graph._namespace.unassociate_name_with_obj(self)
self._name = graph._namespace.create_unique_name(new_name, self)
@property
......
......@@ -14,6 +14,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Sequence, Type, Uni
from .. import get_logger
from ..module import Module
from ..tensor import Parameter, Tensor
logger = get_logger(__name__)
......@@ -301,3 +302,26 @@ class _ModuleDict(Module, MutableMapping):
def forward(self):
raise RuntimeError("ModuleList is not callable")
def assign_attr(obj: Union[Module, Tensor], module: Module, target: str):
*prefix, name = target.split(".")
for item in prefix:
module = getattr(module, item)
if not isinstance(module, Module):
raise AttributeError("`{}` is not an Module".format(item))
setattr(module, name, obj)
def get_subattr(module: Module, target: str):
# todo : remove this import
from .node import ModuleNode
if target == "":
return module
*prefix, name = target.split(".")
for item in prefix:
module = getattr(module, item)
if not isinstance(module, (Module, ModuleNode)):
raise AttributeError("`{}` is not an Module".format(item))
return getattr(module, name)
from copy import deepcopy
from ..functional import ones, sqrt, zeros
from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU
from ..tensor import Parameter
_MAP_TO_FUSED_MODULE = {
(Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d,
(Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d,
(Conv2d, BatchNorm2d, False): Conv2d,
(Conv2d, BatchNorm2d, True): ConvBn2d,
(Conv2d, ReLU): ConvRelu2d,
}
def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5):
# get fold bn conv param
kernel_shape = weight.shape
if len(kernel_shape) == 5:
groups, num_features = kernel_shape[0], kernel_shape[1]
else:
groups, num_features = 1, kernel_shape[0]
if gamma is None:
gamma = ones((num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
if beta is None:
beta = zeros((num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
if bn_mean is None:
bn_mean = zeros((1, num_features, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, num_features, 1, 1), dtype="float32")
if bias is None:
bias = zeros((1, num_features, 1, 1), dtype="float32")
bn_istd = 1.0 / sqrt(bn_var + eps)
scale_factor = gamma * bn_istd
if groups == 1:
w_fold = weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1)
b_fold = beta + gamma * (bias - bn_mean) * bn_istd
return w_fold, b_fold
def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
module_key = tuple([type(m) for m in [conv, bn, relu] if m])
if bn:
assert (
conv.training == bn.training
), "Conv and BN both must be in the same mode (train or eval)."
assert (
bn.num_features == conv.out_channels
), "Output channel of Conv2d must match num_features of BatchNorm2d"
module_key = module_key + (conv.training,)
module = _MAP_TO_FUSED_MODULE[module_key](
in_channels=conv.in_channels,
out_channels=conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
bias=conv.bias is not None,
conv_mode=conv.conv_mode,
compute_mode=conv.compute_mode,
name=conv.name,
)
new_conv = module if bn is None or not conv.training else module.conv
weight, bias = conv.weight, conv.bias
if not conv.training and bn is not None:
weight, bias = fold_weight_bias(
weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps,
)
new_conv.weight = Parameter(weight)
if bias is not None:
new_conv.bias = Parameter(bias)
if bn is not None and conv.training:
module.bn = deepcopy(bn)
new_conv.training = conv.training
return module
......@@ -13,20 +13,7 @@ import megengine.quantization as Q
from megengine import Tensor
from megengine.module.qat.module import QATModule
from megengine.traced_module import TracedModule, trace_module
def get_subattr(self: M.Module, name: str):
if name == "":
return self
module_path, _, name = name.rpartition(".")
if module_path == "":
return getattr(self, name)
module_names = module_path.split(".")
for item in module_names:
self = getattr(self, item)
if not isinstance(self, M.Module):
raise AttributeError("`{}` is not an Module".format(item))
return getattr(self, name)
from megengine.traced_module.utils import get_subattr
class MyConvBnRelu2d(M.ConvBnRelu2d):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册