提交 e718a498 编写于 作者: M Megvii Engine Team

fix(mge/quantization): support for initializing LSQ from Qparams

GitOrigin-RevId: 860e378271f02466e18ad6329d10c1a776815124
上级 f7cbb065
...@@ -139,21 +139,21 @@ class LSQ(_FakeQuantize, QParamsModuleMixin): ...@@ -139,21 +139,21 @@ class LSQ(_FakeQuantize, QParamsModuleMixin):
self.zero_point = Tensor(0.0, dtype="float32") self.zero_point = Tensor(0.0, dtype="float32")
self.grad_scale = Tensor(1.0, dtype="float32") self.grad_scale = Tensor(1.0, dtype="float32")
def set_qparams(self, qparams: LSQParams): def set_qparams(self, qparams: QParams):
self.mode = qparams.mode self.mode = qparams.mode
if qparams.mode == QuantMode.ASYMMERTIC: if qparams.mode == QuantMode.ASYMMERTIC:
self.zero_point = qparams.zero_point self.zero_point = qparams.zero_point
else: else:
self.zero_point = Tensor([0.0], dtype="float32") self.zero_point = Tensor(0.0, dtype="float32")
if qparams.scale is None: if qparams.scale is None:
raise AssertionError("Can not get an initialized scale") raise AssertionError("Can not get an initialized scale")
init_step_size = qparams.scale init_step_size = qparams.scale
if init_step_size < self.eps: if init_step_size < self.eps:
init_step_size = 0 init_step_size = Tensor(0.0, dtype="float32")
else: else:
init_step_size = init_step_size - self.eps init_step_size = Tensor(init_step_size - self.eps)
self.step_size = Parameter(init_step_size, dtype="float32") self.step_size = Parameter(init_step_size.item(), dtype="float32")
if isinstance(qparams, LSQParams):
self.grad_scale = qparams.grad_scale self.grad_scale = qparams.grad_scale
def fake_quant_forward(self, inp, qparams: LSQParams = None): def fake_quant_forward(self, inp, qparams: LSQParams = None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册