From b848bd3713a3da21fedf522591b40f9d034d9ce5 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Mon, 27 Jun 2022 14:54:18 +0800 Subject: [PATCH] fix post_training_quantization typo (#43845) --- .../contrib/slim/quantization/post_training_quantization.py | 4 ++-- .../slim/tests/test_post_training_quantization_mobilenetv1.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py index 3926ee9503..f1da3990a3 100644 --- a/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py +++ b/python/paddle/fluid/contrib/slim/quantization/post_training_quantization.py @@ -657,7 +657,7 @@ class PostTrainingQuantization(object): s += 0.02 bins = 2**(self._activation_bits - 1) - 1 if self._onnx_format: - quant_var = np.clip(distribution(var_tensor / scale * bins), + quant_var = np.clip(np.round(var_tensor / scale * bins), -bins - 1, bins) quant_dequant_var = quant_var / bins * scale else: @@ -701,7 +701,7 @@ class PostTrainingQuantization(object): s += 0.02 bins = 2**(self._activation_bits - 1) - 1 if self._onnx_format: - quant_var = np.clip(distribution(var_tensor / scale * bins), + quant_var = np.clip(np.round(var_tensor / scale * bins), -bins - 1, bins) quant_dequant_var = quant_var / bins * scale else: diff --git a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py index 9c076d85fd..25707d0c8c 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py +++ b/python/paddle/fluid/contrib/slim/tests/test_post_training_quantization_mobilenetv1.py @@ -425,7 +425,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): def test_post_training_onnx_format_mobilenetv1(self): model = "MobileNet-V1" - algo = "avg" + algo = "emd" round_type = "round" data_urls = [ 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' -- GitLab