From 44bafd3f58883834366f57dae338456f4cfa6274 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 2 Apr 2021 18:17:46 +0800 Subject: [PATCH] fix(imperative/quantization): fix zero scale bug of easy quant GitOrigin-RevId: f45e19b3e4d1988330b137386863bc9b80ffab48 --- .../python/megengine/quantization/quantize.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/imperative/python/megengine/quantization/quantize.py b/imperative/python/megengine/quantization/quantize.py index 1011b867..22fe8d15 100644 --- a/imperative/python/megengine/quantization/quantize.py +++ b/imperative/python/megengine/quantization/quantize.py @@ -13,6 +13,7 @@ import numpy as np from .. import module as Float from ..functional import concat, norm +from ..logger import get_logger from ..module import Module from ..module import qat as QAT from ..module import quantized as Quantized @@ -22,6 +23,8 @@ from ..tensor import Tensor from ..utils.module_utils import set_expand_structure from .qconfig import QConfig, ema_fakequant_qconfig +logger = get_logger(__name__) + def _get_quantable_module_names(): def is_quantable(key: str): @@ -236,16 +239,18 @@ def apply_easy_quant( return orig_scale = ob.orig_scale - distance = 0 - best_scale = 0 + cosine = optimal = 0 for scale in np.linspace(start * orig_scale, stop * orig_scale, num): ob.scale = scale fakequant_out = mod(*fakequant_in) dis = get_cosine(normal_out, fakequant_out) - if dis > distance: - distance = dis - best_scale = scale - ob.scale = best_scale + if dis > cosine: + cosine = dis + optimal = scale + if optimal == 0: + logger.warning("EasyQuant finds no better scale") + else: + ob.scale = optimal fakequant_out = outputs[batch_size:] return concat([normal_out, fakequant_out]) -- GitLab