diff --git a/python_module/megengine/module/qat/module.py b/python_module/megengine/module/qat/module.py index 8546e4c66f7a431803a64be70697f6d591f679ed..c7cb80cbef69ebb14344f373aaae8742fadd2155 100644 --- a/python_module/megengine/module/qat/module.py +++ b/python_module/megengine/module/qat/module.py @@ -52,7 +52,7 @@ class QATModule(Module): self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) def _enable_exec(self, with_module, func, enable): - if not with_module: + if not with_module or not func: return if enable: func.enable() diff --git a/python_module/megengine/module/quantized/linear.py b/python_module/megengine/module/quantized/linear.py index 4c798929785f00bacbbed447543cad8397452b84..a6e61a6e3115150e05c07a0a5bb7afaf7b3dff09 100644 --- a/python_module/megengine/module/quantized/linear.py +++ b/python_module/megengine/module/quantized/linear.py @@ -32,11 +32,13 @@ class Linear(QuantizedModule): inp_scale = mgb.dtype.get_scale(inp.dtype) w_scale = mgb.dtype.get_scale(self.weight.dtype) bias_dtype = mgb.dtype.qint32(inp_scale * w_scale) - return F.linear( + ret = F.linear( inp, self.weight, None if self.bias is None else self.bias.astype(bias_dtype), - ).astype(self.output_dtype) + ) + ret = ret if self.output_dtype is None else ret.astype(self.output_dtype) + return ret @classmethod def from_qat_module(cls, qat_module: QAT.Linear):