未验证 提交 8b74fc4f 编写于 作者: J juncaipeng 提交者: GitHub

Fix post training quantization (#21745)

* fix post training quantization bug of memory constrained, support the input be different, test=develop
上级 aa4d6a5d
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
import os
import re
import logging import logging
import numpy as np import numpy as np
from ....executor import global_scope from ....executor import global_scope
...@@ -43,7 +45,9 @@ class PostTrainingQuantization(object): ...@@ -43,7 +45,9 @@ class PostTrainingQuantization(object):
scope=None, scope=None,
algo="KL", algo="KL",
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False): is_full_quantize=False,
is_use_cache_file=False,
cache_dir="./temp_post_training"):
''' '''
The class utilizes post training quantization methon to quantize the The class utilizes post training quantization methon to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of fp32 model. It uses calibrate data to calculate the scale factor of
...@@ -78,9 +82,16 @@ class PostTrainingQuantization(object): ...@@ -78,9 +82,16 @@ class PostTrainingQuantization(object):
that will be quantized. Default is ["conv2d", "depthwise_conv2d", that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"]. "mul"].
is_full_quantized(bool, optional): If set is_full_quantized as True, is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type. according to the input quantizable_op_type.
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
all temp data will be saved in memory. If set is_use_cache_file as True,
it will save temp data to disk. When the fp32 model is complex or
the number of calibrate data is large, we should set is_use_cache_file
as True. Defalut is False.
cache_dir(str, optional): When is_use_cache_file is True, set cache_dir as
the directory for saving temp data. Default is ./temp_post_training.
Returns: Returns:
None None
...@@ -129,6 +140,10 @@ class PostTrainingQuantization(object): ...@@ -129,6 +140,10 @@ class PostTrainingQuantization(object):
self._batch_nums = batch_nums self._batch_nums = batch_nums
self._scope = global_scope() if scope == None else scope self._scope = global_scope() if scope == None else scope
self._algo = algo self._algo = algo
self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir
if self._is_use_cache_file and not os.path.exists(self._cache_dir):
os.mkdir(self._cache_dir)
supported_quantizable_op_type = \ supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type + \ QuantizationTransformPass._supported_quantizable_op_type + \
...@@ -150,8 +165,8 @@ class PostTrainingQuantization(object): ...@@ -150,8 +165,8 @@ class PostTrainingQuantization(object):
self._op_real_in_out_name = _op_real_in_out_name self._op_real_in_out_name = _op_real_in_out_name
self._bit_length = 8 self._bit_length = 8
self._quantized_weight_var_name = [] self._quantized_weight_var_name = set()
self._quantized_act_var_name = [] self._quantized_act_var_name = set()
self._sampling_data = {} self._sampling_data = {}
self._quantized_var_scale_factor = {} self._quantized_var_scale_factor = {}
...@@ -174,7 +189,8 @@ class PostTrainingQuantization(object): ...@@ -174,7 +189,8 @@ class PostTrainingQuantization(object):
feed=data, feed=data,
fetch_list=self._fetch_list, fetch_list=self._fetch_list,
return_numpy=False) return_numpy=False)
self._sample_data() self._sample_data(batch_id)
if batch_id % 5 == 0: if batch_id % 5 == 0:
_logger.info("run batch: " + str(batch_id)) _logger.info("run batch: " + str(batch_id))
batch_id += 1 batch_id += 1
...@@ -238,10 +254,9 @@ class PostTrainingQuantization(object): ...@@ -238,10 +254,9 @@ class PostTrainingQuantization(object):
op_type = op.type op_type = op.type
if op_type in self._quantizable_op_type: if op_type in self._quantizable_op_type:
if op_type in ("conv2d", "depthwise_conv2d"): if op_type in ("conv2d", "depthwise_conv2d"):
self._quantized_act_var_name.append(op.input("Input")[0]) self._quantized_act_var_name.add(op.input("Input")[0])
self._quantized_weight_var_name.append( self._quantized_weight_var_name.add(op.input("Filter")[0])
op.input("Filter")[0]) self._quantized_act_var_name.add(op.output("Output")[0])
self._quantized_act_var_name.append(op.output("Output")[0])
elif op_type == "mul": elif op_type == "mul":
if self._is_input_all_not_persistable( if self._is_input_all_not_persistable(
op, persistable_var_names): op, persistable_var_names):
...@@ -249,9 +264,9 @@ class PostTrainingQuantization(object): ...@@ -249,9 +264,9 @@ class PostTrainingQuantization(object):
_logger.warning("Skip quant a mul op for two " _logger.warning("Skip quant a mul op for two "
"input variables are not persistable") "input variables are not persistable")
else: else:
self._quantized_act_var_name.append(op.input("X")[0]) self._quantized_act_var_name.add(op.input("X")[0])
self._quantized_weight_var_name.append(op.input("Y")[0]) self._quantized_weight_var_name.add(op.input("Y")[0])
self._quantized_act_var_name.append(op.output("Out")[0]) self._quantized_act_var_name.add(op.output("Out")[0])
else: else:
# process other quantizable op type, the input must all not persistable # process other quantizable op type, the input must all not persistable
if self._is_input_all_not_persistable( if self._is_input_all_not_persistable(
...@@ -260,10 +275,10 @@ class PostTrainingQuantization(object): ...@@ -260,10 +275,10 @@ class PostTrainingQuantization(object):
op_type] op_type]
for input_name in input_output_name_list[0]: for input_name in input_output_name_list[0]:
for var_name in op.input(input_name): for var_name in op.input(input_name):
self._quantized_act_var_name.append(var_name) self._quantized_act_var_name.add(var_name)
for output_name in input_output_name_list[1]: for output_name in input_output_name_list[1]:
for var_name in op.output(output_name): for var_name in op.output(output_name):
self._quantized_act_var_name.append(var_name) self._quantized_act_var_name.add(var_name)
# set activation variables to be persistable, so can obtain # set activation variables to be persistable, so can obtain
# the tensor data in sample_data # the tensor data in sample_data
...@@ -271,7 +286,7 @@ class PostTrainingQuantization(object): ...@@ -271,7 +286,7 @@ class PostTrainingQuantization(object):
if var.name in self._quantized_act_var_name: if var.name in self._quantized_act_var_name:
var.persistable = True var.persistable = True
def _sample_data(self): def _sample_data(self, iter):
''' '''
Sample the tensor data of quantized variables, Sample the tensor data of quantized variables,
applied in every iteration. applied in every iteration.
...@@ -281,11 +296,20 @@ class PostTrainingQuantization(object): ...@@ -281,11 +296,20 @@ class PostTrainingQuantization(object):
var_tensor = self._load_var_value(var_name) var_tensor = self._load_var_value(var_name)
self._sampling_data[var_name] = var_tensor self._sampling_data[var_name] = var_tensor
for var_name in self._quantized_act_var_name: if self._is_use_cache_file:
if var_name not in self._sampling_data: for var_name in self._quantized_act_var_name:
self._sampling_data[var_name] = [] var_tensor = self._load_var_value(var_name)
var_tensor = self._load_var_value(var_name) var_tensor = var_tensor.ravel()
self._sampling_data[var_name].append(var_tensor) save_path = os.path.join(self._cache_dir,
var_name + "_" + str(iter) + ".npy")
np.save(save_path, var_tensor)
else:
for var_name in self._quantized_act_var_name:
if var_name not in self._sampling_data:
self._sampling_data[var_name] = []
var_tensor = self._load_var_value(var_name)
var_tensor = var_tensor.ravel()
self._sampling_data[var_name].append(var_tensor)
def _calculate_scale_factor(self): def _calculate_scale_factor(self):
''' '''
...@@ -302,13 +326,33 @@ class PostTrainingQuantization(object): ...@@ -302,13 +326,33 @@ class PostTrainingQuantization(object):
var_name] = scale_factor_per_channel var_name] = scale_factor_per_channel
# apply kl quantization for activation # apply kl quantization for activation
for var_name in self._quantized_act_var_name: if self._is_use_cache_file:
if self._algo == "KL": for var_name in self._quantized_act_var_name:
self._quantized_var_scale_factor[var_name] = \ sampling_data = []
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name])) filenames = [f for f in os.listdir(self._cache_dir) \
else: if re.match(var_name + '_[0-9]+.npy', f)]
self._quantized_var_scale_factor[var_name] = \ for filename in filenames:
np.max(np.abs(self._sampling_data[var_name])) file_path = os.path.join(self._cache_dir, filename)
sampling_data.append(np.load(file_path))
os.remove(file_path)
sampling_data = np.concatenate(sampling_data)
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(sampling_data))
else:
for var_name in self._quantized_act_var_name:
self._sampling_data[var_name] = np.concatenate(
self._sampling_data[var_name])
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(self._sampling_data[var_name]))
def _update_program(self): def _update_program(self):
''' '''
......
...@@ -237,7 +237,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -237,7 +237,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def generate_quantized_model(self, def generate_quantized_model(self,
model_path, model_path,
algo="KL", algo="KL",
is_full_quantize=False): is_full_quantize=False,
is_use_cache_file=False):
try: try:
os.system("mkdir " + self.int8_model) os.system("mkdir " + self.int8_model)
except Exception as e: except Exception as e:
...@@ -259,11 +260,13 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -259,11 +260,13 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_dir=model_path, model_dir=model_path,
algo=algo, algo=algo,
quantizable_op_type=quantizable_op_type, quantizable_op_type=quantizable_op_type,
is_full_quantize=is_full_quantize) is_full_quantize=is_full_quantize,
is_use_cache_file=is_use_cache_file)
ptq.quantize() ptq.quantize()
ptq.save_quantized_model(self.int8_model) ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s): def run_test(self, model, algo, data_urls, data_md5s, is_full_quantize,
is_use_cache_file):
infer_iterations = self.infer_iterations infer_iterations = self.infer_iterations
batch_size = self.batch_size batch_size = self.batch_size
sample_iterations = self.sample_iterations sample_iterations = self.sample_iterations
...@@ -277,8 +280,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -277,8 +280,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size)) format(model, sample_iterations * batch_size))
self.generate_quantized_model( self.generate_quantized_model(model_cache_folder + "/model", algo,
model_cache_folder + "/model", algo=algo, is_full_quantize=True) is_full_quantize, is_use_cache_file)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size)) model, infer_iterations * batch_size))
...@@ -305,7 +308,10 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization): ...@@ -305,7 +308,10 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
] ]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
self.run_test(model, algo, data_urls, data_md5s) is_full_quantize = True
is_use_cache_file = False
self.run_test(model, algo, data_urls, data_md5s, is_full_quantize,
is_use_cache_file)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -25,7 +25,10 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): ...@@ -25,7 +25,10 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
] ]
data_md5s = ['4a5194524823d9b76da6e738e1367881'] data_md5s = ['4a5194524823d9b76da6e738e1367881']
self.run_test(model, algo, data_urls, data_md5s) is_full_quantize = False
is_use_cache_file = True
self.run_test(model, algo, data_urls, data_md5s, is_full_quantize,
is_use_cache_file)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册