未验证 提交 cbce658d 编写于 作者: W wuyefeilin 提交者: GitHub

update post_quantization.py (#255)

* update train.py

* update post_quantization.py
上级 27121d0f
...@@ -14,12 +14,14 @@ ...@@ -14,12 +14,14 @@
from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import _op_real_in_out_name from paddle.fluid.contrib.slim.quantization.quantization_pass import _out_scale_op_list
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
import utils.logging as logging
import paddle.fluid as fluid import paddle.fluid as fluid
import os import os
import re
import utils.logging as logging import numpy as np
import time
class HumanSegPostTrainingQuantization(PostTrainingQuantization): class HumanSegPostTrainingQuantization(PostTrainingQuantization):
...@@ -42,7 +44,6 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization): ...@@ -42,7 +44,6 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
fp32 model. It uses calibrate data to calculate the scale factor of fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the quantized variables, and inserts fake quant/dequant op to obtain the
quantized model. quantized model.
Args: Args:
executor(fluid.Executor): The executor to load, run and save the executor(fluid.Executor): The executor to load, run and save the
quantized model. quantized model.
...@@ -76,6 +77,21 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization): ...@@ -76,6 +77,21 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
Returns: Returns:
None None
''' '''
self._support_activation_quantize_type = [
'range_abs_max', 'moving_average_abs_max', 'abs_max'
]
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
self._support_algo_type = ['KL', 'abs_max', 'min_max']
self._support_quantize_op_type = \
list(set(QuantizationTransformPass._supported_quantizable_op_type +
AddQuantDequantPass._supported_quantizable_op_type))
# Check inputs
assert executor is not None, "The executor cannot be None."
assert batch_size > 0, "The batch_size should be greater than 0."
assert algo in self._support_algo_type, \
"The algo should be KL, abs_max or min_max."
self._executor = executor self._executor = executor
self._dataset = dataset self._dataset = dataset
self._batch_size = batch_size self._batch_size = batch_size
...@@ -84,18 +100,19 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization): ...@@ -84,18 +100,19 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
self._algo = algo self._algo = algo
self._is_use_cache_file = is_use_cache_file self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir self._cache_dir = cache_dir
self._activation_bits = 8
self._weight_bits = 8
self._activation_quantize_type = 'range_abs_max'
self._weight_quantize_type = 'channel_wise_abs_max'
if self._is_use_cache_file and not os.path.exists(self._cache_dir): if self._is_use_cache_file and not os.path.exists(self._cache_dir):
os.mkdir(self._cache_dir) os.mkdir(self._cache_dir)
supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type + \
AddQuantDequantPass._supported_quantizable_op_type
if is_full_quantize: if is_full_quantize:
self._quantizable_op_type = supported_quantizable_op_type self._quantizable_op_type = self._support_quantize_op_type
else: else:
self._quantizable_op_type = quantizable_op_type self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type: for op_type in self._quantizable_op_type:
assert op_type in supported_quantizable_op_type + \ assert op_type in self._support_quantize_op_type + \
AddQuantDequantPass._activation_type, \ AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization." op_type + " is not supported for quantization."
...@@ -105,53 +122,72 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization): ...@@ -105,53 +122,72 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
self._fetch_list = list(outputs.values()) self._fetch_list = list(outputs.values())
self._data_loader = None self._data_loader = None
self._op_real_in_out_name = _op_real_in_out_name self._out_scale_op_list = _out_scale_op_list
self._bit_length = 8 self._bit_length = 8
self._quantized_weight_var_name = set() self._quantized_weight_var_name = set()
self._quantized_act_var_name = set() self._quantized_act_var_name = set()
self._sampling_data = {} self._sampling_data = {}
self._quantized_var_scale_factor = {} self._quantized_var_kl_threshold = {}
self._quantized_var_min = {}
self._quantized_var_max = {}
self._quantized_var_abs_max = {}
def quantize(self): def quantize(self):
''' '''
Quantize the fp32 model. Use calibrate data to calculate the scale factor of Quantize the fp32 model. Use calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the quantized variables, and inserts fake quant/dequant op to obtain the
quantized model. quantized model.
Args: Args:
None None
Returns: Returns:
the program of quantized model. the program of quantized model.
''' '''
self._preprocess() self._load_model_data()
self._collect_target_varnames()
self._set_activation_persistable()
batch_ct = 0
for data in self._data_loader():
batch_ct += 1
if self._batch_nums and batch_ct >= self._batch_nums:
break
batch_id = 0 batch_id = 0
logging.info("Start to run batch!")
for data in self._data_loader(): for data in self._data_loader():
start = time.time()
self._executor.run( self._executor.run(
program=self._program, program=self._program,
feed=data, feed=data,
fetch_list=self._fetch_list, fetch_list=self._fetch_list,
return_numpy=False) return_numpy=False)
if self._algo == "KL":
self._sample_data(batch_id) self._sample_data(batch_id)
else:
if batch_id % 5 == 0: self._sample_threshold()
logging.info("run batch: {}".format(batch_id)) end = time.time()
logging.debug(
'[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
str(batch_id + 1), str(batch_ct), str(end - start)))
batch_id += 1 batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums: if self._batch_nums and batch_id >= self._batch_nums:
break break
logging.info("all run batch: ".format(batch_id)) logging.info("All run batch: ".format(batch_id))
logging.info("calculate scale factor ...") self._reset_activation_persistable()
self._calculate_scale_factor() logging.info("Calculate scale factor ...")
logging.info("update the program ...") if self._algo == "KL":
self._calculate_kl_threshold()
logging.info("Update the program ...")
if self._algo in ["KL", "abs_max"]:
self._update_program() self._update_program()
else:
self._save_output_scale() self._save_input_threhold()
logging.info("Save ...")
self._save_output_threshold()
logging.info("Finish quant!")
return self._program return self._program
def save_quantized_model(self, save_model_path): def save_quantized_model(self, save_model_path):
''' '''
Save the quantized model to the disk. Save the quantized model to the disk.
Args: Args:
save_model_path(str): The path to save the quantized model save_model_path(str): The path to save the quantized model
Returns: Returns:
...@@ -166,59 +202,78 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization): ...@@ -166,59 +202,78 @@ class HumanSegPostTrainingQuantization(PostTrainingQuantization):
params_filename='__params__', params_filename='__params__',
main_program=self._program) main_program=self._program)
def _preprocess(self): def _load_model_data(self):
''' '''
Load model and set data loader, collect the variable names for sampling, Set data loader.
and set activation variables to be persistable.
''' '''
feed_vars = [fluid.framework._get_var(var.name, self._program) \ feed_vars = [fluid.framework._get_var(var.name, self._program) \
for var in self._feed_list] for var in self._feed_list]
self._data_loader = fluid.io.DataLoader.from_generator( self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True) feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
self._data_loader.set_sample_list_generator( self._data_loader.set_sample_list_generator(
self._dataset.generator(self._batch_size, drop_last=True), self._dataset.generator(self._batch_size, drop_last=True),
places=self._place) places=self._place)
# collect the variable names for sampling def _calculate_kl_threshold(self):
persistable_var_names = [] '''
for var in self._program.list_vars(): Calculate the KL threshold of quantized variables.
if var.persistable: '''
persistable_var_names.append(var.name) assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
ct = 1
# Abs_max threshold for weights
for var_name in self._quantized_weight_var_name:
start = time.time()
weight_data = self._sampling_data[var_name]
weight_threshold = None
if self._weight_quantize_type == "abs_max":
weight_threshold = np.max(np.abs(weight_data))
elif self._weight_quantize_type == "channel_wise_abs_max":
weight_threshold = []
for i in range(weight_data.shape[0]):
abs_max_value = np.max(np.abs(weight_data[i]))
weight_threshold.append(abs_max_value)
self._quantized_var_kl_threshold[var_name] = weight_threshold
end = time.time()
logging.debug(
'[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.
format(
str(ct), str(len(self._quantized_weight_var_name)),
str(end - start)))
ct += 1
for op in self._program.global_block().ops: ct = 1
op_type = op.type # KL threshold for activations
if op_type in self._quantizable_op_type: if self._is_use_cache_file:
if op_type in ("conv2d", "depthwise_conv2d"): for var_name in self._quantized_act_var_name:
self._quantized_act_var_name.add(op.input("Input")[0]) start = time.time()
self._quantized_weight_var_name.add(op.input("Filter")[0]) sampling_data = []
self._quantized_act_var_name.add(op.output("Output")[0]) filenames = [f for f in os.listdir(self._cache_dir) \
elif op_type == "mul": if re.match(var_name + '_[0-9]+.npy', f)]
if self._is_input_all_not_persistable( for filename in filenames:
op, persistable_var_names): file_path = os.path.join(self._cache_dir, filename)
op._set_attr("skip_quant", True) sampling_data.append(np.load(file_path))
logging.warning( os.remove(file_path)
"Skip quant a mul op for two input variables are not persistable" sampling_data = np.concatenate(sampling_data)
) self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
end = time.time()
logging.debug(
'[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
.format(
str(ct), str(len(self._quantized_act_var_name)),
str(end - start)))
ct += 1
else: else:
self._quantized_act_var_name.add(op.input("X")[0]) for var_name in self._quantized_act_var_name:
self._quantized_weight_var_name.add(op.input("Y")[0]) start = time.time()
self._quantized_act_var_name.add(op.output("Out")[0]) self._sampling_data[var_name] = np.concatenate(
else: self._sampling_data[var_name])
# process other quantizable op type, the input must all not persistable self._quantized_var_kl_threshold[var_name] = \
if self._is_input_all_not_persistable( self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
op, persistable_var_names): end = time.time()
input_output_name_list = self._op_real_in_out_name[ logging.debug(
op_type] '[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'
for input_name in input_output_name_list[0]: .format(
for var_name in op.input(input_name): str(ct), str(len(self._quantized_act_var_name)),
self._quantized_act_var_name.add(var_name) str(end - start)))
for output_name in input_output_name_list[1]: ct += 1
for var_name in op.output(output_name):
self._quantized_act_var_name.add(var_name)
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = True
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册