提交 3e58cbb8 编写于 作者: M Megvii Engine Team

fix(mge/quantization): fix lsq init format error

GitOrigin-RevId: 032de9c7738622fd6d83e706e5a5b85355645dfd
上级 c88a4e5b
......@@ -132,7 +132,7 @@ class LSQ(_FakeQuantize, QParamsModuleMixin):
:param eps:a small value to avoid division by zero. Default: 1e-5
"""
def init(
def __init__(
self,
dtype: Union[str, QuantDtypeMeta],
enable: bool = True,
......@@ -142,6 +142,9 @@ class LSQ(_FakeQuantize, QParamsModuleMixin):
super().__init__(dtype=dtype, enable=enable, **kwargs)
self.eps = Tensor(eps, dtype="float32")
self.step_size = Parameter(1.0, dtype="float32")
self.mode = None
self.zero_point = Tensor(0.0, dtype="float32")
self.grad_scale = Tensor(1.0, dtype="float32")
def set_qparams(self, qparams: LSQParams):
self.mode = qparams.mode
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册