未验证 提交 0548aac2 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #15532 from hshen14/calibration_api_refine

Refine INT8 calibration API
...@@ -32,10 +32,13 @@ class Calibrator(object): ...@@ -32,10 +32,13 @@ class Calibrator(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.program = kwargs['program'] self.program = kwargs['program']
self.iterations = kwargs['iterations']
self.pretrained_model = kwargs['pretrained_model'] self.pretrained_model = kwargs['pretrained_model']
self.debug = kwargs['debug'] self.debug = kwargs['debug'] if 'debug' in kwargs else False
self.algo = kwargs['algo'] self.algo = kwargs['algo']
self.output = kwargs['output']
self.feed_var_names = kwargs['feed_var_names']
self.fetch_list = kwargs['fetch_list']
self.exe = kwargs['exe']
self._conv_input_var_name = [] self._conv_input_var_name = []
self._conv_output_var_name = [] self._conv_output_var_name = []
...@@ -54,17 +57,38 @@ class Calibrator(object): ...@@ -54,17 +57,38 @@ class Calibrator(object):
self._u8_output_var = [] self._u8_output_var = []
self._s8_output_var = [] self._s8_output_var = []
self._persistable_vars = [] self._persistable_vars = []
self._sampling_data = {}
def generate_sampling_program(self):
self.__init_analysis() self.__init_analysis()
self.__generate_output_program() self.__generate_output_program()
def generate_quantized_data(self, sampling_data): def save_int8_model(self):
self.__sampling(sampling_data) self.__sampling(self._sampling_data)
self.__save_scale() self.__save_scale()
self.__update_program() self.__update_program()
self.__update_output_program_attr() self.__update_output_program_attr()
self.__display_debug() self.__display_debug()
self.__save_offline_model()
def sample_data(self):
'''
Sampling the tensor data of variable.
'''
for i in self.sampling_program.list_vars():
if i.name in self.sampling_vars:
np_data = np.array(fluid.global_scope().find_var(i.name)
.get_tensor())
if i.name not in self._sampling_data:
self._sampling_data[i.name] = []
self._sampling_data[i.name].append(np_data)
def __save_offline_model(self):
'''
Save the quantized model to the disk.
'''
fluid.io.save_inference_model(self.output, self.feed_var_names,
self.fetch_list, self.exe,
self.sampling_program)
def __display_debug(self): def __display_debug(self):
if self.debug: if self.debug:
......
...@@ -26,7 +26,7 @@ import paddle.fluid.profiler as profiler ...@@ -26,7 +26,7 @@ import paddle.fluid.profiler as profiler
from PIL import Image, ImageEnhance from PIL import Image, ImageEnhance
import math import math
sys.path.append('..') sys.path.append('..')
import int8_inference.utility as ut import int8_inference.utility as int8_utility
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -120,13 +120,13 @@ class TestCalibration(unittest.TestCase): ...@@ -120,13 +120,13 @@ class TestCalibration(unittest.TestCase):
def setUp(self): def setUp(self):
# TODO(guomingz): Put the download process in the cmake. # TODO(guomingz): Put the download process in the cmake.
# Download and unzip test data set # Download and unzip test data set
imagenet_dl_url = 'http://paddle-inference-dist.bj.bcebos.com/int8/calibration_test_data.tar.gz' imagenet_dl_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/calibration_test_data.tar.gz'
zip_file_name = imagenet_dl_url.split('/')[-1] zip_file_name = imagenet_dl_url.split('/')[-1]
cmd = 'rm -rf data {} && mkdir data && wget {} && tar xvf {} -C data'.format( cmd = 'rm -rf data {} && mkdir data && wget {} && tar xvf {} -C data'.format(
zip_file_name, imagenet_dl_url, zip_file_name) zip_file_name, imagenet_dl_url, zip_file_name)
os.system(cmd) os.system(cmd)
# resnet50 fp32 data # resnet50 fp32 data
resnet50_fp32_model_url = 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' resnet50_fp32_model_url = 'http://paddle-inference-dist.cdn.bcebos.com/int8/resnet50_int8_model.tar.gz'
resnet50_zip_name = resnet50_fp32_model_url.split('/')[-1] resnet50_zip_name = resnet50_fp32_model_url.split('/')[-1]
resnet50_unzip_folder_name = 'resnet50_fp32' resnet50_unzip_folder_name = 'resnet50_fp32'
cmd = 'rm -rf {} {} && mkdir {} && wget {} && tar xvf {} -C {}'.format( cmd = 'rm -rf {} {} && mkdir {} && wget {} && tar xvf {} -C {}'.format(
...@@ -135,8 +135,7 @@ class TestCalibration(unittest.TestCase): ...@@ -135,8 +135,7 @@ class TestCalibration(unittest.TestCase):
resnet50_zip_name, resnet50_unzip_folder_name) resnet50_zip_name, resnet50_unzip_folder_name)
os.system(cmd) os.system(cmd)
self.iterations = 100 self.iterations = 50
self.skip_batch_num = 5
def run_program(self, model_path, generate_int8=False, algo='direct'): def run_program(self, model_path, generate_int8=False, algo='direct'):
image_shape = [3, 224, 224] image_shape = [3, 224, 224]
...@@ -163,16 +162,15 @@ class TestCalibration(unittest.TestCase): ...@@ -163,16 +162,15 @@ class TestCalibration(unittest.TestCase):
print("Start calibration ...") print("Start calibration ...")
calibrator = ut.Calibrator( calibrator = int8_utility.Calibrator(
program=infer_program, program=infer_program,
pretrained_model=model_path, pretrained_model=model_path,
iterations=100, algo=algo,
debug=False, exe=exe,
algo=algo) output=int8_model,
feed_var_names=feed_dict,
sampling_data = {} fetch_list=fetch_targets)
calibrator.generate_sampling_program()
test_info = [] test_info = []
cnt = 0 cnt = 0
for batch_id, data in enumerate(val_reader()): for batch_id, data in enumerate(val_reader()):
...@@ -192,13 +190,7 @@ class TestCalibration(unittest.TestCase): ...@@ -192,13 +190,7 @@ class TestCalibration(unittest.TestCase):
feed_dict[1]: label}, feed_dict[1]: label},
fetch_list=fetch_targets) fetch_list=fetch_targets)
if generate_int8: if generate_int8:
for i in calibrator.sampling_program.list_vars(): calibrator.sample_data()
if i.name in calibrator.sampling_vars:
np_data = np.array(fluid.global_scope().find_var(i.name)
.get_tensor())
if i.name not in sampling_data:
sampling_data[i.name] = []
sampling_data[i.name].append(np_data)
test_info.append(np.mean(acc1) * len(data)) test_info.append(np.mean(acc1) * len(data))
cnt += len(data) cnt += len(data)
...@@ -209,9 +201,8 @@ class TestCalibration(unittest.TestCase): ...@@ -209,9 +201,8 @@ class TestCalibration(unittest.TestCase):
break break
if generate_int8: if generate_int8:
calibrator.generate_quantized_data(sampling_data) calibrator.save_int8_model()
fluid.io.save_inference_model(int8_model, feed_dict, fetch_targets,
exe, calibrator.sampling_program)
print( print(
"Calibration is done and the corresponding files were generated at {}". "Calibration is done and the corresponding files were generated at {}".
format(os.path.abspath("calibration_out"))) format(os.path.abspath("calibration_out")))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册