diff --git a/imperative/python/megengine/quantization/quantize.py b/imperative/python/megengine/quantization/quantize.py index 1011b867db36eba1719c6e2bb92df0b2688a483c..22fe8d15c424152cc5ca9d262c315565c60e5749 100644 --- a/imperative/python/megengine/quantization/quantize.py +++ b/imperative/python/megengine/quantization/quantize.py @@ -13,6 +13,7 @@ import numpy as np from .. import module as Float from ..functional import concat, norm +from ..logger import get_logger from ..module import Module from ..module import qat as QAT from ..module import quantized as Quantized @@ -22,6 +23,8 @@ from ..tensor import Tensor from ..utils.module_utils import set_expand_structure from .qconfig import QConfig, ema_fakequant_qconfig +logger = get_logger(__name__) + def _get_quantable_module_names(): def is_quantable(key: str): @@ -236,16 +239,18 @@ def apply_easy_quant( return orig_scale = ob.orig_scale - distance = 0 - best_scale = 0 + cosine = optimal = 0 for scale in np.linspace(start * orig_scale, stop * orig_scale, num): ob.scale = scale fakequant_out = mod(*fakequant_in) dis = get_cosine(normal_out, fakequant_out) - if dis > distance: - distance = dis - best_scale = scale - ob.scale = best_scale + if dis > cosine: + cosine = dis + optimal = scale + if optimal == 0: + logger.warning("EasyQuant finds no better scale") + else: + ob.scale = optimal fakequant_out = outputs[batch_size:] return concat([normal_out, fakequant_out])