提交 22853fa2 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/quantization): add `mapping` parameter for custom modules

GitOrigin-RevId: a4de4261d0ed3179d3ad236eca8526f0368aaea1
上级 6e70fa7a
......@@ -60,7 +60,7 @@ class Module(metaclass=ABCMeta):
def __init__(self):
# runtime attributes
self.training = True
self.quantize_diabled = False
self.quantize_disabled = False
# hooks
self._forward_pre_hooks = OrderedDict()
......@@ -328,12 +328,12 @@ class Module(metaclass=ABCMeta):
def disable_quantize(self, value=True):
r"""
Set ``module``'s ``quantize_diabled`` attribute and return ``module``.
Set ``module``'s ``quantize_disabled`` attribute and return ``module``.
Could be used as a decorator.
"""
def fn(module: Module) -> None:
module.quantize_diabled = value
module.quantize_disabled = value
self.apply(fn)
......
......@@ -5,7 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy
from copy import copy, deepcopy
from typing import Callable, Dict, Tuple
from .. import module as Float
......@@ -49,19 +49,24 @@ def _get_convert_dict() -> Tuple[
_float2qat_dict, _qat2quantized_dict = _get_convert_dict()
def quantize(module: Module, inplace: bool = True):
def quantize(module: Module, inplace: bool = True, mapping: dict = None):
r"""
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
through :meth:`~.Module.apply`.
:param module: root module to do convert recursively.
:param inplace: whether to convert submodules in-place.
:param mapping: a dict indicating how to convert custom modules from QATModule to
QuantizedModule. Will be combined with internal default convert mapping dict.
"""
if not inplace:
module = deepcopy(module)
qat_modules = tuple(_qat2quantized_dict.keys())
convert_dict = copy(_qat2quantized_dict)
if mapping is not None:
convert_dict.update(mapping)
qat_modules = tuple(convert_dict.keys())
def is_qat(mod: Module):
return isinstance(mod, qat_modules)
......@@ -70,7 +75,7 @@ def quantize(module: Module, inplace: bool = True):
for key, submodule, parent in list(
module._flatten(with_key=True, with_parent=True, predicate=is_qat)
):
new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule)
new_mod = convert_dict[type(submodule)].from_qat_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod
......@@ -81,7 +86,10 @@ def quantize(module: Module, inplace: bool = True):
def quantize_qat(
module: Module, inplace: bool = True, qconfig: QConfig = ema_fakequant_qconfig,
module: Module,
inplace: bool = True,
qconfig: QConfig = ema_fakequant_qconfig,
mapping: dict = None,
):
r"""
Recursively convert float :class:`~.Module` to :class:`~.QATModule`
......@@ -91,12 +99,17 @@ def quantize_qat(
:param inplace: whether to convert submodules in-place.
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is ``ema_fakequant_qconfig``.
:param mapping: a dict indicating how to convert custom modules from Module to QATModule.
Will be combined with internal default convert mapping dict.
"""
if not inplace:
module = deepcopy(module)
quantable_modules = tuple(_float2qat_dict.keys())
convert_dict = copy(_float2qat_dict)
if mapping is not None:
convert_dict.update(mapping)
quantable_modules = tuple(convert_dict.keys())
def is_quantable(mod: Module):
return isinstance(mod, quantable_modules)
......@@ -106,10 +119,10 @@ def quantize_qat(
module._flatten(with_key=True, with_parent=True, predicate=is_quantable)
):
# only convert top quantable module.
if is_quantable(parent) or submodule.quantize_diabled:
if is_quantable(parent) or submodule.quantize_disabled:
continue
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule)
new_mod = convert_dict[type(submodule)].from_float_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod
......
......@@ -52,3 +52,29 @@ def test_disable_quantize():
qat_net = quantize_qat(net, inplace=False)
assert isinstance(qat_net.conv, Float.ConvBnRelu2d)
assert isinstance(qat_net.conv.conv, Float.Conv2d)
def test_convert_with_custom_mapping():
class FloatExample(Float.Module):
def forward(self, x):
return x
class QATExample(QAT.QATModule):
def forward(self, x):
return x
@classmethod
def from_float_module(cls, float_module):
return cls()
class Net(Float.Module):
def __init__(self):
super().__init__()
self.example = FloatExample()
def forward(self, x):
return self.example(x)
net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册