提交 4130dcd3 编写于 作者: M Megvii Engine Team

fix(mge/quantization): fix QATModule filter in `reset_qconfig` and `hook_qat_module`

GitOrigin-RevId: 92e9f36ca4d2fc1d70e0f5ea74702e71dd49c683
上级 97beae2f
......@@ -51,10 +51,6 @@ _float2qat_dict, _qat2quantized_dict = _get_convert_dict()
qat_modules = tuple(_qat2quantized_dict.keys())
def is_qat(mod: Module):
return isinstance(mod, qat_modules)
def quantize(module: Module, inplace: bool = True, mapping: dict = None):
r"""
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
......@@ -157,6 +153,9 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
inst.set_qparams(q_dict)
return inst
def is_qat(mod: Module):
return isinstance(mod, QATModule)
for m in list(module._flatten(predicate=is_qat)):
if m.with_weight:
weight_q_dict = m.get_weight_qparams()
......@@ -193,6 +192,9 @@ def hook_qat_module(module: Module, func: Callable):
Add hooks for all :class:`~.QATModule` submodule
"""
def is_qat(mod: Module):
return isinstance(mod, QATModule)
hooks = []
for submodule in list(module._flatten(predicate=is_qat)):
hooks.append(submodule.register_forward_hook(func))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册