From 7a023c059a49453ba41867fa075c606728e3728f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 26 Nov 2021 15:45:23 +0800 Subject: [PATCH] feat(mge/traced_module): add optimization api GitOrigin-RevId: eaa74026404a8b3fa7a8d33826612e859a445b0c --- .../megengine/traced_module/__init__.py | 9 ++ .../traced_module/_passes/__init__.py | 12 ++ .../traced_module/_passes/optimization.py | 70 ++++++++++++ .../test/unit/traced_module/test_passes.py | 106 ++++++++++++++++++ 4 files changed, 197 insertions(+) create mode 100644 imperative/python/megengine/traced_module/_passes/__init__.py create mode 100644 imperative/python/megengine/traced_module/_passes/optimization.py create mode 100644 imperative/python/test/unit/traced_module/test_passes.py diff --git a/imperative/python/megengine/traced_module/__init__.py b/imperative/python/megengine/traced_module/__init__.py index 741ec59d2..4256f5e03 100644 --- a/imperative/python/megengine/traced_module/__init__.py +++ b/imperative/python/megengine/traced_module/__init__.py @@ -8,6 +8,7 @@ from ..core._imperative_rt.core2 import set_cpp_apply_module_trace from . import compat +from ._passes import optimize from .traced_module import ( TracedModule, _register_all_builtin_module, @@ -19,3 +20,11 @@ from .traced_module import ( _register_all_builtin_module() set_cpp_apply_module_trace(cpp_apply_module_trace) + +__all__ = { + "register_as_builtin", + "trace_module", + "wrap", + "TracedModule", + "optimize", +} diff --git a/imperative/python/megengine/traced_module/_passes/__init__.py b/imperative/python/megengine/traced_module/_passes/__init__.py new file mode 100644 index 000000000..c5f7772f0 --- /dev/null +++ b/imperative/python/megengine/traced_module/_passes/__init__.py @@ -0,0 +1,12 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from . import const_pass, fold_scale_pass, fuse_pass +from .optimization import optimize + +__all__ = ["optimize"] diff --git a/imperative/python/megengine/traced_module/_passes/optimization.py b/imperative/python/megengine/traced_module/_passes/optimization.py new file mode 100644 index 000000000..88d47baf1 --- /dev/null +++ b/imperative/python/megengine/traced_module/_passes/optimization.py @@ -0,0 +1,70 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from copy import deepcopy +from typing import List, Set + +from ...logger import get_logger +from ..traced_module import TracedModule +from .pass_base import get_default_pass_context, get_registered_pass + +logger = get_logger(__name__) + + +def optimize( + module: TracedModule, enabled_pass: List[str] = ["FuseConvBn"], +) -> TracedModule: + r"""Performs a set of optimization passes to optimize a `TracedModule` for inference. + + The following passes are currently supported: + + * FuseConvBn: fuse BN layers into to conv2d + * FuseAddMul: fold adjacent const add or mul binary operations + * BackwardFoldScale: backward fold const scaling into weights of conv2d + + Args: + module: the :class:`TracedModule` to be optimized. + enabled_pass: optimization passes to be enabled during optimization. + Default: ["FuseConvBn"] + + Returns: + the optimized :class:`TracedModule`. + """ + + defalut_passes_list = [ + "FuseConvBn", + "FuseAddMul", + ] + + if isinstance(enabled_pass, str): + enabled_pass = [enabled_pass] + + if "BackwardFoldScale" in enabled_pass: + if "FuseConvBn" not in enabled_pass: + logger.warning( + "Since BackwardFoldScale requires FuseConvBn" + ", FuseConvBn will be enabled." + ) + enabled_pass.append("FuseConvBn") + defalut_passes_list.extend( + ["BackwardFoldScale", "FuseAddMul",] + ) + + pass_ctx = get_default_pass_context() + + def run_pass(mod: TracedModule): + for pass_name in defalut_passes_list: + if pass_name in enabled_pass: + pass_func = get_registered_pass(pass_name)() + mod = pass_func(mod, pass_ctx) + return mod + + module = deepcopy(module) + module = run_pass(module) + + return module diff --git a/imperative/python/test/unit/traced_module/test_passes.py b/imperative/python/test/unit/traced_module/test_passes.py new file mode 100644 index 000000000..18dc13266 --- /dev/null +++ b/imperative/python/test/unit/traced_module/test_passes.py @@ -0,0 +1,106 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +import types + +import numpy as np +import pytest + +import megengine as mge +import megengine.functional as F +import megengine.module as M +import megengine.traced_module as tm + + +class myconv(M.Conv2d): + pass + + +class mybn(M.BatchNorm2d): + pass + + +class MyBlock(M.Module): + def __init__(self, conv_cls, bn_cls): + super().__init__() + self.conv = conv_cls(3, 3, 1, 1, 0) + self.bn = bn_cls(3) + self.conv2 = conv_cls(3, 3, 1, 1, 0) + self.bn2 = bn_cls(3) + self.scale = mge.Tensor([3, 4]) + + def forward(self, x): + x1 = self.conv(x) + x1 = self.bn(x1) + x1 = F.relu(x1) + x1 = x1 * self.scale[0] + x2 = self.conv2(x) + x2 = self.bn2(x2) + x2 = F.relu(x2) + x2 = x2 * self.scale[1] + y = x1 + x2 + y = y + 4 + y = self.scale[0] + y + y = F.relu(y) * 3 + return y + + +class MyModule(M.Module): + def __init__(self, conv_cls, bn_cls): + super().__init__() + self.block_0 = MyBlock(conv_cls, bn_cls) + self.block_1 = MyBlock(conv_cls, bn_cls) + + def forward(self, x): + x1 = self.block_0(x) + x2 = self.block_1(x) + y = x1 + x2 + y = F.reshape(y, (-1)) + y = y * 3 + return y + + +@pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv]) +@pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn]) +def test_backward_fold_scale(conv_cls, bn_cls): + module = MyModule(conv_cls, bn_cls) + module.eval() + inp = mge.Tensor(np.random.random((1, 3, 32, 32))) + desired = module(inp) + traced_net = tm.trace_module(module, inp) + + traced_net = traced_net.flatten() + optimized_net = tm.optimize(traced_net, "BackwardFoldScale") + + actual = optimized_net(inp) + np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4) + # fuse all mul to conv + mul_list = optimized_net.graph.get_method_by_type("__mul__").as_list() + assert len(mul_list) == 0 + + +@pytest.mark.parametrize("conv_cls", [M.Conv2d, myconv]) +@pytest.mark.parametrize("bn_cls", [M.BatchNorm2d, mybn]) +def test_fuse_bn(conv_cls, bn_cls): + module = MyModule(conv_cls, bn_cls) + module.eval() + inp = mge.Tensor(np.random.random((1, 3, 32, 32))) + desired = module(inp) + traced_net = tm.trace_module(module, inp) + + traced_net = traced_net.flatten() + optimized_net = tm.optimize(traced_net, "FuseConvBn") + + actual = optimized_net(inp) + np.testing.assert_allclose(desired=desired, actual=actual, atol=1e-4) + # fuse all mul to conv + bn_list = optimized_net.graph.get_function_by_type(F.batch_norm).as_list() + assert len(bn_list) == 0 + + bn_list = optimized_net.graph.get_module_by_type(M.BatchNorm2d).as_list() + assert len(bn_list) == 0 -- GitLab