From 4130dcd3556f7defb7e82f8cc224f78b82eec5c9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 22 Feb 2021 15:38:54 +0800 Subject: [PATCH] fix(mge/quantization): fix QATModule filter in `reset_qconfig` and `hook_qat_module` GitOrigin-RevId: 92e9f36ca4d2fc1d70e0f5ea74702e71dd49c683 --- imperative/python/megengine/quantization/quantize.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/imperative/python/megengine/quantization/quantize.py b/imperative/python/megengine/quantization/quantize.py index e3a01942d..730947c90 100644 --- a/imperative/python/megengine/quantization/quantize.py +++ b/imperative/python/megengine/quantization/quantize.py @@ -51,10 +51,6 @@ _float2qat_dict, _qat2quantized_dict = _get_convert_dict() qat_modules = tuple(_qat2quantized_dict.keys()) -def is_qat(mod: Module): - return isinstance(mod, qat_modules) - - def quantize(module: Module, inplace: bool = True, mapping: dict = None): r""" Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` @@ -157,6 +153,9 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): inst.set_qparams(q_dict) return inst + def is_qat(mod: Module): + return isinstance(mod, QATModule) + for m in list(module._flatten(predicate=is_qat)): if m.with_weight: weight_q_dict = m.get_weight_qparams() @@ -193,6 +192,9 @@ def hook_qat_module(module: Module, func: Callable): Add hooks for all :class:`~.QATModule` submodule """ + def is_qat(mod: Module): + return isinstance(mod, QATModule) + hooks = [] for submodule in list(module._flatten(predicate=is_qat)): hooks.append(submodule.register_forward_hook(func)) -- GitLab