diff --git a/imperative/python/megengine/quantization/utils.py b/imperative/python/megengine/quantization/utils.py index 2cd1eecb38818a9a6dbc6e15d3181c26899da46f..00699d3fba3de85d414bc80fae23dd759d53fc76 100644 --- a/imperative/python/megengine/quantization/utils.py +++ b/imperative/python/megengine/quantization/utils.py @@ -206,20 +206,26 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: ): inp_params = inp.qparams w_params = w_qat.qparams - if inp_params.scale is not None and w_params.scale is not None: - assert inp_params.mode == w_params.mode, "incompatible QuantMode" - # TODO: support quint8 dtype. - assert ( - inp_params.dtype_meta.np_dtype_str == "int8" - and w_params.dtype_meta.np_dtype_str == "int8" - ), "fake_quant_bias only support int8 like dtype now" - - # use the same mode with weight. - # TODO: avoid hardcode - b_dtype = _builtin_quant_dtypes["qint32"] - b_param = create_qparams( - w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale - ) - b_qat = fake_quant_tensor(bias, b_param) - b_qat.qparams.update(b_param) + + if inp_params.scale is None or w_params.scale is None: + return b_qat + + # TODO: support different mode + if inp_params.mode != w_params.mode: + return b_qat + + # TODO: support quint8 dtype. + if inp_params.dtype_meta.np_dtype_str != "int8": + return b_qat + if w_params.dtype_meta.np_dtype_str != "int8": + return b_qat + + # use the same mode with weight. + # TODO: avoid hardcode + b_dtype = _builtin_quant_dtypes["qint32"] + b_param = create_qparams( + w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale + ) + b_qat = fake_quant_tensor(bias, b_param) + b_qat.qparams.update(b_param) return b_qat