提交 ab913025 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mge/quantization): add `quantize_disabled` attribute in Module

GitOrigin-RevId: f108f03c5a5ac2142d5cb7bb9c24cbe43dce6fe6
上级 f4ead788
......@@ -57,6 +57,7 @@ class Module(metaclass=ABCMeta):
def __init__(self):
self.training = True
self.quantize_diabled = False
@abstractmethod
def forward(self, inputs):
......@@ -312,6 +313,16 @@ class Module(metaclass=ABCMeta):
"""
self.train(False)
def disable_quantize(self, value=True):
r"""
Set ``module``'s ``quantize_diabled`` attribute and return ``module``.
Could be used as a decorator.
"""
def fn(module: Module) -> None:
module.quantize_diabled = value
self.apply(fn)
def state_dict(self, rst=None, prefix="", keep_var=False):
r"""Returns a dictionary containing whole states of the module.
"""
......
......@@ -26,8 +26,6 @@ class QATModule(Module):
def __init__(self):
super().__init__()
self.scale = None
self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer
......
......@@ -6,7 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy
from typing import Dict, Tuple
from typing import Callable, Dict, Tuple
from .. import module as Float
from ..module import Module
......@@ -48,7 +48,7 @@ def _get_convert_dict() -> Tuple[
_float2qat_dict, _qat2quantized_dict = _get_convert_dict()
def quantize(module: Module, inplace=True):
def quantize(module: Module, inplace: bool = True):
r"""
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
through :meth:`~.Module.apply`.
......@@ -80,7 +80,9 @@ def quantize(module: Module, inplace=True):
def quantize_qat(
module: Module, inplace=True, qconfig: QConfig = ema_fakequant_qconfig
module: Module,
inplace: bool = True,
qconfig: QConfig = ema_fakequant_qconfig,
):
r"""
Recursively convert float :class:`~.Module` to :class:`~.QATModule`
......@@ -105,7 +107,7 @@ def quantize_qat(
module._flatten(with_key=True, with_parent=True, predicate=is_quantable)
):
# only convert top quantable module.
if is_quantable(parent):
if is_quantable(parent) or submodule.quantize_diabled:
continue
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule)
......@@ -136,12 +138,12 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig):
def disable_fake_quant(module: Module):
r"""
Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply`
Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`
:param module: root module to do disable fake quantization recursively.
"""
def fn(mod):
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.act_fake_quant.disable()
mod.weight_fake_quant.disable()
......@@ -151,12 +153,12 @@ def disable_fake_quant(module: Module):
def disable_observer(module: Module):
r"""
Recursively disable `module` observer in QATModule through :meth:`~.Module.apply`
Recursively disable ``module`` observer in QATModule through :meth:`~.Module.apply`
:param module: root module to do disable observer recursively.
"""
def fn(mod):
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.act_observer.disable()
mod.weight_observer.disable()
......@@ -166,12 +168,12 @@ def disable_observer(module: Module):
def enable_fake_quant(module: Module):
r"""
Recursively enable `module` fake quantization in QATModule through :meth:`~.Module.apply`
Recursively enable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`
:param module: root module to do enable fake quantization recursively.
"""
def fn(mod):
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.act_fake_quant.enable()
mod.weight_fake_quant.enable()
......@@ -181,12 +183,12 @@ def enable_fake_quant(module: Module):
def enable_observer(module: Module):
r"""
Recursively enable `module` observer in QATModule through :meth:`~.Module.apply`
Recursively enable ``module`` observer in QATModule through :meth:`~.Module.apply`
:param module: root module to do enable observer recursively.
"""
def fn(mod):
def fn(mod: Module):
if isinstance(mod, QATModule):
mod.act_observer.enable()
mod.weight_observer.enable()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册