提交 b43f6a26 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mge/quantization): handle empty Observer in QATModule

GitOrigin-RevId: e8a62297bc513a30be900743c3c199ccc2b30273
上级 13e8f00a
......@@ -70,18 +70,22 @@ class QATModule(Module):
def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
):
# do observer
if observer is None:
return target
oup = observer(target)
q_dict = observer.get_qparams()
q_dict = None
oup = target
else:
q_dict = observer.get_qparams()
oup = observer(target)
# do fake quant
if fake_quant is not None:
oup = fake_quant(oup, q_dict)
# use qparams of fake_quant if have.
if hasattr(fake_quant, "get_qparams"):
q_dict = fake_quant.get_qparams()
# use qparams of fake_quant if have.
if hasattr(fake_quant, "get_qparams"):
q_dict = fake_quant.get_qparams()
# set to tensor qparams.
oup.q_dict.update(q_dict)
if q_dict is not None:
oup.q_dict.update(q_dict)
return oup
def apply_quant_weight(self, target: Tensor):
......@@ -100,42 +104,46 @@ class QATModule(Module):
target, self.act_fake_quant, self.act_observer
)
def _get_method_result(
self, method: str, fake_quant: FakeQuantize, observer: Observer
):
if hasattr(fake_quant, method):
return getattr(fake_quant, method)()
elif hasattr(observer, method):
return getattr(observer, method)()
return None
def get_weight_dtype(self):
r"""
Get weight's quantization dtype as the method from ``qconfig``.
"""
if hasattr(self.weight_fake_quant, "get_dtype"):
return self.weight_fake_quant.get_dtype()
else:
return self.weight_observer.get_dtype()
return self._get_method_result(
"get_dtype", self.weight_fake_quant, self.weight_observer
)
def get_activation_dtype(self):
r"""
Get activation's quantization dtype as the method from ``qconfig``.
"""
if hasattr(self.act_fake_quant, "get_dtype"):
return self.act_fake_quant.get_dtype()
else:
return self.act_observer.get_dtype()
def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer):
if hasattr(fake_quant, "get_qparams"):
return fake_quant.get_qparams()
elif observer is not None:
return observer.get_qparams()
return None
return self._get_method_result(
"get_dtype", self.act_fake_quant, self.act_observer
)
def get_weight_qparams(self):
r"""
Get weight's quantization parameters.
"""
return self._get_qparams(self.weight_fake_quant, self.weight_observer)
return self._get_method_result(
"get_qparams", self.weight_fake_quant, self.weight_observer
)
def get_activation_qparams(self):
r"""
Get activation's quantization parameters.
"""
return self._get_qparams(self.act_fake_quant, self.act_observer)
return self._get_method_result(
"get_qparams", self.act_fake_quant, self.act_observer
)
@classmethod
@abstractmethod
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册