提交 b245e4ed 编写于 作者: M Megvii Engine Team

fix(mge/quantization): remove assert in fake_quant_bias to support more QAT mode

GitOrigin-RevId: 8c7f268480e703c896fbb93ca19516aa970a4901
上级 5aecef5d
......@@ -206,13 +206,19 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
):
inp_params = inp.qparams
w_params = w_qat.qparams
if inp_params.scale is not None and w_params.scale is not None:
assert inp_params.mode == w_params.mode, "incompatible QuantMode"
if inp_params.scale is None or w_params.scale is None:
return b_qat
# TODO: support different mode
if inp_params.mode != w_params.mode:
return b_qat
# TODO: support quint8 dtype.
assert (
inp_params.dtype_meta.np_dtype_str == "int8"
and w_params.dtype_meta.np_dtype_str == "int8"
), "fake_quant_bias only support int8 like dtype now"
if inp_params.dtype_meta.np_dtype_str != "int8":
return b_qat
if w_params.dtype_meta.np_dtype_str != "int8":
return b_qat
# use the same mode with weight.
# TODO: avoid hardcode
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册