提交 b43f6a26 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mge/quantization): handle empty Observer in QATModule

GitOrigin-RevId: e8a62297bc513a30be900743c3c199ccc2b30273
上级 13e8f00a
...@@ -70,18 +70,22 @@ class QATModule(Module): ...@@ -70,18 +70,22 @@ class QATModule(Module):
def _apply_fakequant_with_observer( def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
): ):
# do observer
if observer is None: if observer is None:
return target q_dict = None
oup = observer(target) oup = target
q_dict = observer.get_qparams() else:
q_dict = observer.get_qparams()
oup = observer(target)
# do fake quant # do fake quant
if fake_quant is not None: if fake_quant is not None:
oup = fake_quant(oup, q_dict) oup = fake_quant(oup, q_dict)
# use qparams of fake_quant if have. # use qparams of fake_quant if have.
if hasattr(fake_quant, "get_qparams"): if hasattr(fake_quant, "get_qparams"):
q_dict = fake_quant.get_qparams() q_dict = fake_quant.get_qparams()
# set to tensor qparams. # set to tensor qparams.
oup.q_dict.update(q_dict) if q_dict is not None:
oup.q_dict.update(q_dict)
return oup return oup
def apply_quant_weight(self, target: Tensor): def apply_quant_weight(self, target: Tensor):
...@@ -100,42 +104,46 @@ class QATModule(Module): ...@@ -100,42 +104,46 @@ class QATModule(Module):
target, self.act_fake_quant, self.act_observer target, self.act_fake_quant, self.act_observer
) )
def _get_method_result(
self, method: str, fake_quant: FakeQuantize, observer: Observer
):
if hasattr(fake_quant, method):
return getattr(fake_quant, method)()
elif hasattr(observer, method):
return getattr(observer, method)()
return None
def get_weight_dtype(self): def get_weight_dtype(self):
r""" r"""
Get weight's quantization dtype as the method from ``qconfig``. Get weight's quantization dtype as the method from ``qconfig``.
""" """
if hasattr(self.weight_fake_quant, "get_dtype"): return self._get_method_result(
return self.weight_fake_quant.get_dtype() "get_dtype", self.weight_fake_quant, self.weight_observer
else: )
return self.weight_observer.get_dtype()
def get_activation_dtype(self): def get_activation_dtype(self):
r""" r"""
Get activation's quantization dtype as the method from ``qconfig``. Get activation's quantization dtype as the method from ``qconfig``.
""" """
if hasattr(self.act_fake_quant, "get_dtype"): return self._get_method_result(
return self.act_fake_quant.get_dtype() "get_dtype", self.act_fake_quant, self.act_observer
else: )
return self.act_observer.get_dtype()
def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer):
if hasattr(fake_quant, "get_qparams"):
return fake_quant.get_qparams()
elif observer is not None:
return observer.get_qparams()
return None
def get_weight_qparams(self): def get_weight_qparams(self):
r""" r"""
Get weight's quantization parameters. Get weight's quantization parameters.
""" """
return self._get_qparams(self.weight_fake_quant, self.weight_observer) return self._get_method_result(
"get_qparams", self.weight_fake_quant, self.weight_observer
)
def get_activation_qparams(self): def get_activation_qparams(self):
r""" r"""
Get activation's quantization parameters. Get activation's quantization parameters.
""" """
return self._get_qparams(self.act_fake_quant, self.act_observer) return self._get_method_result(
"get_qparams", self.act_fake_quant, self.act_observer
)
@classmethod @classmethod
@abstractmethod @abstractmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册