提交 4130dcd3 编写于 作者: M Megvii Engine Team

fix(mge/quantization): fix QATModule filter in `reset_qconfig` and `hook_qat_module`

GitOrigin-RevId: 92e9f36ca4d2fc1d70e0f5ea74702e71dd49c683
上级 97beae2f
...@@ -51,10 +51,6 @@ _float2qat_dict, _qat2quantized_dict = _get_convert_dict() ...@@ -51,10 +51,6 @@ _float2qat_dict, _qat2quantized_dict = _get_convert_dict()
qat_modules = tuple(_qat2quantized_dict.keys()) qat_modules = tuple(_qat2quantized_dict.keys())
def is_qat(mod: Module):
return isinstance(mod, qat_modules)
def quantize(module: Module, inplace: bool = True, mapping: dict = None): def quantize(module: Module, inplace: bool = True, mapping: dict = None):
r""" r"""
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
...@@ -157,6 +153,9 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): ...@@ -157,6 +153,9 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
inst.set_qparams(q_dict) inst.set_qparams(q_dict)
return inst return inst
def is_qat(mod: Module):
return isinstance(mod, QATModule)
for m in list(module._flatten(predicate=is_qat)): for m in list(module._flatten(predicate=is_qat)):
if m.with_weight: if m.with_weight:
weight_q_dict = m.get_weight_qparams() weight_q_dict = m.get_weight_qparams()
...@@ -193,6 +192,9 @@ def hook_qat_module(module: Module, func: Callable): ...@@ -193,6 +192,9 @@ def hook_qat_module(module: Module, func: Callable):
Add hooks for all :class:`~.QATModule` submodule Add hooks for all :class:`~.QATModule` submodule
""" """
def is_qat(mod: Module):
return isinstance(mod, QATModule)
hooks = [] hooks = []
for submodule in list(module._flatten(predicate=is_qat)): for submodule in list(module._flatten(predicate=is_qat)):
hooks.append(submodule.register_forward_hook(func)) hooks.append(submodule.register_forward_hook(func))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册