未验证 提交 fe50020f 编写于 作者: J Jason 提交者: GitHub

Merge pull request #76 from SunAhong1993/syf0519

fix the post quant
...@@ -139,9 +139,10 @@ class BaseAPI: ...@@ -139,9 +139,10 @@ class BaseAPI:
dataset.num_samples = batch_size * batch_num dataset.num_samples = batch_size * batch_num
try: try:
from .slim.post_quantization import PaddleXPostTrainingQuantization from .slim.post_quantization import PaddleXPostTrainingQuantization
PaddleXPostTrainingQuantization._collect_target_varnames
except: except:
raise Exception( raise Exception(
"Model Quantization is not available, try to upgrade your paddlepaddle>=1.7.0" "Model Quantization is not available, try to upgrade your paddlepaddle>=1.8.0"
) )
is_use_cache_file = True is_use_cache_file = True
if cache_dir is None: if cache_dir is None:
...@@ -544,4 +545,4 @@ class BaseAPI: ...@@ -544,4 +545,4 @@ class BaseAPI:
best_accuracy)) best_accuracy))
if eval_dataset is not None and early_stop: if eval_dataset is not None and early_stop:
if earlystop(current_accuracy): if earlystop(current_accuracy):
break break
\ No newline at end of file
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization.quantization_pass import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass from paddle.fluid.contrib.slim.quantization.quantization_pass import AddQuantDequantPass
from paddle.fluid.contrib.slim.quantization.quantization_pass import _op_real_in_out_name from paddle.fluid.contrib.slim.quantization.quantization_pass import _out_scale_op_list
from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization from paddle.fluid.contrib.slim.quantization import PostTrainingQuantization
import paddlex.utils.logging as logging import paddlex.utils.logging as logging
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -44,7 +44,6 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -44,7 +44,6 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
fp32 model. It uses calibrate data to calculate the scale factor of fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the quantized variables, and inserts fake quant/dequant op to obtain the
quantized model. quantized model.
Args: Args:
executor(fluid.Executor): The executor to load, run and save the executor(fluid.Executor): The executor to load, run and save the
quantized model. quantized model.
...@@ -78,6 +77,21 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -78,6 +77,21 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
Returns: Returns:
None None
''' '''
self._support_activation_quantize_type = [
'range_abs_max', 'moving_average_abs_max', 'abs_max'
]
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
self._support_algo_type = ['KL', 'abs_max', 'min_max']
self._support_quantize_op_type = \
list(set(QuantizationTransformPass._supported_quantizable_op_type +
AddQuantDequantPass._supported_quantizable_op_type))
# Check inputs
assert executor is not None, "The executor cannot be None."
assert batch_size > 0, "The batch_size should be greater than 0."
assert algo in self._support_algo_type, \
"The algo should be KL, abs_max or min_max."
self._executor = executor self._executor = executor
self._dataset = dataset self._dataset = dataset
self._batch_size = batch_size self._batch_size = batch_size
...@@ -86,18 +100,19 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -86,18 +100,19 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
self._algo = algo self._algo = algo
self._is_use_cache_file = is_use_cache_file self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir self._cache_dir = cache_dir
self._activation_bits = 8
self._weight_bits = 8
self._activation_quantize_type = 'range_abs_max'
self._weight_quantize_type = 'channel_wise_abs_max'
if self._is_use_cache_file and not os.path.exists(self._cache_dir): if self._is_use_cache_file and not os.path.exists(self._cache_dir):
os.mkdir(self._cache_dir) os.mkdir(self._cache_dir)
supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type + \
AddQuantDequantPass._supported_quantizable_op_type
if is_full_quantize: if is_full_quantize:
self._quantizable_op_type = supported_quantizable_op_type self._quantizable_op_type = self._support_quantize_op_type
else: else:
self._quantizable_op_type = quantizable_op_type self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type: for op_type in self._quantizable_op_type:
assert op_type in supported_quantizable_op_type + \ assert op_type in self._support_quantize_op_type + \
AddQuantDequantPass._activation_type, \ AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization." op_type + " is not supported for quantization."
...@@ -107,25 +122,29 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -107,25 +122,29 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
self._fetch_list = list(outputs.values()) self._fetch_list = list(outputs.values())
self._data_loader = None self._data_loader = None
self._op_real_in_out_name = _op_real_in_out_name self._out_scale_op_list = _out_scale_op_list
self._bit_length = 8 self._bit_length = 8
self._quantized_weight_var_name = set() self._quantized_weight_var_name = set()
self._quantized_act_var_name = set() self._quantized_act_var_name = set()
self._sampling_data = {} self._sampling_data = {}
self._quantized_var_scale_factor = {} self._quantized_var_kl_threshold = {}
self._quantized_var_min = {}
self._quantized_var_max = {}
self._quantized_var_abs_max = {}
def quantize(self): def quantize(self):
''' '''
Quantize the fp32 model. Use calibrate data to calculate the scale factor of Quantize the fp32 model. Use calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the quantized variables, and inserts fake quant/dequant op to obtain the
quantized model. quantized model.
Args: Args:
None None
Returns: Returns:
the program of quantized model. the program of quantized model.
''' '''
self._preprocess() self._load_model_data()
self._collect_target_varnames()
self._set_activation_persistable()
batch_ct = 0 batch_ct = 0
for data in self._data_loader(): for data in self._data_loader():
batch_ct += 1 batch_ct += 1
...@@ -140,7 +159,10 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -140,7 +159,10 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
feed=data, feed=data,
fetch_list=self._fetch_list, fetch_list=self._fetch_list,
return_numpy=False) return_numpy=False)
self._sample_data(batch_id) if self._algo == "KL":
self._sample_data(batch_id)
else:
self._sample_threshold()
end = time.time() end = time.time()
logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} s.'.format( logging.debug('[Run batch data] Batch={}/{}, time_each_batch={} s.'.format(
str(batch_id + 1), str(batch_id + 1),
...@@ -150,19 +172,23 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -150,19 +172,23 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
if self._batch_nums and batch_id >= self._batch_nums: if self._batch_nums and batch_id >= self._batch_nums:
break break
logging.info("All run batch: ".format(batch_id)) logging.info("All run batch: ".format(batch_id))
self._reset_activation_persistable()
logging.info("Calculate scale factor ...") logging.info("Calculate scale factor ...")
self._calculate_scale_factor() if self._algo == "KL":
self._calculate_kl_threshold()
logging.info("Update the program ...") logging.info("Update the program ...")
self._update_program() if self._algo in ["KL", "abs_max"]:
self._update_program()
else:
self._save_input_threhold()
logging.info("Save ...") logging.info("Save ...")
self._save_output_scale() self._save_output_threshold()
logging.info("Finish quant!") logging.info("Finish quant!")
return self._program return self._program
def save_quantized_model(self, save_model_path): def save_quantized_model(self, save_model_path):
''' '''
Save the quantized model to the disk. Save the quantized model to the disk.
Args: Args:
save_model_path(str): The path to save the quantized model save_model_path(str): The path to save the quantized model
Returns: Returns:
...@@ -176,88 +202,47 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -176,88 +202,47 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
executor=self._executor, executor=self._executor,
params_filename='__params__', params_filename='__params__',
main_program=self._program) main_program=self._program)
def _preprocess(self): def _load_model_data(self):
''' '''
Load model and set data loader, collect the variable names for sampling, Set data loader.
and set activation variables to be persistable.
''' '''
feed_vars = [fluid.framework._get_var(var.name, self._program) \ feed_vars = [fluid.framework._get_var(var.name, self._program) \
for var in self._feed_list] for var in self._feed_list]
self._data_loader = fluid.io.DataLoader.from_generator( self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True) feed_list=feed_vars, capacity=3 * self._batch_size, iterable=True)
self._data_loader.set_sample_list_generator( self._data_loader.set_sample_list_generator(
self._dataset.generator(self._batch_size, drop_last=True), self._dataset.generator(self._batch_size, drop_last=True),
places=self._place) places=self._place)
# collect the variable names for sampling def _calculate_kl_threshold(self):
persistable_var_names = []
for var in self._program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
for op in self._program.global_block().ops:
op_type = op.type
if op_type in self._quantizable_op_type:
if op_type in ("conv2d", "depthwise_conv2d"):
self._quantized_act_var_name.add(op.input("Input")[0])
self._quantized_weight_var_name.add(op.input("Filter")[0])
self._quantized_act_var_name.add(op.output("Output")[0])
elif op_type == "mul":
if self._is_input_all_not_persistable(
op, persistable_var_names):
op._set_attr("skip_quant", True)
logging.warning(
"Skip quant a mul op for two input variables are not persistable"
)
else:
self._quantized_act_var_name.add(op.input("X")[0])
self._quantized_weight_var_name.add(op.input("Y")[0])
self._quantized_act_var_name.add(op.output("Out")[0])
else:
# process other quantizable op type, the input must all not persistable
if self._is_input_all_not_persistable(
op, persistable_var_names):
input_output_name_list = self._op_real_in_out_name[
op_type]
for input_name in input_output_name_list[0]:
for var_name in op.input(input_name):
self._quantized_act_var_name.add(var_name)
for output_name in input_output_name_list[1]:
for var_name in op.output(output_name):
self._quantized_act_var_name.add(var_name)
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = True
def _calculate_scale_factor(self):
''' '''
Calculate the scale factor of quantized variables. Calculate the KL threshold of quantized variables.
''' '''
# apply channel_wise_abs_max quantization for weights assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
ct = 1 ct = 1
# Abs_max threshold for weights
for var_name in self._quantized_weight_var_name: for var_name in self._quantized_weight_var_name:
start = time.time() start = time.time()
data = self._sampling_data[var_name] weight_data = self._sampling_data[var_name]
scale_factor_per_channel = [] weight_threshold = None
for i in range(data.shape[0]): if self._weight_quantize_type == "abs_max":
abs_max_value = np.max(np.abs(data[i])) weight_threshold = np.max(np.abs(weight_data))
scale_factor_per_channel.append(abs_max_value) elif self._weight_quantize_type == "channel_wise_abs_max":
self._quantized_var_scale_factor[ weight_threshold = []
var_name] = scale_factor_per_channel for i in range(weight_data.shape[0]):
abs_max_value = np.max(np.abs(weight_data[i]))
weight_threshold.append(abs_max_value)
self._quantized_var_kl_threshold[var_name] = weight_threshold
end = time.time() end = time.time()
logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.format( logging.debug('[Calculate weight] Weight_id={}/{}, time_each_weight={} s.'.format(
str(ct), str(ct),
str(len(self._quantized_weight_var_name)), str(len(self._quantized_weight_var_name)),
str(end-start))) str(end-start)))
ct += 1 ct += 1
ct = 1 ct = 1
# apply kl quantization for activation # KL threshold for activations
if self._is_use_cache_file: if self._is_use_cache_file:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
start = time.time() start = time.time()
...@@ -269,13 +254,8 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -269,13 +254,8 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
sampling_data.append(np.load(file_path)) sampling_data.append(np.load(file_path))
os.remove(file_path) os.remove(file_path)
sampling_data = np.concatenate(sampling_data) sampling_data = np.concatenate(sampling_data)
self._quantized_var_kl_threshold[var_name] = \
if self._algo == "KL": self._get_kl_scaling_factor(np.abs(sampling_data))
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(sampling_data))
end = time.time() end = time.time()
logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format( logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
str(ct), str(ct),
...@@ -287,15 +267,13 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization): ...@@ -287,15 +267,13 @@ class PaddleXPostTrainingQuantization(PostTrainingQuantization):
start = time.time() start = time.time()
self._sampling_data[var_name] = np.concatenate( self._sampling_data[var_name] = np.concatenate(
self._sampling_data[var_name]) self._sampling_data[var_name])
if self._algo == "KL": self._quantized_var_kl_threshold[var_name] = \
self._quantized_var_scale_factor[var_name] = \ self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(self._sampling_data[var_name]))
end = time.time() end = time.time()
logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format( logging.debug('[Calculate activation] Activation_id={}/{}, time_each_activation={} s.'.format(
str(ct), str(ct),
str(len(self._quantized_act_var_name)), str(len(self._quantized_act_var_name)),
str(end-start))) str(end-start)))
ct += 1 ct += 1
\ No newline at end of file
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册