提交 84865b80 编写于 作者: J juncaipeng 提交者: Tao Luo

add resnet50 test for post trainint quantization, test=develop (#21272)

上级 9a7832f8
......@@ -99,7 +99,7 @@ class PostTrainingQuantization(object):
params_filename = None
save_model_path = path/to/save_model_path
# prepare the sample generator according to the model, and the
# sample generator must return a simple every time. The reference
# sample generator must return a sample every time. The reference
# document: https://www.paddlepaddle.org.cn/documentation/docs/zh
# /user_guides/howto/prepare_data/use_py_reader.html
sample_generator = your_sample_generator
......
......@@ -48,7 +48,8 @@ endfunction()
if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
endif()
# int8 image classification python api test
......
......@@ -110,10 +110,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.int8_download = 'int8/download'
self.cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' +
self.int8_download)
self.data_cache_folder = ''
data_urls = []
data_md5s = []
self.data_cache_folder = ''
if os.environ.get('DATASET') == 'full':
data_urls.append(
'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa'
......@@ -145,7 +144,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
'DATASET') == 'full' else 1
self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
self.int8_model = ''
self.int8_model = os.path.join(os.getcwd(),
"post_training_" + self.timestamp)
def tearDown(self):
try:
......@@ -191,14 +191,14 @@ class TestPostTrainingQuantization(unittest.TestCase):
def download_model(self):
pass
def run_program(self, model_path):
def run_program(self, model_path, batch_size, infer_iterations):
image_shape = [3, 224, 224]
place = fluid.CPUPlace()
exe = fluid.Executor(place)
[infer_program, feed_dict, fetch_targets] = \
fluid.io.load_inference_model(model_path, exe)
val_reader = paddle.batch(val(), self.batch_size)
iterations = self.infer_iterations
val_reader = paddle.batch(val(), batch_size)
iterations = infer_iterations
test_info = []
cnt = 0
......@@ -237,8 +237,6 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path,
algo="KL",
is_full_quantize=False):
self.int8_model = os.path.join(os.getcwd(),
"post_training_" + self.timestamp)
try:
os.system("mkdir " + self.int8_model)
except Exception as e:
......@@ -264,52 +262,50 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq.quantize()
ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
def download_model(self):
# mobilenetv1 fp32 data
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
self.model_cache_folder = self.download_data(data_urls, data_md5s,
"mobilenetv1_fp32")
self.model = "MobileNet-V1"
self.algo = "KL"
def test_post_training_mobilenetv1(self):
self.download_model()
model_cache_folder = self.download_data(data_urls, data_md5s, model)
print("Start FP32 inference for {0} on {1} images ...".format(
self.model, self.infer_iterations * self.batch_size))
(fp32_throughput, fp32_latency,
fp32_acc1) = self.run_program(self.model_cache_folder + "/model")
model, infer_iterations * batch_size))
(fp32_throughput, fp32_latency, fp32_acc1) = self.run_program(
model_cache_folder + "/model", batch_size, infer_iterations)
print("Start INT8 post training quantization for {0} on {1} images ...".
format(self.model, self.sample_iterations * self.batch_size))
format(model, sample_iterations * batch_size))
self.generate_quantized_model(
self.model_cache_folder + "/model",
algo=self.algo,
is_full_quantize=True)
model_cache_folder + "/model", algo=algo, is_full_quantize=True)
print("Start INT8 inference for {0} on {1} images ...".format(
self.model, self.infer_iterations * self.batch_size))
(int8_throughput, int8_latency,
int8_acc1) = self.run_program(self.int8_model)
model, infer_iterations * batch_size))
(int8_throughput, int8_latency, int8_acc1) = self.run_program(
self.int8_model, batch_size, infer_iterations)
print(
"FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}".
format(self.model, self.batch_size, fp32_throughput, fp32_latency,
fp32_acc1))
format(model, batch_size, fp32_throughput, fp32_latency, fp32_acc1))
print(
"INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}".
format(self.model, self.batch_size, int8_throughput, int8_latency,
int8_acc1))
format(model, batch_size, int8_throughput, int8_latency, int8_acc1))
sys.stdout.flush()
delta_value = fp32_acc1 - int8_acc1
self.assertLess(delta_value, 0.025)
class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_mobilenetv1(self):
model = "MobileNet-V1"
algo = "KL"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
self.run_test(model, algo, data_urls, data_md5s)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantization
class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def test_post_training_resnet50(self):
model = "ResNet-50"
algo = "direct"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
data_md5s = ['4a5194524823d9b76da6e738e1367881']
self.run_test(model, algo, data_urls, data_md5s)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册