未验证 提交 b848bd37 编写于 作者: G Guanghua Yu 提交者: GitHub

fix post_training_quantization typo (#43845)

上级 7d14613d
...@@ -657,7 +657,7 @@ class PostTrainingQuantization(object): ...@@ -657,7 +657,7 @@ class PostTrainingQuantization(object):
s += 0.02 s += 0.02
bins = 2**(self._activation_bits - 1) - 1 bins = 2**(self._activation_bits - 1) - 1
if self._onnx_format: if self._onnx_format:
quant_var = np.clip(distribution(var_tensor / scale * bins), quant_var = np.clip(np.round(var_tensor / scale * bins),
-bins - 1, bins) -bins - 1, bins)
quant_dequant_var = quant_var / bins * scale quant_dequant_var = quant_var / bins * scale
else: else:
...@@ -701,7 +701,7 @@ class PostTrainingQuantization(object): ...@@ -701,7 +701,7 @@ class PostTrainingQuantization(object):
s += 0.02 s += 0.02
bins = 2**(self._activation_bits - 1) - 1 bins = 2**(self._activation_bits - 1) - 1
if self._onnx_format: if self._onnx_format:
quant_var = np.clip(distribution(var_tensor / scale * bins), quant_var = np.clip(np.round(var_tensor / scale * bins),
-bins - 1, bins) -bins - 1, bins)
quant_dequant_var = quant_var / bins * scale quant_dequant_var = quant_var / bins * scale
else: else:
......
...@@ -425,7 +425,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization): ...@@ -425,7 +425,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_onnx_format_mobilenetv1(self): def test_post_training_onnx_format_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "avg" algo = "emd"
round_type = "round" round_type = "round"
data_urls = [ data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册