提交 e718a498 编写于 作者: M Megvii Engine Team

fix(mge/quantization): support for initializing LSQ from Qparams

GitOrigin-RevId: 860e378271f02466e18ad6329d10c1a776815124
上级 f7cbb065
......@@ -139,22 +139,22 @@ class LSQ(_FakeQuantize, QParamsModuleMixin):
self.zero_point = Tensor(0.0, dtype="float32")
self.grad_scale = Tensor(1.0, dtype="float32")
def set_qparams(self, qparams: LSQParams):
def set_qparams(self, qparams: QParams):
self.mode = qparams.mode
if qparams.mode == QuantMode.ASYMMERTIC:
self.zero_point = qparams.zero_point
else:
self.zero_point = Tensor([0.0], dtype="float32")
self.zero_point = Tensor(0.0, dtype="float32")
if qparams.scale is None:
raise AssertionError("Can not get an initialized scale")
init_step_size = qparams.scale
if init_step_size < self.eps:
init_step_size = 0
init_step_size = Tensor(0.0, dtype="float32")
else:
init_step_size = init_step_size - self.eps
self.step_size = Parameter(init_step_size, dtype="float32")
self.grad_scale = qparams.grad_scale
init_step_size = Tensor(init_step_size - self.eps)
self.step_size = Parameter(init_step_size.item(), dtype="float32")
if isinstance(qparams, LSQParams):
self.grad_scale = qparams.grad_scale
def fake_quant_forward(self, inp, qparams: LSQParams = None):
step_size = F.abs(self.step_size) + self.eps
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册