未验证 提交 7b6e9b6a 编写于 作者: Z zhouzj 提交者: GitHub

[cherry-pick] fix qat test (#55045)

* fix qat test.

* decrease test time.
上级 1d8116a5
......@@ -12,6 +12,7 @@
# see the license for the specific language governing permissions and
# limitations under the license.
import functools
import logging
import os
import random
import sys
......@@ -24,6 +25,7 @@ from PIL import Image
import paddle
from paddle.dataset.common import download
from paddle.static.log_helper import get_logger
from paddle.static.quantization import PostTrainingQuantization
paddle.enable_static()
......@@ -39,6 +41,10 @@ DATA_DIR = 'data/ILSVRC2012'
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
def resize_short(img, target_size):
percent = float(target_size) / min(img.size[0], img.size[1])
......@@ -193,7 +199,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
file_name = data_urls[0].split('/')[-1]
zip_path = os.path.join(self.cache_folder, file_name)
print(f'Data is downloaded at {zip_path}')
_logger.info(f'Data is downloaded at {zip_path}')
self.cache_unzipping(data_cache_folder, zip_path)
return data_cache_folder
......@@ -253,7 +259,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
cnt += len(data)
if (batch_id + 1) % 100 == 0:
print(f"{batch_id + 1} images,")
_logger.info(f"{batch_id + 1} images,")
sys.stdout.flush()
if (batch_id + 1) == iterations:
break
......@@ -275,14 +281,14 @@ class TestPostTrainingQuantization(unittest.TestCase):
is_full_quantize=False,
is_use_cache_file=False,
is_optimize_model=False,
batch_nums=10,
batch_nums=1,
onnx_format=False,
deploy_backend=None,
):
try:
os.system("mkdir " + self.int8_model)
except Exception as e:
print(f"Failed to create {self.int8_model} due to {str(e)}")
_logger.info(f"Failed to create {self.int8_model} due to {str(e)}")
sys.exit(-1)
place = paddle.CPUPlace()
......@@ -309,8 +315,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq.quantize()
ptq.save_quantized_model(
self.int8_model,
model_filename="inference.pdmodel",
params_filename="inference.pdiparams",
model_filename=model_filename,
params_filename=params_filename,
)
def run_test(
......@@ -322,27 +328,28 @@ class TestPostTrainingQuantization(unittest.TestCase):
round_type,
data_urls,
data_md5s,
data_name,
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
onnx_format=False,
batch_nums=10,
batch_nums=1,
deploy_backend=None,
):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
model_cache_folder = self.download_data(data_urls, data_md5s, model)
print(
model_path = os.path.join(model_cache_folder, data_name)
_logger.info(
"Start FP32 inference for {} on {} images ...".format(
model, infer_iterations * batch_size
)
)
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
os.path.join(model_cache_folder, "MobileNetV1_infer"),
model_path,
model_filename,
params_filename,
batch_size,
......@@ -350,7 +357,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
)
self.generate_quantized_model(
os.path.join(model_cache_folder, "MobileNetV1_infer"),
model_path,
model_filename,
params_filename,
quantizable_op_type,
......@@ -365,7 +372,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
deploy_backend,
)
print(
_logger.info(
"Start INT8 inference for {} on {} images ...".format(
model, infer_iterations * batch_size
)
......@@ -378,13 +385,13 @@ class TestPostTrainingQuantization(unittest.TestCase):
infer_iterations,
)
print(f"---Post training quantization of {algo} method---")
print(
_logger.info(f"---Post training quantization of {algo} method---")
_logger.info(
"FP32 {}: batch_size {}, throughput {} images/second, latency {} second, accuracy {}.".format(
model, batch_size, fp32_throughput, fp32_latency, fp32_acc1
)
)
print(
_logger.info(
"INT8 {}: batch_size {}, throughput {} images/second, latency {} second, accuracy {}.\n".format(
model, batch_size, int8_throughput, int8_latency, int8_acc1
)
......@@ -414,7 +421,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
batch_nums = 3
batch_nums = 1
self.run_test(
model,
'inference.pdmodel',
......@@ -423,6 +430,7 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
round_type,
data_urls,
data_md5s,
"MobileNetV1_infer",
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -457,11 +465,13 @@ class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
round_type,
data_urls,
data_md5s,
"MobileNetV1_infer",
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
is_optimize_model,
diff_threshold,
batch_nums=2,
)
......@@ -483,7 +493,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.03
batch_nums = 3
batch_nums = 1
self.run_test(
model,
'inference.pdmodel',
......@@ -492,6 +502,7 @@ class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
round_type,
data_urls,
data_md5s,
"MobileNetV1_infer",
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -527,6 +538,7 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
round_type,
data_urls,
data_md5s,
"MobileNetV1_infer",
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -554,7 +566,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
is_optimize_model = True
onnx_format = True
diff_threshold = 0.05
batch_nums = 3
batch_nums = 1
self.run_test(
model,
'inference.pdmodel',
......@@ -563,6 +575,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1(TestPostTrainingQuantization):
round_type,
data_urls,
data_md5s,
"MobileNetV1_infer",
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -594,7 +607,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1TensorRT(
is_optimize_model = False
onnx_format = True
diff_threshold = 0.05
batch_nums = 10
batch_nums = 2
deploy_backend = "tensorrt"
self.run_test(
model,
......@@ -604,6 +617,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1TensorRT(
round_type,
data_urls,
data_md5s,
"MobileNetV1_infer",
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -636,7 +650,7 @@ class TestPostTrainingKLONNXFormatForMobilenetv1MKLDNN(
is_optimize_model = False
onnx_format = True
diff_threshold = 0.05
batch_nums = 2
batch_nums = 1
deploy_backend = "mkldnn"
self.run_test(
model,
......@@ -646,6 +660,7 @@ class TestPostTrainingKLONNXFormatForMobilenetv1MKLDNN(
round_type,
data_urls,
data_md5s,
"MobileNetV1_infer",
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -678,7 +693,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1ARMCPU(
is_optimize_model = True
onnx_format = True
diff_threshold = 0.05
batch_nums = 3
batch_nums = 1
deploy_backend = "arm"
self.run_test(
model,
......@@ -688,6 +703,7 @@ class TestPostTrainingAvgONNXFormatForMobilenetv1ARMCPU(
round_type,
data_urls,
data_md5s,
"MobileNetV1_infer",
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......
......@@ -12,10 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import time
import unittest
import numpy as np
from test_post_training_quantization_mobilenetv1 import (
TestPostTrainingQuantization,
val,
)
import paddle
......@@ -45,6 +49,7 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
round_type,
data_urls,
data_md5s,
"model",
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......@@ -52,8 +57,66 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
diff_threshold,
)
def run_program(
self,
model_path,
model_filename,
params_filename,
batch_size,
infer_iterations,
):
image_shape = [3, 224, 224]
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
[
infer_program,
feed_dict,
fetch_targets,
] = paddle.static.load_inference_model(
model_path,
exe,
model_filename=model_filename,
params_filename=params_filename,
)
val_reader = paddle.batch(val(), batch_size)
iterations = infer_iterations
test_info = []
cnt = 0
periods = []
for batch_id, data in enumerate(val_reader()):
image = np.array([x[0].reshape(image_shape) for x in data]).astype(
"float32"
)
label = np.array([x[1] for x in data]).astype("int64")
label = label.reshape([-1, 1])
t1 = time.time()
_, acc1, _ = exe.run(
infer_program,
feed={feed_dict[0]: image, feed_dict[1]: label},
fetch_list=fetch_targets,
)
t2 = time.time()
period = t2 - t1
periods.append(period)
test_info.append(np.mean(acc1) * len(data))
cnt += len(data)
if (batch_id + 1) % 100 == 0:
print(f"{batch_id + 1} images,")
sys.stdout.flush()
if (batch_id + 1) == iterations:
break
throughput = cnt / np.sum(periods)
latency = np.average(periods)
acc1 = np.sum(test_info) / cnt
return (throughput, latency, acc1)
class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingForResnet50):
def test_post_training_resnet50(self):
model = "ResNet-50"
algo = "min_max"
......@@ -76,6 +139,7 @@ class TestPostTrainingForResnet50ONNXFormat(TestPostTrainingQuantization):
round_type,
data_urls,
data_md5s,
"model",
quantizable_op_type,
is_full_quantize,
is_use_cache_file,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册