提交 7a023c05 编写于 作者: M Megvii Engine Team

feat(mge/traced_module): add optimization api

GitOrigin-RevId: eaa74026404a8b3fa7a8d33826612e859a445b0c
上级 6c692b26
......@@ -8,6 +8,7 @@
from ..core._imperative_rt.core2 import set_cpp_apply_module_trace
from . import compat
from ._passes import optimize
from .traced_module import (
TracedModule,
_register_all_builtin_module,
......@@ -19,3 +20,11 @@ from .traced_module import (
_register_all_builtin_module()
set_cpp_apply_module_trace(cpp_apply_module_trace)
__all__ = {
"register_as_builtin",
"trace_module",
"wrap",
"TracedModule",
"optimize",
}
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from . import const_pass, fold_scale_pass, fuse_pass
from .optimization import optimize
__all__ = ["optimize"]
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy
from typing import List, Set
from ...logger import get_logger
from ..traced_module import TracedModule
from .pass_base import get_default_pass_context, get_registered_pass
logger = get_logger(__name__)
def optimize(
module: TracedModule, enabled_pass: List[str] = ["FuseConvBn"],
) -> TracedModule:
r"""Performs a set of optimization passes to optimize a `TracedModule` for inference.
The following passes are currently supported:
* FuseConvBn: fuse BN layers into to conv2d
* FuseAddMul: fold adjacent const add or mul binary operations
* BackwardFoldScale: backward fold const scaling into weights of conv2d
Args:
module: the :class:`TracedModule` to be optimized.
enabled_pass: optimization passes to be enabled during optimization.
Default: ["FuseConvBn"]
Returns:
the optimized :class:`TracedModule`.
"""
defalut_passes_list = [
"FuseConvBn",
"FuseAddMul",
]
if isinstance(enabled_pass, str):
enabled_pass = [enabled_pass]
if "BackwardFoldScale" in enabled_pass:
if "FuseConvBn" not in enabled_pass:
logger.warning(
"Since BackwardFoldScale requires FuseConvBn"
", FuseConvBn will be enabled."
)
enabled_pass.append("FuseConvBn")
defalut_passes_list.extend(
["BackwardFoldScale", "FuseAddMul",]
)
pass_ctx = get_default_pass_context()
def run_pass(mod: TracedModule):
for pass_name in defalut_passes_list:
if pass_name in enabled_pass:
pass_func = get_registered_pass(pass_name)()
mod = pass_func(mod, pass_ctx)
return mod
module = deepcopy(module)
module = run_pass(module)
return module
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import types
import numpy as np
import pytest
import megengine as mge
import megengine.functional as F
import megengine.module as M
import megengine.traced_module as tm
class myconv(M.Conv2d):
pass
class mybn(M.BatchNorm2d):
pass
class MyBlock(M.Module):
def __init__(self, conv_cls, bn_cls):
super().__init__()
self.conv = conv_cls(3, 3, 1, 1, 0)
self.bn = bn_cls(3)
self.conv2 = conv_cls(3, 3, 1, 1, 0)
self.bn2 = bn_cls(3)
self.scale = mge.Tensor([3, 4])
def forward(self, x):
x1 = self.conv(x)
x1 = self.bn(x1)
x1 = F.relu(x1)
x1 = x1 * self.scale[0]
x2 = self.conv2(x)
x2 = self.bn2(x2)
x2 = F.relu(x2)
x2 = x2 * self.scale[1]
y = x1 + x2
y = y + 4
y = self.scale[0] + y
y = F.relu(y) * 3
return y
class MyModule(M.Module):
def __init__(self, conv_cls, bn_cls):
super().__init__()
self.block_0 = MyBlock(conv_cls, bn_cls)
self.block_1 = MyBlock(conv_cls, bn_cls)
def forward(self, x):
x1 = self.block_0(x)
x2 = self.block_1(x)
y = x1 + x2
y = F.reshape(y, (-1))
y = y * 3
return y
@pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv])
@pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn])
def test_backward_fold_scale(conv_cls, bn_cls):
module = MyModule(conv_cls, bn_cls)
module.eval()
inp = mge.Tensor(np.random.random((1, 3, 32, 32)))
desired = module(inp)
traced_net = tm.trace_module(module, inp)
traced_net = traced_net.flatten()
optimized_net = tm.optimize(traced_net, "BackwardFoldScale")
actual = optimized_net(inp)
np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4)
# fuse all mul to conv
mul_list = optimized_net.graph.get_method_by_type("__mul__").as_list()
assert len(mul_list) == 0
@pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv])
@pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn])
def test_fuse_bn(conv_cls, bn_cls):
module = MyModule(conv_cls, bn_cls)
module.eval()
inp = mge.Tensor(np.random.random((1, 3, 32, 32)))
desired = module(inp)
traced_net = tm.trace_module(module, inp)
traced_net = traced_net.flatten()
optimized_net = tm.optimize(traced_net, "FuseConvBn")
actual = optimized_net(inp)
np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4)
# fuse all mul to conv
bn_list = optimized_net.graph.get_function_by_type(F.batch_norm).as_list()
assert len(bn_list) == 0
bn_list = optimized_net.graph.get_module_by_type(M.BatchNorm2d).as_list()
assert len(bn_list) == 0
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册