diff --git a/imperative/python/megengine/quantization/fake_quant.py b/imperative/python/megengine/quantization/fake_quant.py index c1469d33fbd86fe8e4bf0c388270cb4b4c318d00..f6625d7083e063893b532e40218a6942798055e5 100644 --- a/imperative/python/megengine/quantization/fake_quant.py +++ b/imperative/python/megengine/quantization/fake_quant.py @@ -132,7 +132,7 @@ class LSQ(_FakeQuantize, QParamsModuleMixin): :param eps:a small value to avoid division by zero. Default: 1e-5 """ - def init( + def __init__( self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, @@ -142,6 +142,9 @@ class LSQ(_FakeQuantize, QParamsModuleMixin): super().__init__(dtype=dtype, enable=enable, **kwargs) self.eps = Tensor(eps, dtype="float32") self.step_size = Parameter(1.0, dtype="float32") + self.mode = None + self.zero_point = Tensor(0.0, dtype="float32") + self.grad_scale = Tensor(1.0, dtype="float32") def set_qparams(self, qparams: LSQParams): self.mode = qparams.mode