From 3af16297603b6bc16eb7f0be9f38a9cf73d6b0ca Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Thu, 17 Jun 2021 10:57:22 +0800 Subject: [PATCH] fix the error of qat unit test (#33574) --- .../contrib/slim/tests/test_imperative_qat.py | 16 +++------------- .../tests/test_imperative_qat_channelwise.py | 3 --- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py index bf411e5b38e..3cc61ce8c58 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py @@ -64,11 +64,11 @@ class TestImperativeQat(unittest.TestCase): print("Failed to delete {} due to {}".format(cls.root_path, str(e))) def set_vars(self): - self.weight_quantize_type = None - self.activation_quantize_type = None + self.weight_quantize_type = 'abs_max' + self.activation_quantize_type = 'moving_average_abs_max' print('weight_quantize_type', self.weight_quantize_type) - def run_qat_save(self): + def test_qat(self): self.set_vars() imperative_qat = ImperativeQuantAware( @@ -200,15 +200,5 @@ class TestImperativeQat(unittest.TestCase): msg='Failed to save the inference quantized model.') -class TestImperativeQatAbsMax(TestImperativeQat): - def set_vars(self): - self.weight_quantize_type = 'abs_max' - self.activation_quantize_type = 'moving_average_abs_max' - print('weight_quantize_type', self.weight_quantize_type) - - def test_qat(self): - self.run_qat_save() - - if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py index 3d2cad388d1..1a6c9c41638 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_channelwise.py @@ -43,9 +43,6 @@ class TestImperativeQatChannelWise(TestImperativeQat): self.activation_quantize_type = 'moving_average_abs_max' print('weight_quantize_type', self.weight_quantize_type) - def test_qat(self): - self.run_qat_save() - if __name__ == '__main__': unittest.main() -- GitLab