提交 ab9fa48e 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/quantization): make `q_dict` a kwarg rather than an arg

GitOrigin-RevId: 38e3b2bfaf252761f87254c1060d3d1c1441b8c2
上级 f8810f73
...@@ -50,17 +50,17 @@ class _FakeQuantize(Module): ...@@ -50,17 +50,17 @@ class _FakeQuantize(Module):
def disable(self): def disable(self):
self.enabled = False self.enabled = False
def fake_quant_forward(self, inp, q_dict): def fake_quant_forward(self, inp, q_dict=None):
return inp return inp
def normal_foward(self, inp, q_dict): def normal_foward(self, inp, q_dict=None):
return inp return inp
def forward(self, inp, q_dict): def forward(self, inp, q_dict=None):
if self.enabled: if self.enabled:
return self.fake_quant_forward(inp, q_dict) return self.fake_quant_forward(inp, q_dict=q_dict)
else: else:
return self.normal_foward(inp, q_dict) return self.normal_foward(inp, q_dict=q_dict)
class TQT_Function(Function): class TQT_Function(Function):
...@@ -110,11 +110,11 @@ class TQT(_FakeQuantize): ...@@ -110,11 +110,11 @@ class TQT(_FakeQuantize):
super().__init__(dtype, narrow_range, enable) super().__init__(dtype, narrow_range, enable)
self.scale = Parameter(0.0, dtype=np.float32) self.scale = Parameter(0.0, dtype=np.float32)
def fake_quant_forward(self, inp, q_dict): def fake_quant_forward(self, inp, q_dict=None):
# when enable, TQT will do fakequant forward, finetune the scale # when enable, TQT will do fakequant forward, finetune the scale
return TQT_Function(self.qmin, self.qmax)(inp, self.scale) return TQT_Function(self.qmin, self.qmax)(inp, self.scale)
def normal_foward(self, inp, q_dict): def normal_foward(self, inp, q_dict=None):
if q_dict["enable_observer"]: if q_dict["enable_observer"]:
# when disable, TQT will do normal forward, initialize scale weight # when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
...@@ -123,9 +123,9 @@ class TQT(_FakeQuantize): ...@@ -123,9 +123,9 @@ class TQT(_FakeQuantize):
return inp return inp
def get_qparams(self): def get_qparams(self):
qdict = get_qparam_dict(QuantMode.TQT) q_dict = get_qparam_dict(QuantMode.TQT)
qdict["scale"] = 2 ** self.scale q_dict["scale"] = 2 ** self.scale
return qdict return q_dict
def get_dtype(self): def get_dtype(self):
q_dict = self.get_qparams() q_dict = self.get_qparams()
...@@ -141,5 +141,5 @@ class FakeQuantize(_FakeQuantize): ...@@ -141,5 +141,5 @@ class FakeQuantize(_FakeQuantize):
A module to do quant and dequant according to observer's scale and zero_point. A module to do quant and dequant according to observer's scale and zero_point.
""" """
def fake_quant_forward(self, inp, q_dict): def fake_quant_forward(self, inp, q_dict=None):
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict) return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册