未验证 提交 2d8281d5 编写于 作者: C cc 提交者: GitHub

Remove the cache in post_traning_quantization, test=develop (#26450)

* Remove the cache in post_traning_quantization, test=develop
上级 3ae3b864
......@@ -143,7 +143,7 @@ class PostTrainingQuantization(object):
weight_quantize_type='channel_wise_abs_max',
optimize_model=False,
is_use_cache_file=False,
cache_dir="./temp_post_training"):
cache_dir=None):
'''
Constructor.
......@@ -206,13 +206,8 @@ class PostTrainingQuantization(object):
`conv2d/depthwise_conv2d + bn`, the weights scale for all channel will
be different. In address this problem, fuse the pattern before
quantization. Default False.
is_use_cache_file(bool, optional): If set is_use_cache_file as False,
all temp data will be saved in memory. If set is_use_cache_file as True,
it will save temp data to disk. When the fp32 model is complex or
the number of calibrate data is large, we should set is_use_cache_file
as True. Defalut is False.
cache_dir(str, optional): When is_use_cache_file is True, set cache_dir as
the directory for saving temp data. Default is ./temp_post_training.
is_use_cache_file(bool, optional): This param is deprecated.
cache_dir(str, optional): This param is deprecated.
Returns:
None
......@@ -302,10 +297,6 @@ class PostTrainingQuantization(object):
assert op_type in self._support_quantize_op_type, \
op_type + " is not supported for quantization."
self._optimize_model = optimize_model
self._is_use_cache_file = is_use_cache_file
self._cache_dir = cache_dir
if self._is_use_cache_file and not os.path.exists(self._cache_dir):
os.mkdir(self._cache_dir)
# Define variables
self._place = self._executor.place
......@@ -317,11 +308,17 @@ class PostTrainingQuantization(object):
self._out_scale_op_list = _out_scale_op_list
self._quantized_weight_var_name = set()
self._quantized_act_var_name = set()
self.weight_op_pairs = {}
self._weight_op_pairs = {}
# The vars for alog = KL
self._sampling_act_abs_min_max = {}
self._sampling_act_histogram = {}
self._sampling_data = {}
self._quantized_var_kl_threshold = {}
self._histogram_bins = 2048
# The vars for algo = min_max
self._quantized_var_min = {}
self._quantized_var_max = {}
# The vars for algo = abs_max
self._quantized_var_abs_max = {}
def quantize(self):
......@@ -339,6 +336,25 @@ class PostTrainingQuantization(object):
self._collect_target_varnames()
self._set_activation_persistable()
if self._algo == "KL":
_logger.info("Preparation stage ...")
batch_id = 0
for data in self._data_loader():
self._executor.run(program=self._program,
feed=data,
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
self._collect_activation_abs_min_max()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish preparation stage, all batch:" + str(batch_id))
self._init_sampling_act_histogram()
_logger.info("Sampling stage ...")
batch_id = 0
for data in self._data_loader():
self._executor.run(program=self._program,
......@@ -346,17 +362,13 @@ class PostTrainingQuantization(object):
fetch_list=self._fetch_list,
return_numpy=False,
scope=self._scope)
if self._algo == "KL":
self._sample_data(batch_id)
else:
self._sample_threshold()
self._sampling()
if batch_id % 5 == 0:
_logger.info("Run batch: " + str(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("Finish all batch: " + str(batch_id))
_logger.info("Finish sampling stage, all batch: " + str(batch_id))
self._reset_activation_persistable()
......@@ -397,6 +409,7 @@ class PostTrainingQuantization(object):
target_vars=self._fetch_list,
executor=self._executor,
main_program=self._program)
_logger.info("The quantized model is saved in " + save_model_path)
def _load_model_data(self):
'''
......@@ -454,7 +467,7 @@ class PostTrainingQuantization(object):
for var_name in var_name_list:
if var_name in persistable_var_names:
self._quantized_weight_var_name.add(var_name)
self.weight_op_pairs[var_name] = op_type
self._weight_op_pairs[var_name] = op_type
else:
self._quantized_act_var_name.add(var_name)
......@@ -494,20 +507,18 @@ class PostTrainingQuantization(object):
if var.name in self._quantized_act_var_name:
var.persistable = False
def _sample_threshold(self):
def _sampling(self):
'''
Sample the input threshold(min, max, or abs_max) in every iterations.
Sample the min/max, abs_max or histogram in every iterations.
'''
assert self._algo in ["abs_max", "min_max"], \
"The algo should be abs_max or min_max for _sample_threshold."
if self._algo == "abs_max":
self._sample_threshold_abs_max()
self._sample_abs_max()
elif self._algo == "min_max":
self._sample_threshold_min_max()
self._sample_min_max()
elif self._algo == "KL":
self._sample_histogram()
def _sample_threshold_abs_max(self):
assert self._algo == "abs_max", \
"The algo should be abs_max for _sample_threshold_abs_max."
def _sample_abs_max(self):
# Only calculate abs_max value for weight for once
if self._quantized_var_abs_max == {}:
for var_name in self._quantized_weight_var_name:
......@@ -516,7 +527,7 @@ class PostTrainingQuantization(object):
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
abs_max_value = []
if self.weight_op_pairs[
if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
abs_max_value.append(
......@@ -534,9 +545,7 @@ class PostTrainingQuantization(object):
(abs_max_value > self._quantized_var_abs_max[var_name]):
self._quantized_var_abs_max[var_name] = abs_max_value
def _sample_threshold_min_max(self):
assert self._algo == "min_max", \
"The algo should be min_max for _sample_threshold_min_max."
def _sample_min_max(self):
if self._quantized_var_min == {} and self._quantized_var_max == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
......@@ -546,7 +555,7 @@ class PostTrainingQuantization(object):
elif self._weight_quantize_type == "channel_wise_abs_max":
min_value = []
max_value = []
if self.weight_op_pairs[
if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
min_value.append(float(np.min(var_tensor[:, i])))
......@@ -569,6 +578,14 @@ class PostTrainingQuantization(object):
(max_value > self._quantized_var_max[var_name]):
self._quantized_var_max[var_name] = max_value
def _sample_histogram(self):
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor_abs = np.abs(var_tensor)
bins = self._sampling_act_histogram[var_name][1]
hist, _ = np.histogram(var_tensor_abs, bins=bins)
self._sampling_act_histogram[var_name][0] += hist
def _save_input_threhold(self):
'''
Save input threshold to the quantized op.
......@@ -585,27 +602,36 @@ class PostTrainingQuantization(object):
op._set_attr(var_name + ".max",
self._quantized_var_max[var_name])
def _sample_data(self, iter):
def _collect_activation_abs_min_max(self):
'''
Sample the tensor data of quantized variables,
applied in every iteration.
Collect the abs_min and abs_max for all activation. When algo = KL,
get the min and max value, and then calculate the threshold.
'''
assert self._algo == "KL", "The algo should be KL to sample data."
if self._is_use_cache_file:
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.ravel()
save_path = os.path.join(
self._cache_dir,
var_name.replace("/", ".") + "_" + str(iter) + ".npy")
np.save(save_path, var_tensor)
else:
for var_name in self._quantized_act_var_name:
if var_name not in self._sampling_data:
self._sampling_data[var_name] = []
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.ravel()
self._sampling_data[var_name].append(var_tensor)
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = np.abs(var_tensor)
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
if var_name not in self._sampling_act_abs_min_max:
self._sampling_act_abs_min_max[
var_name] = [min_value, max_value]
else:
if min_value < self._sampling_act_abs_min_max[var_name][0]:
self._sampling_act_abs_min_max[var_name][0] = min_value
if max_value > self._sampling_act_abs_min_max[var_name][1]:
self._sampling_act_abs_min_max[var_name][1] = max_value
def _init_sampling_act_histogram(self):
'''
Based on the min/max value, init the sampling_act_histogram.
'''
for var_name in self._quantized_act_var_name:
if var_name not in self._sampling_act_histogram:
min_val = self._sampling_act_abs_min_max[var_name][0]
max_val = self._sampling_act_abs_min_max[var_name][1]
hist, hist_edeges = np.histogram(
[], bins=self._histogram_bins, range=(min_val, max_val))
self._sampling_act_histogram[var_name] = [hist, hist_edeges]
def _calculate_kl_threshold(self):
'''
......@@ -621,7 +647,7 @@ class PostTrainingQuantization(object):
weight_threshold = float(np.max(np.abs(weight_data)))
elif self._weight_quantize_type == "channel_wise_abs_max":
weight_threshold = []
if self.weight_op_pairs[
if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(weight_data.shape[1]):
weight_threshold.append(
......@@ -632,25 +658,10 @@ class PostTrainingQuantization(object):
float(np.max(np.abs(weight_data[i]))))
self._quantized_var_kl_threshold[var_name] = weight_threshold
# KL threshold for activations
if self._is_use_cache_file:
for var_name in self._quantized_act_var_name:
sampling_data = []
filenames = [f for f in os.listdir(self._cache_dir) \
if re.match(var_name.replace("/", ".") + '_[0-9]+.npy', f)]
for filename in filenames:
file_path = os.path.join(self._cache_dir, filename)
sampling_data.append(np.load(file_path))
os.remove(file_path)
sampling_data = np.concatenate(sampling_data)
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
else:
for var_name in self._quantized_act_var_name:
self._sampling_data[var_name] = np.concatenate(
self._sampling_data[var_name])
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
for var_name in self._quantized_act_var_name:
hist, hist_edeges = self._sampling_act_histogram[var_name]
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(hist, hist_edeges)
def _update_program(self):
'''
......@@ -765,22 +776,15 @@ class PostTrainingQuantization(object):
for var_name in out_var_names:
analysis_and_save_info(op, var_name)
def _get_kl_scaling_factor(self, activation_blob, num_quantized_bins=255):
def _get_kl_scaling_factor(self, hist, hist_edeges, num_quantized_bins=255):
'''
Using the KL-divergenc method to get the more precise scaling factor.
'''
max_val = np.max(activation_blob)
min_val = np.min(activation_blob)
if min_val >= 0:
hist, hist_edeges = np.histogram(
activation_blob, bins=2048, range=(min_val, max_val))
ending_iter = 2047
starting_iter = int(ending_iter * 0.7)
else:
_logger.error("Please first apply abs to activation_blob.")
ending_iter = self._histogram_bins - 1
starting_iter = int(ending_iter * 0.7)
bin_width = hist_edeges[1] - hist_edeges[0]
P_sum = len(np.array(activation_blob).ravel())
P_sum = np.sum(np.array(hist).ravel())
min_kl_divergence = 0
min_kl_index = 0
kl_inited = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册