From b245e4edd9a732d25ec9d341a937f96f9cba85ec Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 8 Dec 2022 17:44:48 +0800 Subject: [PATCH] fix(mge/quantization): remove assert in fake_quant_bias to support more QAT mode GitOrigin-RevId: 8c7f268480e703c896fbb93ca19516aa970a4901 --- .../python/megengine/quantization/utils.py | 38 +++++++++++-------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/imperative/python/megengine/quantization/utils.py b/imperative/python/megengine/quantization/utils.py index 2cd1eecb3..00699d3fb 100644 --- a/imperative/python/megengine/quantization/utils.py +++ b/imperative/python/megengine/quantization/utils.py @@ -206,20 +206,26 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: ): inp_params = inp.qparams w_params = w_qat.qparams - if inp_params.scale is not None and w_params.scale is not None: - assert inp_params.mode == w_params.mode, "incompatible QuantMode" - # TODO: support quint8 dtype. - assert ( - inp_params.dtype_meta.np_dtype_str == "int8" - and w_params.dtype_meta.np_dtype_str == "int8" - ), "fake_quant_bias only support int8 like dtype now" - - # use the same mode with weight. - # TODO: avoid hardcode - b_dtype = _builtin_quant_dtypes["qint32"] - b_param = create_qparams( - w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale - ) - b_qat = fake_quant_tensor(bias, b_param) - b_qat.qparams.update(b_param) + + if inp_params.scale is None or w_params.scale is None: + return b_qat + + # TODO: support different mode + if inp_params.mode != w_params.mode: + return b_qat + + # TODO: support quint8 dtype. + if inp_params.dtype_meta.np_dtype_str != "int8": + return b_qat + if w_params.dtype_meta.np_dtype_str != "int8": + return b_qat + + # use the same mode with weight. + # TODO: avoid hardcode + b_dtype = _builtin_quant_dtypes["qint32"] + b_param = create_qparams( + w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale + ) + b_qat = fake_quant_tensor(bias, b_param) + b_qat.qparams.update(b_param) return b_qat -- GitLab