提交 e55c5db0 编写于 作者: M Megvii Engine Team

fix(mge/quantization): remove assert in fake_quant_bias to support more QAT mode

GitOrigin-RevId: 8c7f268480e703c896fbb93ca19516aa970a4901
上级 bc581d59
...@@ -206,20 +206,26 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: ...@@ -206,20 +206,26 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
): ):
inp_params = inp.qparams inp_params = inp.qparams
w_params = w_qat.qparams w_params = w_qat.qparams
if inp_params.scale is not None and w_params.scale is not None:
assert inp_params.mode == w_params.mode, "incompatible QuantMode" if inp_params.scale is None or w_params.scale is None:
# TODO: support quint8 dtype. return b_qat
assert (
inp_params.dtype_meta.np_dtype_str == "int8" # TODO: support different mode
and w_params.dtype_meta.np_dtype_str == "int8" if inp_params.mode != w_params.mode:
), "fake_quant_bias only support int8 like dtype now" return b_qat
# use the same mode with weight. # TODO: support quint8 dtype.
# TODO: avoid hardcode if inp_params.dtype_meta.np_dtype_str != "int8":
b_dtype = _builtin_quant_dtypes["qint32"] return b_qat
b_param = create_qparams( if w_params.dtype_meta.np_dtype_str != "int8":
w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale return b_qat
)
b_qat = fake_quant_tensor(bias, b_param) # use the same mode with weight.
b_qat.qparams.update(b_param) # TODO: avoid hardcode
b_dtype = _builtin_quant_dtypes["qint32"]
b_param = create_qparams(
w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale
)
b_qat = fake_quant_tensor(bias, b_param)
b_qat.qparams.update(b_param)
return b_qat return b_qat
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册