提交 387867f8 编写于 作者: M Megvii Engine Team

feat(mge/quantization): add cambricon-quantization-example

GitOrigin-RevId: ed578ca92bccca84367a6bd5c4492ddcf759f50e
上级 0067dcf0
......@@ -52,7 +52,7 @@ class QATModule(Module):
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant)
def _enable_exec(self, with_module, func, enable):
if not with_module:
if not with_module or not func:
return
if enable:
func.enable()
......
......@@ -32,11 +32,13 @@ class Linear(QuantizedModule):
inp_scale = mgb.dtype.get_scale(inp.dtype)
w_scale = mgb.dtype.get_scale(self.weight.dtype)
bias_dtype = mgb.dtype.qint32(inp_scale * w_scale)
return F.linear(
ret = F.linear(
inp,
self.weight,
None if self.bias is None else self.bias.astype(bias_dtype),
).astype(self.output_dtype)
)
ret = ret if self.output_dtype is None else ret.astype(self.output_dtype)
return ret
@classmethod
def from_qat_module(cls, qat_module: QAT.Linear):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册