未验证 提交 cbce658d 编写于 作者: W wuyefeilin 提交者: GitHub

update post_quantization.py (#255)

* update train.py

* update post_quantization.py
上级 27121d0f
......@@ -14,12 +14,14 @@
from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import _op_real_in_out_name
from paddle.fluid.contrib.slim.quantization.quantization_pass import _out_scale_op_list
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
import utils.logging as logging
import paddle.fluid as fluid
import os
import utils.logging as logging
import re
import numpy as np
import time
class HumanSegPostTrainingQuantization(PostTrainingQuantization):
......@@ -42,7 +44,6 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args:
executor(fluid.Executor): The executor to load, run and save the
quantized model.
......@@ -76,6 +77,21 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
Returns:
None
'''
self._support_activation_quantize_type = [
'range_abs_max', 'moving_average_abs_max', 'abs_max'
]
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
self._support_algo_type = ['KL', 'abs_max', 'min_max']
self._support_quantize_op_type = \
list(set(QuantizationTransformPass._supported_quantizable_op_type +
AddQuantDequantPass._supported_quantizable_op_type))
# Check inputs
assert executor is not None, "The executor cannot be None."
assert batch_size > 0, "The batch_size should be greater than 0."
assert algo in self._support_algo_type, \
"The algo should be KL, abs_max or min_max."
self._executor = executor
self._dataset = dataset
self._batch_size = batch_size
......@@ -84,18 +100,19 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
self._algo = algo
self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir
self._activation_bits = 8
self._weight_bits = 8
self._activation_quantize_type = 'range_abs_max'
self._weight_quantize_type = 'channel_wise_abs_max'
if self._is_use_cache_file and not os.path.exists(self._cache_dir):
os.mkdir(self._cache_dir)
supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type + \
AddQuantDequantPass._supported_quantizable_op_type
if is_full_quantize:
self._quantizable_op_type = supported_quantizable_op_type
self._quantizable_op_type = self._support_quantize_op_type
else:
self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type:
assert op_type in supported_quantizable_op_type + \
assert op_type in self._support_quantize_op_type + \
AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization."
......@@ -105,53 +122,72 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
self._fetch_list = list(outputs.values())
self._data_loader = None
self._op_real_in_out_name = _op_real_in_out_name
self._out_scale_op_list = _out_scale_op_list
self._bit_length = 8
self._quantized_weight_var_name = set()
self._quantized_act_var_name = set()
self._sampling_data = {}
self._quantized_var_scale_factor = {}
self._quantized_var_kl_threshold = {}
self._quantized_var_min = {}
self._quantized_var_max = {}
self._quantized_var_abs_max = {}
def quantize(self):
'''
Quantize the fp32 model. Use calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args:
None
Returns:
the program of quantized model.
'''
self._preprocess()
self._load_model_data()
self._collect_target_varnames()
self._set_activation_persistable()
batch_ct = 0
for data in self._data_loader():
batch_ct += 1
if self._batch_nums and batch_ct >= self._batch_nums:
break
batch_id = 0
logging.info("Start to run batch!")
for data in self._data_loader():
start = time.time()
self._executor.run(
program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False)
if self._algo == "KL":
self._sample_data(batch_id)
if batch_id % 5 == 0:
logging.info("run batch: {}".format(batch_id))
else:
self._sample_threshold()
end = time.time()
logging.debug(
'[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
str(batch_id + 1), str(batch_ct), str(end - start)))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
logging.info("all run batch: ".format(batch_id))
logging.info("calculate scale factor ...")
self._calculate_scale_factor()
logging.info("update the program ...")
logging.info("All run batch: ".format(batch_id))
self._reset_activation_persistable()
logging.info("Calculate scale factor ...")
if self._algo == "KL":
self._calculate_kl_threshold()
logging.info("Update the program ...")
if self._algo in ["KL", "abs_max"]:
self._update_program()
self._save_output_scale()
else:
self._save_input_threhold()
logging.info("Save ...")
self._save_output_threshold()
logging.info("Finish quant!")
return self._program
def save_quantized_model(self, save_model_path):
'''
Save the quantized model to the disk.
Args:
save_model_path(str): The path to save the quantized model
Returns:
......@@ -166,59 +202,78 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
params_filename='__params__',
main_program=self._program)
def _preprocess(self):
def _load_model_data(self):
'''
Load model and set data loader, collect the variable names for sampling,
and set activation variables to be persistable.
Set data loader.
'''
feed_vars = [fluid.framework._get_var(var.name, self._program) \
for var in self._feed_list]
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
self._data_loader.set_sample_list_generator(
self._dataset.generator(self._batch_size, drop_last=True),
places=self._place)
# collect the variable names for sampling
persistable_var_names = []
for var in self._program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
def _calculate_kl_threshold(self):
'''
Calculate the KL threshold of quantized variables.
'''
assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
ct = 1
# Abs_max threshold for weights
for var_name in self._quantized_weight_var_name:
start = time.time()
weight_data = self._sampling_data[var_name]
weight_threshold = None
if self._weight_quantize_type == "abs_max":
weight_threshold = np.max(np.abs(weight_data))
elif self._weight_quantize_type == "channel_wise_abs_max":
weight_threshold = []
for i in range(weight_data.shape[0]):
abs_max_value = np.max(np.abs(weight_data[i]))
weight_threshold.append(abs_max_value)
self._quantized_var_kl_threshold[var_name] = weight_threshold
end = time.time()
logging.debug(
'[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.
format(
str(ct), str(len(self._quantized_weight_var_name)),
str(end - start)))
ct += 1
for op in self._program.global_block().ops:
op_type = op.type
if op_type in self._quantizable_op_type:
if op_type in ("conv2d", "depthwise_conv2d"):
self._quantized_act_var_name.add(op.input("Input")[0])
self._quantized_weight_var_name.add(op.input("Filter")[0])
self._quantized_act_var_name.add(op.output("Output")[0])
elif op_type == "mul":
if self._is_input_all_not_persistable(
op, persistable_var_names):
op._set_attr("skip_quant", True)
logging.warning(
"Skip quant a mul op for two input variables are not persistable"
)
ct = 1
# KL threshold for activations
if self._is_use_cache_file:
for var_name in self._quantized_act_var_name:
start = time.time()
sampling_data = []
filenames = [f for f in os.listdir(self._cache_dir) \
if re.match(var_name + '_[0-9]+.npy', f)]
for filename in filenames:
file_path = os.path.join(self._cache_dir, filename)
sampling_data.append(np.load(file_path))
os.remove(file_path)
sampling_data = np.concatenate(sampling_data)
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
end = time.time()
logging.debug(
'[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
.format(
str(ct), str(len(self._quantized_act_var_name)),
str(end - start)))
ct += 1
else:
self._quantized_act_var_name.add(op.input("X")[0])
self._quantized_weight_var_name.add(op.input("Y")[0])
self._quantized_act_var_name.add(op.output("Out")[0])
else:
# process other quantizable op type, the input must all not persistable
if self._is_input_all_not_persistable(
op, persistable_var_names):
input_output_name_list = self._op_real_in_out_name[
op_type]
for input_name in input_output_name_list[0]:
for var_name in op.input(input_name):
self._quantized_act_var_name.add(var_name)
for output_name in input_output_name_list[1]:
for var_name in op.output(output_name):
self._quantized_act_var_name.add(var_name)
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = True
for var_name in self._quantized_act_var_name:
start = time.time()
self._sampling_data[var_name] = np.concatenate(
self._sampling_data[var_name])
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
end = time.time()
logging.debug(
'[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
.format(
str(ct), str(len(self._quantized_act_var_name)),
str(end - start)))
ct += 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册