diff --git a/python_module/megengine/module/qat/module.py b/python_module/megengine/module/qat/module.py index 5c510cca54629d84549567f6b1667f03a11bd429..7ebae774df72e46d9ec2a245ee242054b8cb86df 100644 --- a/python_module/megengine/module/qat/module.py +++ b/python_module/megengine/module/qat/module.py @@ -70,18 +70,22 @@ class QATModule(Module): def _apply_fakequant_with_observer( self, target: Tensor, fake_quant: FakeQuantize, observer: Observer ): + # do observer if observer is None: - return target - oup = observer(target) - q_dict = observer.get_qparams() + q_dict = None + oup = target + else: + q_dict = observer.get_qparams() + oup = observer(target) # do fake quant if fake_quant is not None: oup = fake_quant(oup, q_dict) - # use qparams of fake_quant if have. - if hasattr(fake_quant, "get_qparams"): - q_dict = fake_quant.get_qparams() + # use qparams of fake_quant if have. + if hasattr(fake_quant, "get_qparams"): + q_dict = fake_quant.get_qparams() # set to tensor qparams. - oup.q_dict.update(q_dict) + if q_dict is not None: + oup.q_dict.update(q_dict) return oup def apply_quant_weight(self, target: Tensor): @@ -100,42 +104,46 @@ class QATModule(Module): target, self.act_fake_quant, self.act_observer ) + def _get_method_result( + self, method: str, fake_quant: FakeQuantize, observer: Observer + ): + if hasattr(fake_quant, method): + return getattr(fake_quant, method)() + elif hasattr(observer, method): + return getattr(observer, method)() + return None + def get_weight_dtype(self): r""" Get weight's quantization dtype as the method from ``qconfig``. """ - if hasattr(self.weight_fake_quant, "get_dtype"): - return self.weight_fake_quant.get_dtype() - else: - return self.weight_observer.get_dtype() + return self._get_method_result( + "get_dtype", self.weight_fake_quant, self.weight_observer + ) def get_activation_dtype(self): r""" Get activation's quantization dtype as the method from ``qconfig``. """ - if hasattr(self.act_fake_quant, "get_dtype"): - return self.act_fake_quant.get_dtype() - else: - return self.act_observer.get_dtype() - - def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer): - if hasattr(fake_quant, "get_qparams"): - return fake_quant.get_qparams() - elif observer is not None: - return observer.get_qparams() - return None + return self._get_method_result( + "get_dtype", self.act_fake_quant, self.act_observer + ) def get_weight_qparams(self): r""" Get weight's quantization parameters. """ - return self._get_qparams(self.weight_fake_quant, self.weight_observer) + return self._get_method_result( + "get_qparams", self.weight_fake_quant, self.weight_observer + ) def get_activation_qparams(self): r""" Get activation's quantization parameters. """ - return self._get_qparams(self.act_fake_quant, self.act_observer) + return self._get_method_result( + "get_qparams", self.act_fake_quant, self.act_observer + ) @classmethod @abstractmethod