未验证 提交 8b74fc4f 编写于 作者: J juncaipeng 提交者: GitHub

Fix post training quantization (#21745)

* fix post training quantization bug of memory constrained, support the input be different, test=develop
上级 aa4d6a5d
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import re
import logging
import numpy as np
from ....executor import global_scope
......@@ -43,7 +45,9 @@ class PostTrainingQuantization(object):
scope=None,
algo="KL",
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False):
is_full_quantize=False,
is_use_cache_file=False,
cache_dir="./temp_post_training"):
'''
The class utilizes post training quantization methon to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of
......@@ -78,9 +82,16 @@ class PostTrainingQuantization(object):
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type.
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
all temp data will be saved in memory. If set is_use_cache_file as True,
it will save temp data to disk. When the fp32 model is complex or
the number of calibrate data is large, we should set is_use_cache_file
as True. Defalut is False.
cache_dir(str, optional): When is_use_cache_file is True, set cache_dir as
the directory for saving temp data. Default is ./temp_post_training.
Returns:
None
......@@ -129,6 +140,10 @@ class PostTrainingQuantization(object):
self._batch_nums = batch_nums
self._scope = global_scope() if scope == None else scope
self._algo = algo
self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir
if self._is_use_cache_file and not os.path.exists(self._cache_dir):
os.mkdir(self._cache_dir)
supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type + \
......@@ -150,8 +165,8 @@ class PostTrainingQuantization(object):
self._op_real_in_out_name = _op_real_in_out_name
self._bit_length = 8
self._quantized_weight_var_name = []
self._quantized_act_var_name = []
self._quantized_weight_var_name = set()
self._quantized_act_var_name = set()
self._sampling_data = {}
self._quantized_var_scale_factor = {}
......@@ -174,7 +189,8 @@ class PostTrainingQuantization(object):
feed=data,
fetch_list=self._fetch_list,
return_numpy=False)
self._sample_data()
self._sample_data(batch_id)
if batch_id % 5 == 0:
_logger.info("run batch: " + str(batch_id))
batch_id += 1
......@@ -238,10 +254,9 @@ class PostTrainingQuantization(object):
op_type = op.type
if op_type in self._quantizable_op_type:
if op_type in ("conv2d", "depthwise_conv2d"):
self._quantized_act_var_name.append(op.input("Input")[0])
self._quantized_weight_var_name.append(
op.input("Filter")[0])
self._quantized_act_var_name.append(op.output("Output")[0])
self._quantized_act_var_name.add(op.input("Input")[0])
self._quantized_weight_var_name.add(op.input("Filter")[0])
self._quantized_act_var_name.add(op.output("Output")[0])
elif op_type == "mul":
if self._is_input_all_not_persistable(
op, persistable_var_names):
......@@ -249,9 +264,9 @@ class PostTrainingQuantization(object):
_logger.warning("Skip quant a mul op for two "
"input variables are not persistable")
else:
self._quantized_act_var_name.append(op.input("X")[0])
self._quantized_weight_var_name.append(op.input("Y")[0])
self._quantized_act_var_name.append(op.output("Out")[0])
self._quantized_act_var_name.add(op.input("X")[0])
self._quantized_weight_var_name.add(op.input("Y")[0])
self._quantized_act_var_name.add(op.output("Out")[0])
else:
# process other quantizable op type, the input must all not persistable
if self._is_input_all_not_persistable(
......@@ -260,10 +275,10 @@ class PostTrainingQuantization(object):
op_type]
for input_name in input_output_name_list[0]:
for var_name in op.input(input_name):
self._quantized_act_var_name.append(var_name)
self._quantized_act_var_name.add(var_name)
for output_name in input_output_name_list[1]:
for var_name in op.output(output_name):
self._quantized_act_var_name.append(var_name)
self._quantized_act_var_name.add(var_name)
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
......@@ -271,7 +286,7 @@ class PostTrainingQuantization(object):
if var.name in self._quantized_act_var_name:
var.persistable = True
def _sample_data(self):
def _sample_data(self, iter):
'''
Sample the tensor data of quantized variables,
applied in every iteration.
......@@ -281,11 +296,20 @@ class PostTrainingQuantization(object):
var_tensor = self._load_var_value(var_name)
self._sampling_data[var_name] = var_tensor
for var_name in self._quantized_act_var_name:
if var_name not in self._sampling_data:
self._sampling_data[var_name] = []
var_tensor = self._load_var_value(var_name)
self._sampling_data[var_name].append(var_tensor)
if self._is_use_cache_file:
for var_name in self._quantized_act_var_name:
var_tensor = self._load_var_value(var_name)
var_tensor = var_tensor.ravel()
save_path = os.path.join(self._cache_dir,
var_name + "_" + str(iter) + ".npy")
np.save(save_path, var_tensor)
else:
for var_name in self._quantized_act_var_name:
if var_name not in self._sampling_data:
self._sampling_data[var_name] = []
var_tensor = self._load_var_value(var_name)
var_tensor = var_tensor.ravel()
self._sampling_data[var_name].append(var_tensor)
def _calculate_scale_factor(self):
'''
......@@ -302,13 +326,33 @@ class PostTrainingQuantization(object):
var_name] = scale_factor_per_channel
# apply kl quantization for activation
for var_name in self._quantized_act_var_name:
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(self._sampling_data[var_name]))
if self._is_use_cache_file:
for var_name in self._quantized_act_var_name:
sampling_data = []
filenames = [f for f in os.listdir(self._cache_dir) \
if re.match(var_name + '_[0-9]+.npy', f)]
for filename in filenames:
file_path = os.path.join(self._cache_dir, filename)
sampling_data.append(np.load(file_path))
os.remove(file_path)
sampling_data = np.concatenate(sampling_data)
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(sampling_data))
else:
for var_name in self._quantized_act_var_name:
self._sampling_data[var_name] = np.concatenate(
self._sampling_data[var_name])
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(self._sampling_data[var_name]))
def _update_program(self):
'''
......
......@@ -237,7 +237,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
def generate_quantized_model(self,
model_path,
algo="KL",
is_full_quantize=False):
is_full_quantize=False,
is_use_cache_file=False):
try:
os.system("mkdir " + self.int8_model)
except Exception as e:
......@@ -259,11 +260,13 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_dir=model_path,
algo=algo,
quantizable_op_type=quantizable_op_type,
is_full_quantize=is_full_quantize)
is_full_quantize=is_full_quantize,
is_use_cache_file=is_use_cache_file)
ptq.quantize()
ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s):
def run_test(self, model, algo, data_urls, data_md5s, is_full_quantize,
is_use_cache_file):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
......@@ -277,8 +280,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
print("Start INT8 post training quantization for {0} on {1} images ...".
format(model, sample_iterations * batch_size))
self.generate_quantized_model(
model_cache_folder + "/model", algo=algo, is_full_quantize=True)
self.generate_quantized_model(model_cache_folder + "/model", algo,
is_full_quantize, is_use_cache_file)
print("Start INT8 inference for {0} on {1} images ...".format(
model, infer_iterations * batch_size))
......@@ -305,7 +308,10 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
self.run_test(model, algo, data_urls, data_md5s)
is_full_quantize = True
is_use_cache_file = False
self.run_test(model, algo, data_urls, data_md5s, is_full_quantize,
is_use_cache_file)
if __name__ == '__main__':
......
......@@ -25,7 +25,10 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
data_md5s = ['4a5194524823d9b76da6e738e1367881']
self.run_test(model, algo, data_urls, data_md5s)
is_full_quantize = False
is_use_cache_file = True
self.run_test(model, algo, data_urls, data_md5s, is_full_quantize,
is_use_cache_file)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册