提交 ab9fa48e 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

feat(mge/quantization): make `q_dict` a kwarg rather than an arg

GitOrigin-RevId: 38e3b2bfaf252761f87254c1060d3d1c1441b8c2
上级 f8810f73
......@@ -50,17 +50,17 @@ class _FakeQuantize(Module):
def disable(self):
self.enabled = False
def fake_quant_forward(self, inp, q_dict):
def fake_quant_forward(self, inp, q_dict=None):
return inp
def normal_foward(self, inp, q_dict):
def normal_foward(self, inp, q_dict=None):
return inp
def forward(self, inp, q_dict):
def forward(self, inp, q_dict=None):
if self.enabled:
return self.fake_quant_forward(inp, q_dict)
return self.fake_quant_forward(inp, q_dict=q_dict)
else:
return self.normal_foward(inp, q_dict)
return self.normal_foward(inp, q_dict=q_dict)
class TQT_Function(Function):
......@@ -110,11 +110,11 @@ class TQT(_FakeQuantize):
super().__init__(dtype, narrow_range, enable)
self.scale = Parameter(0.0, dtype=np.float32)
def fake_quant_forward(self, inp, q_dict):
def fake_quant_forward(self, inp, q_dict=None):
# when enable, TQT will do fakequant forward, finetune the scale
return TQT_Function(self.qmin, self.qmax)(inp, self.scale)
def normal_foward(self, inp, q_dict):
def normal_foward(self, inp, q_dict=None):
if q_dict["enable_observer"]:
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
......@@ -123,9 +123,9 @@ class TQT(_FakeQuantize):
return inp
def get_qparams(self):
qdict = get_qparam_dict(QuantMode.TQT)
qdict["scale"] = 2 ** self.scale
return qdict
q_dict = get_qparam_dict(QuantMode.TQT)
q_dict["scale"] = 2 ** self.scale
return q_dict
def get_dtype(self):
q_dict = self.get_qparams()
......@@ -141,5 +141,5 @@ class FakeQuantize(_FakeQuantize):
A module to do quant and dequant according to observer's scale and zero_point.
"""
def fake_quant_forward(self, inp, q_dict):
def fake_quant_forward(self, inp, q_dict=None):
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册