提交 b245e4ed 编写于 作者: M Megvii Engine Team

fix(mge/quantization): remove assert in fake_quant_bias to support more QAT mode

GitOrigin-RevId: 8c7f268480e703c896fbb93ca19516aa970a4901
上级 5aecef5d
......@@ -206,20 +206,26 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
):
inp_params = inp.qparams
w_params = w_qat.qparams
if inp_params.scale is not None and w_params.scale is not None:
assert inp_params.mode == w_params.mode, "incompatible QuantMode"
# TODO: support quint8 dtype.
assert (
inp_params.dtype_meta.np_dtype_str == "int8"
and w_params.dtype_meta.np_dtype_str == "int8"
), "fake_quant_bias only support int8 like dtype now"
# use the same mode with weight.
# TODO: avoid hardcode
b_dtype = _builtin_quant_dtypes["qint32"]
b_param = create_qparams(
w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale
)
b_qat = fake_quant_tensor(bias, b_param)
b_qat.qparams.update(b_param)
if inp_params.scale is None or w_params.scale is None:
return b_qat
# TODO: support different mode
if inp_params.mode != w_params.mode:
return b_qat
# TODO: support quint8 dtype.
if inp_params.dtype_meta.np_dtype_str != "int8":
return b_qat
if w_params.dtype_meta.np_dtype_str != "int8":
return b_qat
# use the same mode with weight.
# TODO: avoid hardcode
b_dtype = _builtin_quant_dtypes["qint32"]
b_param = create_qparams(
w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale
)
b_qat = fake_quant_tensor(bias, b_param)
b_qat.qparams.update(b_param)
return b_qat
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册