未验证 提交 2d8281d5 编写于 作者: C cc 提交者: GitHub

Remove the cache in post_traning_quantization, test=develop (#26450)

* Remove the cache in post_traning_quantization, test=develop
上级 3ae3b864
...@@ -143,7 +143,7 @@ class PostTrainingQuantization(object): ...@@ -143,7 +143,7 @@ class PostTrainingQuantization(object):
weight_quantize_type='channel_wise_abs_max', weight_quantize_type='channel_wise_abs_max',
optimize_model=False, optimize_model=False,
is_use_cache_file=False, is_use_cache_file=False,
cache_dir="./temp_post_training"): cache_dir=None):
''' '''
Constructor. Constructor.
...@@ -206,13 +206,8 @@ class PostTrainingQuantization(object): ...@@ -206,13 +206,8 @@ class PostTrainingQuantization(object):
`conv2d/depthwise_conv2d + bn`, the weights scale for all channel will `conv2d/depthwise_conv2d + bn`, the weights scale for all channel will
be different. In address this problem, fuse the pattern before be different. In address this problem, fuse the pattern before
quantization. Default False. quantization. Default False.
is_use_cache_file(bool, optional): If set is_use_cache_file as False, is_use_cache_file(bool, optional): This param is deprecated.
all temp data will be saved in memory. If set is_use_cache_file as True, cache_dir(str, optional): This param is deprecated.
it will save temp data to disk. When the fp32 model is complex or
the number of calibrate data is large, we should set is_use_cache_file
as True. Defalut is False.
cache_dir(str, optional): When is_use_cache_file is True, set cache_dir as
the directory for saving temp data. Default is ./temp_post_training.
Returns: Returns:
None None
...@@ -302,10 +297,6 @@ class PostTrainingQuantization(object): ...@@ -302,10 +297,6 @@ class PostTrainingQuantization(object):
assert op_type in self._support_quantize_op_type, \ assert op_type in self._support_quantize_op_type, \
op_type + " is not supported for quantization." op_type + " is not supported for quantization."
self._optimize_model = optimize_model self._optimize_model = optimize_model
self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir
if self._is_use_cache_file and not os.path.exists(self._cache_dir):
os.mkdir(self._cache_dir)
# Define variables # Define variables
self._place = self._executor.place self._place = self._executor.place
...@@ -317,11 +308,17 @@ class PostTrainingQuantization(object): ...@@ -317,11 +308,17 @@ class PostTrainingQuantization(object):
self._out_scale_op_list = _out_scale_op_list self._out_scale_op_list = _out_scale_op_list
self._quantized_weight_var_name = set() self._quantized_weight_var_name = set()
self._quantized_act_var_name = set() self._quantized_act_var_name = set()
self.weight_op_pairs = {} self._weight_op_pairs = {}
# The vars for alog = KL
self._sampling_act_abs_min_max = {}
self._sampling_act_histogram = {}
self._sampling_data = {} self._sampling_data = {}
self._quantized_var_kl_threshold = {} self._quantized_var_kl_threshold = {}
self._histogram_bins = 2048
# The vars for algo = min_max
self._quantized_var_min = {} self._quantized_var_min = {}
self._quantized_var_max = {} self._quantized_var_max = {}
# The vars for algo = abs_max
self._quantized_var_abs_max = {} self._quantized_var_abs_max = {}
def quantize(self): def quantize(self):
...@@ -339,6 +336,25 @@ class PostTrainingQuantization(object): ...@@ -339,6 +336,25 @@ class PostTrainingQuantization(object):
self._collect_target_varnames() self._collect_target_varnames()
self._set_activation_persistable() self._set_activation_persistable()
if self._algo == "KL":
_logger.info("Preparation stage ...")
batch_id = 0
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._collect_activation_abs_min_max()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish preparation stage, all batch:" + str(batch_id))
self._init_sampling_act_histogram()
_logger.info("Sampling stage ...")
batch_id = 0 batch_id = 0
for data in self._data_loader(): for data in self._data_loader():
self._executor.run(program=self._program, self._executor.run(program=self._program,
...@@ -346,17 +362,13 @@ class PostTrainingQuantization(object): ...@@ -346,17 +362,13 @@ class PostTrainingQuantization(object):
fetch_list=self._fetch_list, fetch_list=self._fetch_list,
return_numpy=False, return_numpy=False,
scope=self._scope) scope=self._scope)
if self._algo == "KL": self._sampling()
self._sample_data(batch_id)
else:
self._sample_threshold()
if batch_id % 5 == 0: if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id)) _logger.info("Run batch: " + str(batch_id))
batch_id += 1 batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums: if self._batch_nums and batch_id >= self._batch_nums:
break break
_logger.info("Finish all batch: " + str(batch_id)) _logger.info("Finish sampling stage, all batch: " + str(batch_id))
self._reset_activation_persistable() self._reset_activation_persistable()
...@@ -397,6 +409,7 @@ class PostTrainingQuantization(object): ...@@ -397,6 +409,7 @@ class PostTrainingQuantization(object):
target_vars=self._fetch_list, target_vars=self._fetch_list,
executor=self._executor, executor=self._executor,
main_program=self._program) main_program=self._program)
_logger.info("The quantized model is saved in " + save_model_path)
def _load_model_data(self): def _load_model_data(self):
''' '''
...@@ -454,7 +467,7 @@ class PostTrainingQuantization(object): ...@@ -454,7 +467,7 @@ class PostTrainingQuantization(object):
for var_name in var_name_list: for var_name in var_name_list:
if var_name in persistable_var_names: if var_name in persistable_var_names:
self._quantized_weight_var_name.add(var_name) self._quantized_weight_var_name.add(var_name)
self.weight_op_pairs[var_name] = op_type self._weight_op_pairs[var_name] = op_type
else: else:
self._quantized_act_var_name.add(var_name) self._quantized_act_var_name.add(var_name)
...@@ -494,20 +507,18 @@ class PostTrainingQuantization(object): ...@@ -494,20 +507,18 @@ class PostTrainingQuantization(object):
if var.name in self._quantized_act_var_name: if var.name in self._quantized_act_var_name:
var.persistable = False var.persistable = False
def _sample_threshold(self): def _sampling(self):
''' '''
Sample the input threshold(min, max, or abs_max) in every iterations. Sample the min/max, abs_max or histogram in every iterations.
''' '''
assert self._algo in ["abs_max", "min_max"], \
"The algo should be abs_max or min_max for _sample_threshold."
if self._algo == "abs_max": if self._algo == "abs_max":
self._sample_threshold_abs_max() self._sample_abs_max()
elif self._algo == "min_max": elif self._algo == "min_max":
self._sample_threshold_min_max() self._sample_min_max()
elif self._algo == "KL":
self._sample_histogram()
def _sample_threshold_abs_max(self): def _sample_abs_max(self):
assert self._algo == "abs_max", \
"The algo should be abs_max for _sample_threshold_abs_max."
# Only calculate abs_max value for weight for once # Only calculate abs_max value for weight for once
if self._quantized_var_abs_max == {}: if self._quantized_var_abs_max == {}:
for var_name in self._quantized_weight_var_name: for var_name in self._quantized_weight_var_name:
...@@ -516,7 +527,7 @@ class PostTrainingQuantization(object): ...@@ -516,7 +527,7 @@ class PostTrainingQuantization(object):
abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max": elif self._weight_quantize_type == "channel_wise_abs_max":
abs_max_value = [] abs_max_value = []
if self.weight_op_pairs[ if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops: var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]): for i in range(var_tensor.shape[1]):
abs_max_value.append( abs_max_value.append(
...@@ -534,9 +545,7 @@ class PostTrainingQuantization(object): ...@@ -534,9 +545,7 @@ class PostTrainingQuantization(object):
(abs_max_value > self._quantized_var_abs_max[var_name]): (abs_max_value > self._quantized_var_abs_max[var_name]):
self._quantized_var_abs_max[var_name] = abs_max_value self._quantized_var_abs_max[var_name] = abs_max_value
def _sample_threshold_min_max(self): def _sample_min_max(self):
assert self._algo == "min_max", \
"The algo should be min_max for _sample_threshold_min_max."
if self._quantized_var_min == {} and self._quantized_var_max == {}: if self._quantized_var_min == {} and self._quantized_var_max == {}:
for var_name in self._quantized_weight_var_name: for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name) var_tensor = _load_variable_data(self._scope, var_name)
...@@ -546,7 +555,7 @@ class PostTrainingQuantization(object): ...@@ -546,7 +555,7 @@ class PostTrainingQuantization(object):
elif self._weight_quantize_type == "channel_wise_abs_max": elif self._weight_quantize_type == "channel_wise_abs_max":
min_value = [] min_value = []
max_value = [] max_value = []
if self.weight_op_pairs[ if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops: var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]): for i in range(var_tensor.shape[1]):
min_value.append(float(np.min(var_tensor[:, i]))) min_value.append(float(np.min(var_tensor[:, i])))
...@@ -569,6 +578,14 @@ class PostTrainingQuantization(object): ...@@ -569,6 +578,14 @@ class PostTrainingQuantization(object):
(max_value > self._quantized_var_max[var_name]): (max_value > self._quantized_var_max[var_name]):
self._quantized_var_max[var_name] = max_value self._quantized_var_max[var_name] = max_value
def _sample_histogram(self):
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor_abs = np.abs(var_tensor)
bins = self._sampling_act_histogram[var_name][1]
hist, _ = np.histogram(var_tensor_abs, bins=bins)
self._sampling_act_histogram[var_name][0] += hist
def _save_input_threhold(self): def _save_input_threhold(self):
''' '''
Save input threshold to the quantized op. Save input threshold to the quantized op.
...@@ -585,27 +602,36 @@ class PostTrainingQuantization(object): ...@@ -585,27 +602,36 @@ class PostTrainingQuantization(object):
op._set_attr(var_name + ".max", op._set_attr(var_name + ".max",
self._quantized_var_max[var_name]) self._quantized_var_max[var_name])
def _sample_data(self, iter): def _collect_activation_abs_min_max(self):
''' '''
Sample the tensor data of quantized variables, Collect the abs_min and abs_max for all activation. When algo = KL,
applied in every iteration. get the min and max value, and then calculate the threshold.
''' '''
assert self._algo == "KL", "The algo should be KL to sample data." for var_name in self._quantized_act_var_name:
if self._is_use_cache_file: var_tensor = _load_variable_data(self._scope, var_name)
for var_name in self._quantized_act_var_name: var_tensor = np.abs(var_tensor)
var_tensor = _load_variable_data(self._scope, var_name) min_value = float(np.min(var_tensor))
var_tensor = var_tensor.ravel() max_value = float(np.max(var_tensor))
save_path = os.path.join( if var_name not in self._sampling_act_abs_min_max:
self._cache_dir, self._sampling_act_abs_min_max[
var_name.replace("/", ".") + "_" + str(iter) + ".npy") var_name] = [min_value, max_value]
np.save(save_path, var_tensor) else:
else: if min_value < self._sampling_act_abs_min_max[var_name][0]:
for var_name in self._quantized_act_var_name: self._sampling_act_abs_min_max[var_name][0] = min_value
if var_name not in self._sampling_data: if max_value > self._sampling_act_abs_min_max[var_name][1]:
self._sampling_data[var_name] = [] self._sampling_act_abs_min_max[var_name][1] = max_value
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.ravel() def _init_sampling_act_histogram(self):
self._sampling_data[var_name].append(var_tensor) '''
Based on the min/max value, init the sampling_act_histogram.
'''
for var_name in self._quantized_act_var_name:
if var_name not in self._sampling_act_histogram:
min_val = self._sampling_act_abs_min_max[var_name][0]
max_val = self._sampling_act_abs_min_max[var_name][1]
hist, hist_edeges = np.histogram(
[], bins=self._histogram_bins, range=(min_val, max_val))
self._sampling_act_histogram[var_name] = [hist, hist_edeges]
def _calculate_kl_threshold(self): def _calculate_kl_threshold(self):
''' '''
...@@ -621,7 +647,7 @@ class PostTrainingQuantization(object): ...@@ -621,7 +647,7 @@ class PostTrainingQuantization(object):
weight_threshold = float(np.max(np.abs(weight_data))) weight_threshold = float(np.max(np.abs(weight_data)))
elif self._weight_quantize_type == "channel_wise_abs_max": elif self._weight_quantize_type == "channel_wise_abs_max":
weight_threshold = [] weight_threshold = []
if self.weight_op_pairs[ if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops: var_name] in _channelwise_quant_axis1_ops:
for i in range(weight_data.shape[1]): for i in range(weight_data.shape[1]):
weight_threshold.append( weight_threshold.append(
...@@ -632,25 +658,10 @@ class PostTrainingQuantization(object): ...@@ -632,25 +658,10 @@ class PostTrainingQuantization(object):
float(np.max(np.abs(weight_data[i])))) float(np.max(np.abs(weight_data[i]))))
self._quantized_var_kl_threshold[var_name] = weight_threshold self._quantized_var_kl_threshold[var_name] = weight_threshold
# KL threshold for activations for var_name in self._quantized_act_var_name:
if self._is_use_cache_file: hist, hist_edeges = self._sampling_act_histogram[var_name]
for var_name in self._quantized_act_var_name: self._quantized_var_kl_threshold[var_name] = \
sampling_data = [] self._get_kl_scaling_factor(hist, hist_edeges)
filenames = [f for f in os.listdir(self._cache_dir) \
if re.match(var_name.replace("/", ".") + '_[0-9]+.npy', f)]
for filename in filenames:
file_path = os.path.join(self._cache_dir, filename)
sampling_data.append(np.load(file_path))
os.remove(file_path)
sampling_data = np.concatenate(sampling_data)
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
else:
for var_name in self._quantized_act_var_name:
self._sampling_data[var_name] = np.concatenate(
self._sampling_data[var_name])
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
def _update_program(self): def _update_program(self):
''' '''
...@@ -765,22 +776,15 @@ class PostTrainingQuantization(object): ...@@ -765,22 +776,15 @@ class PostTrainingQuantization(object):
for var_name in out_var_names: for var_name in out_var_names:
analysis_and_save_info(op, var_name) analysis_and_save_info(op, var_name)
def _get_kl_scaling_factor(self, activation_blob, num_quantized_bins=255): def _get_kl_scaling_factor(self, hist, hist_edeges, num_quantized_bins=255):
''' '''
Using the KL-divergenc method to get the more precise scaling factor. Using the KL-divergenc method to get the more precise scaling factor.
''' '''
max_val = np.max(activation_blob) ending_iter = self._histogram_bins - 1
min_val = np.min(activation_blob) starting_iter = int(ending_iter * 0.7)
if min_val >= 0:
hist, hist_edeges = np.histogram(
activation_blob, bins=2048, range=(min_val, max_val))
ending_iter = 2047
starting_iter = int(ending_iter * 0.7)
else:
_logger.error("Please first apply abs to activation_blob.")
bin_width = hist_edeges[1] - hist_edeges[0] bin_width = hist_edeges[1] - hist_edeges[0]
P_sum = len(np.array(activation_blob).ravel()) P_sum = np.sum(np.array(hist).ravel())
min_kl_divergence = 0 min_kl_divergence = 0
min_kl_index = 0 min_kl_index = 0
kl_inited = False kl_inited = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册