diff --git a/imperative/python/megengine/quantization/quantize.py b/imperative/python/megengine/quantization/quantize.py index e3a01942ddb8ae866185b8f140fe6668f555b1ce..730947c9027ebfbc4c81ba7372706acec006f434 100644 --- a/imperative/python/megengine/quantization/quantize.py +++ b/imperative/python/megengine/quantization/quantize.py @@ -51,10 +51,6 @@ _float2qat_dict, _qat2quantized_dict = _get_convert_dict() qat_modules = tuple(_qat2quantized_dict.keys()) -def is_qat(mod: Module): - return isinstance(mod, qat_modules) - - def quantize(module: Module, inplace: bool = True, mapping: dict = None): r""" Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` @@ -157,6 +153,9 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): inst.set_qparams(q_dict) return inst + def is_qat(mod: Module): + return isinstance(mod, QATModule) + for m in list(module._flatten(predicate=is_qat)): if m.with_weight: weight_q_dict = m.get_weight_qparams() @@ -193,6 +192,9 @@ def hook_qat_module(module: Module, func: Callable): Add hooks for all :class:`~.QATModule` submodule """ + def is_qat(mod: Module): + return isinstance(mod, QATModule) + hooks = [] for submodule in list(module._flatten(predicate=is_qat)): hooks.append(submodule.register_forward_hook(func))