diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index 2424e22ff0ad867ab0389d7270458791d4fa5852..358c6cc50785050fe6ce556ac049f97957a732a5 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -65,9 +65,9 @@ def quantize(module: Module, inplace=True): def is_qat(mod: Module): return isinstance(mod, qat_modules) - # no need to pass prefix and get pure key of parent Module. - for key, submodule, parent in module._flatten( - with_key=True, with_parent=True, predicate=is_qat + # must use list to avoid replacement influencing successor modules + for key, submodule, parent in list( + module._flatten(with_key=True, with_parent=True, predicate=is_qat) ): new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule) if isinstance(parent, Float.Sequential): @@ -100,9 +100,9 @@ def quantize_qat( def is_quantable(mod: Module): return isinstance(mod, quantable_modules) - # no need to pass prefix and get pure key of parent Module. - for key, submodule, parent in module._flatten( - with_key=True, with_parent=True, predicate=is_quantable + # must use list to avoid replacement influencing successor modules + for key, submodule, parent in list( + module._flatten(with_key=True, with_parent=True, predicate=is_quantable) ): # only convert top quantable module. if is_quantable(parent):