未验证 提交 4281eb49 编写于 作者: X XGZhang 提交者: GitHub

add new post-quant methods (#32208)

上级 cb81826a
...@@ -55,7 +55,7 @@ def _set_variable_data(scope, place, var_name, np_value): ...@@ -55,7 +55,7 @@ def _set_variable_data(scope, place, var_name, np_value):
Set the value of var node by name, if the node exits, Set the value of var node by name, if the node exits,
''' '''
assert isinstance(np_value, np.ndarray), \ assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.' 'The type of value should be numpy array.'
var_node = scope.find_var(var_name) var_node = scope.find_var(var_name)
if var_node != None: if var_node != None:
tensor = var_node.get_tensor() tensor = var_node.get_tensor()
...@@ -138,8 +138,10 @@ class PostTrainingQuantization(object): ...@@ -138,8 +138,10 @@ class PostTrainingQuantization(object):
batch_size=10, batch_size=10,
batch_nums=None, batch_nums=None,
algo="KL", algo="KL",
hist_percent=0.99999,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"], quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False, is_full_quantize=False,
bias_correction=False,
activation_bits=8, activation_bits=8,
weight_bits=8, weight_bits=8,
activation_quantize_type='range_abs_max', activation_quantize_type='range_abs_max',
...@@ -180,7 +182,13 @@ class PostTrainingQuantization(object): ...@@ -180,7 +182,13 @@ class PostTrainingQuantization(object):
get the KL threshold for quantized activations and get the abs_max get the KL threshold for quantized activations and get the abs_max
value for quantized weights. If algo='abs_max', get the abs max value for quantized weights. If algo='abs_max', get the abs max
value for activations and weights. If algo= 'min_max', get the min value for activations and weights. If algo= 'min_max', get the min
and max value for quantized activations and weights. Default is KL. and max value for quantized activations and weights. If algo='avg',
get the average value among the max values for activations. If
algo= 'hist', get the value of 'hist_percent' quantile as the threshold.
If algo='mse', get the value which makes the quantization mse loss
minimal. Default is KL.
hist_percent(float, optional): The threshold of algo 'hist' for activations.
Default is 0.99999.
quantizable_op_type(list[str], optional): List the type of ops quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d", that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"]. "mul"].
...@@ -188,6 +196,8 @@ class PostTrainingQuantization(object): ...@@ -188,6 +196,8 @@ class PostTrainingQuantization(object):
apply quantization to all supported quantizable op type. If set apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type. according to the input quantizable_op_type.
bias_correction(bool, optional): If set as True, use the bias correction
method of https://arxiv.org/abs/1810.05723. Default is False.
activation_bits(int): quantization bit number for activation. activation_bits(int): quantization bit number for activation.
weight_bits(int, optional): quantization bit number for weights. weight_bits(int, optional): quantization bit number for weights.
activation_quantize_type(str): quantization type for activation, activation_quantize_type(str): quantization type for activation,
...@@ -255,7 +265,9 @@ class PostTrainingQuantization(object): ...@@ -255,7 +265,9 @@ class PostTrainingQuantization(object):
'range_abs_max', 'moving_average_abs_max', 'abs_max' 'range_abs_max', 'moving_average_abs_max', 'abs_max'
] ]
self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max'] self._support_weight_quantize_type = ['abs_max', 'channel_wise_abs_max']
self._support_algo_type = ['KL', 'abs_max', 'min_max'] self._support_algo_type = [
'KL', 'hist', 'avg', 'mse', 'abs_max', 'min_max'
]
self._dynamic_quantize_op_type = ['lstm'] self._dynamic_quantize_op_type = ['lstm']
self._support_quantize_op_type = \ self._support_quantize_op_type = \
list(set(QuantizationTransformPass._supported_quantizable_op_type + list(set(QuantizationTransformPass._supported_quantizable_op_type +
...@@ -270,7 +282,7 @@ class PostTrainingQuantization(object): ...@@ -270,7 +282,7 @@ class PostTrainingQuantization(object):
"cannot be None in the same time." "cannot be None in the same time."
assert batch_size > 0, "The batch_size should be greater than 0." assert batch_size > 0, "The batch_size should be greater than 0."
assert algo in self._support_algo_type, \ assert algo in self._support_algo_type, \
"The algo should be KL, abs_max or min_max." "The algo should be KL, hist, mse, avg, abs_max or min_max."
assert activation_quantize_type in self._support_activation_quantize_type, \ assert activation_quantize_type in self._support_activation_quantize_type, \
"The activation_quantize_type ({}) should in ({}).".format( "The activation_quantize_type ({}) should in ({}).".format(
activation_quantize_type, self._support_activation_quantize_type) activation_quantize_type, self._support_activation_quantize_type)
...@@ -279,6 +291,7 @@ class PostTrainingQuantization(object): ...@@ -279,6 +291,7 @@ class PostTrainingQuantization(object):
weight_quantize_type, self._support_weight_quantize_type) weight_quantize_type, self._support_weight_quantize_type)
# Save input params # Save input params
self._bias_correction = bias_correction
self._executor = executor self._executor = executor
self._scope = global_scope() if scope == None else scope self._scope = global_scope() if scope == None else scope
self._model_dir = model_dir self._model_dir = model_dir
...@@ -289,6 +302,7 @@ class PostTrainingQuantization(object): ...@@ -289,6 +302,7 @@ class PostTrainingQuantization(object):
self._batch_size = batch_size self._batch_size = batch_size
self._batch_nums = batch_nums self._batch_nums = batch_nums
self._algo = algo self._algo = algo
self._hist_percent = hist_percent
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_quantize_type = activation_quantize_type self._activation_quantize_type = activation_quantize_type
...@@ -314,17 +328,21 @@ class PostTrainingQuantization(object): ...@@ -314,17 +328,21 @@ class PostTrainingQuantization(object):
self._quantized_weight_var_name = set() self._quantized_weight_var_name = set()
self._quantized_act_var_name = set() self._quantized_act_var_name = set()
self._weight_op_pairs = {} self._weight_op_pairs = {}
# The vars for alog = KL # The vars for alog = KL or hist
self._sampling_act_abs_min_max = {} self._sampling_act_abs_min_max = {}
self._sampling_act_histogram = {} self._sampling_act_histogram = {}
self._sampling_data = {} self._sampling_data = {}
self._quantized_var_kl_threshold = {} self._quantized_var_threshold = {}
self._histogram_bins = 2048 self._histogram_bins = 2048
# The vars for algo = min_max # The vars for algo = min_max
self._quantized_var_min = {} self._quantized_var_min = {}
self._quantized_var_max = {} self._quantized_var_max = {}
# The vars for algo = abs_max # The vars for algo = avg
self._quantized_var_abs_max = {} self._quantized_var_avg = {}
# The best loss of algo = mse
self._best_mse_loss = {}
# The threshold for algo = abs_max, mse or avg
self._quantized_threshold = {}
def quantize(self): def quantize(self):
''' '''
...@@ -341,7 +359,7 @@ class PostTrainingQuantization(object): ...@@ -341,7 +359,7 @@ class PostTrainingQuantization(object):
self._collect_target_varnames() self._collect_target_varnames()
self._set_activation_persistable() self._set_activation_persistable()
if self._algo == "KL": if self._algo in ["KL", "hist"]:
_logger.info("Preparation stage ...") _logger.info("Preparation stage ...")
batch_id = 0 batch_id = 0
for data in self._data_loader(): for data in self._data_loader():
...@@ -374,13 +392,14 @@ class PostTrainingQuantization(object): ...@@ -374,13 +392,14 @@ class PostTrainingQuantization(object):
if self._batch_nums and batch_id >= self._batch_nums: if self._batch_nums and batch_id >= self._batch_nums:
break break
_logger.info("Finish sampling stage, all batch: " + str(batch_id)) _logger.info("Finish sampling stage, all batch: " + str(batch_id))
self._reset_activation_persistable() self._reset_activation_persistable()
if self._algo == 'avg':
if self._algo == "KL": for var_name in self._quantized_act_var_name:
self._calculate_kl_threshold() self._quantized_threshold[var_name] = \
np.array(self._quantized_var_avg[var_name]).mean()
if self._algo in ["KL", "abs_max"]: if self._algo in ["KL", "hist"]:
self._calculate_kl_hist_threshold()
if self._algo in ["KL", "abs_max", "hist", "avg", "mse"]:
self._update_program() self._update_program()
else: else:
self._save_input_threhold() self._save_input_threhold()
...@@ -526,14 +545,84 @@ class PostTrainingQuantization(object): ...@@ -526,14 +545,84 @@ class PostTrainingQuantization(object):
''' '''
if self._algo == "abs_max": if self._algo == "abs_max":
self._sample_abs_max() self._sample_abs_max()
elif self._algo == "avg":
self._sample_avg()
elif self._algo == "min_max": elif self._algo == "min_max":
self._sample_min_max() self._sample_min_max()
elif self._algo == "KL": elif self._algo == "mse":
self._sample_mse()
elif self._algo in ["KL", "hist"]:
self._sample_histogram() self._sample_histogram()
def _sample_mse(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
abs_max_value = []
if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[:, i]))))
else:
for i in range(var_tensor.shape[0]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value
_logger.info("MSE searching stage ...")
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
var_tensor = var_tensor.flatten()
abs_max_value = float(np.max(np.abs(var_tensor)))
s = 0.3
if var_name not in self._best_mse_loss:
self._best_mse_loss[var_name] = float('inf')
while s <= 1.0:
scale = s * abs_max_value
s += 0.02
bins = 2**(self._activation_bits - 1) - 1
quant_dequant_var = np.round(
np.clip(var_tensor, 0.0, scale) / scale *
bins) / bins * scale
mse_loss = ((var_tensor - quant_dequant_var)**2).mean()
if mse_loss <= self._best_mse_loss[var_name]:
self._best_mse_loss[var_name] = mse_loss
self._quantized_threshold[var_name] = scale
def _sample_avg(self):
if self._quantized_threshold == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max":
abs_max_value = float(np.max(np.abs(var_tensor)))
elif self._weight_quantize_type == "channel_wise_abs_max":
abs_max_value = []
if self._weight_op_pairs[
var_name] in _channelwise_quant_axis1_ops:
for i in range(var_tensor.shape[1]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[:, i]))))
else:
for i in range(var_tensor.shape[0]):
abs_max_value.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_threshold[var_name] = abs_max_value
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_var_avg):
self._quantized_var_avg[var_name] = []
abs_avg_value = float(np.mean(np.max( \
np.abs(var_tensor.reshape(var_tensor.shape[0], -1)), axis=(1))))
self._quantized_var_avg[var_name].append(abs_avg_value)
continue
def _sample_abs_max(self): def _sample_abs_max(self):
# Only calculate abs_max value for weight for once if self._quantized_threshold == {}:
if self._quantized_var_abs_max == {}:
for var_name in self._quantized_weight_var_name: for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name) var_tensor = _load_variable_data(self._scope, var_name)
if self._weight_quantize_type == "abs_max": if self._weight_quantize_type == "abs_max":
...@@ -549,14 +638,14 @@ class PostTrainingQuantization(object): ...@@ -549,14 +638,14 @@ class PostTrainingQuantization(object):
for i in range(var_tensor.shape[0]): for i in range(var_tensor.shape[0]):
abs_max_value.append( abs_max_value.append(
float(np.max(np.abs(var_tensor[i])))) float(np.max(np.abs(var_tensor[i]))))
self._quantized_var_abs_max[var_name] = abs_max_value self._quantized_threshold[var_name] = abs_max_value
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name) var_tensor = _load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor))) abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_var_abs_max) or \ if (var_name not in self._quantized_threshold) or \
(abs_max_value > self._quantized_var_abs_max[var_name]): (abs_max_value > self._quantized_threshold[var_name]):
self._quantized_var_abs_max[var_name] = abs_max_value self._quantized_threshold[var_name] = abs_max_value
def _sample_min_max(self): def _sample_min_max(self):
if self._quantized_var_min == {} and self._quantized_var_max == {}: if self._quantized_var_min == {} and self._quantized_var_max == {}:
...@@ -646,12 +735,12 @@ class PostTrainingQuantization(object): ...@@ -646,12 +735,12 @@ class PostTrainingQuantization(object):
[], bins=self._histogram_bins, range=(min_val, max_val)) [], bins=self._histogram_bins, range=(min_val, max_val))
self._sampling_act_histogram[var_name] = [hist, hist_edeges] self._sampling_act_histogram[var_name] = [hist, hist_edeges]
def _calculate_kl_threshold(self): def _calculate_kl_hist_threshold(self):
''' '''
Calculate the KL threshold of quantized variables. Calculate the KL or hist threshold of quantized variables.
''' '''
_logger.info("Calculate KL threshold ...") _logger.info("Calculate {} threshold ...".format(self._algo))
assert self._algo == "KL", "The algo should be KL to calculate kl threshold." assert self._algo in ["KL", "hist"], "The algo should be KL or hist."
# Abs_max threshold for weights # Abs_max threshold for weights
for var_name in self._quantized_weight_var_name: for var_name in self._quantized_weight_var_name:
...@@ -669,18 +758,22 @@ class PostTrainingQuantization(object): ...@@ -669,18 +758,22 @@ class PostTrainingQuantization(object):
for i in range(weight_data.shape[0]): for i in range(weight_data.shape[0]):
weight_threshold.append( weight_threshold.append(
float(np.max(np.abs(weight_data[i])))) float(np.max(np.abs(weight_data[i]))))
self._quantized_var_kl_threshold[var_name] = weight_threshold self._quantized_var_threshold[var_name] = weight_threshold
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
hist, hist_edeges = self._sampling_act_histogram[var_name] hist, hist_edeges = self._sampling_act_histogram[var_name]
self._quantized_var_kl_threshold[var_name] = \ if self._algo == "KL":
self._get_kl_scaling_factor(hist, hist_edeges) self._quantized_var_threshold[var_name] = \
self._get_kl_scaling_factor(hist, hist_edeges)
elif self._algo == "hist":
self._quantized_var_threshold[var_name] = \
self._get_hist_scaling_factor(hist, hist_edeges)
def _update_program(self): def _update_program(self):
''' '''
Use QuantizationTransformPass and AddQuantDequantPass to insert Use QuantizationTransformPass and AddQuantDequantPass to insert
fake_quantize, fake_dequantize and fake_quant_dequant op. fake_quantize, fake_dequantize and fake_quant_dequant op.
Besides, save all kl threshold to the scale var node. Besides, save all threshold to the scale var node.
''' '''
_logger.info("Update the program ...") _logger.info("Update the program ...")
graph = IrGraph(core.Graph(self._program.desc), for_test=True) graph = IrGraph(core.Graph(self._program.desc), for_test=True)
...@@ -711,11 +804,11 @@ class PostTrainingQuantization(object): ...@@ -711,11 +804,11 @@ class PostTrainingQuantization(object):
quantizable_op_type=minor_quantizable_op_types) quantizable_op_type=minor_quantizable_op_types)
add_quant_dequant_pass.apply(graph) add_quant_dequant_pass.apply(graph)
# save abs_max or KL threshold to scale var node # save threshold to scale var node
if self._algo == "KL": if self._algo in ["KL", "hist"]:
scale_dict = self._quantized_var_kl_threshold scale_dict = self._quantized_var_threshold
else: else:
scale_dict = self._quantized_var_abs_max scale_dict = self._quantized_threshold
for key, val in scale_dict.items(): for key, val in scale_dict.items():
_set_variable_data( _set_variable_data(
self._scope, self._scope,
...@@ -734,6 +827,7 @@ class PostTrainingQuantization(object): ...@@ -734,6 +827,7 @@ class PostTrainingQuantization(object):
freeze_pass = QuantizationFreezePass( freeze_pass = QuantizationFreezePass(
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
bias_correction=self._bias_correction,
weight_bits=self._weight_bits, weight_bits=self._weight_bits,
activation_bits=self._activation_bits, activation_bits=self._activation_bits,
weight_quantize_type=self._weight_quantize_type, weight_quantize_type=self._weight_quantize_type,
...@@ -761,20 +855,28 @@ class PostTrainingQuantization(object): ...@@ -761,20 +855,28 @@ class PostTrainingQuantization(object):
out_var_name + " is not the output of the op" out_var_name + " is not the output of the op"
if self._algo == "KL": if self._algo == "KL":
# For compatibility, we save output threshold by two methods. # For compatibility, we save output threshold by two methods.
save_info(op_node, out_var_name, save_info(op_node, out_var_name, self._quantized_var_threshold,
self._quantized_var_kl_threshold, "out_threshold", "out_threshold", "post_kl")
"post_kl")
save_info( save_info(
op_node, out_var_name, self._quantized_var_kl_threshold, op_node, out_var_name, self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold", argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl") "post_kl")
elif self._algo == "abs_max": elif self._algo == "hist":
save_info(op_node, out_var_name, self._quantized_var_abs_max, # For compatibility, we save output threshold by two methods.
"out_threshold", "post_abs_max") save_info(op_node, out_var_name, self._quantized_var_threshold,
"out_threshold", "post_hist")
save_info( save_info(
op_node, out_var_name, self._quantized_var_abs_max, op_node, out_var_name, self._quantized_var_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold", argname_index[0] + str(argname_index[1]) + "_threshold",
"post_kl") "post_hist")
elif self._algo in ["avg", "abs_max", "mse"]:
save_info(op_node, out_var_name, self._quantized_threshold,
"out_threshold", "post_" + str(self._algo))
save_info(
op_node, out_var_name, self._quantized_threshold,
argname_index[0] + str(argname_index[1]) + "_threshold",
"post_" + str(self._algo))
elif self._algo == "min_max": elif self._algo == "min_max":
save_info(op_node, out_var_name, self._quantized_var_min, save_info(op_node, out_var_name, self._quantized_var_min,
"out_min", "post_min_max") "out_min", "post_min_max")
...@@ -817,10 +919,27 @@ class PostTrainingQuantization(object): ...@@ -817,10 +919,27 @@ class PostTrainingQuantization(object):
op._set_attr("quantization_type", quantization_type) op._set_attr("quantization_type", quantization_type)
op._set_attr("bit_length", self._weight_bits) op._set_attr("bit_length", self._weight_bits)
def _get_kl_scaling_factor(self, hist, hist_edeges, num_quantized_bins=255): def _get_hist_scaling_factor(self, hist, hist_edges):
'''
Using the hist method to get the scaling factor.
'''
threshold_rate = self._hist_percent
hist = hist / float(sum(hist))
hist_sum = 0
hist_index = 0
for i in range(len(hist)):
hist_sum += hist[i]
if hist_sum >= threshold_rate:
hist_index = i + 1
break
bin_width = hist_edges[1] - hist_edges[0]
return (hist_index - 0.5) * bin_width
def _get_kl_scaling_factor(self, hist, hist_edeges):
''' '''
Using the KL-divergenc method to get the more precise scaling factor. Using the KL-divergenc method to get the more precise scaling factor.
''' '''
num_quantized_bins = 2**(self._activation_bits - 1) - 1
ending_iter = self._histogram_bins - 1 ending_iter = self._histogram_bins - 1
starting_iter = int(ending_iter * 0.7) starting_iter = int(ending_iter * 0.7)
bin_width = hist_edeges[1] - hist_edeges[0] bin_width = hist_edeges[1] - hist_edeges[0]
......
...@@ -1070,6 +1070,7 @@ class QuantizationFreezePass(object): ...@@ -1070,6 +1070,7 @@ class QuantizationFreezePass(object):
def __init__(self, def __init__(self,
scope, scope,
place, place,
bias_correction=False,
weight_bits=8, weight_bits=8,
activation_bits=8, activation_bits=8,
weight_quantize_type='abs_max', weight_quantize_type='abs_max',
...@@ -1085,6 +1086,8 @@ class QuantizationFreezePass(object): ...@@ -1085,6 +1086,8 @@ class QuantizationFreezePass(object):
scope(fluid.Scope): scope is used to get the weight tensor values. scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace|str): place is used to restore the weight tensors. place(fluid.CPUPlace|fluid.CUDAPlace|str): place is used to restore the weight tensors.
If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs. If it's string, It can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs.
bias_correction(bool): whether use bias correction for post-training quantization.
https://arxiv.org/abs/1810.05723.
weight_bits(int): quantization bit number for weights. weight_bits(int): quantization bit number for weights.
activation_bits(int): quantization bit number for activation. activation_bits(int): quantization bit number for activation.
weight_quantize_type(str): quantization type for weights, support 'abs_max' and weight_quantize_type(str): quantization type for weights, support 'abs_max' and
...@@ -1098,6 +1101,7 @@ class QuantizationFreezePass(object): ...@@ -1098,6 +1101,7 @@ class QuantizationFreezePass(object):
assert place is not None, \ assert place is not None, \
'The place cannot be set None.' 'The place cannot be set None.'
self._scope = scope self._scope = scope
self._bias_correction = bias_correction
self._place = _get_paddle_place(place) self._place = _get_paddle_place(place)
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
...@@ -1154,7 +1158,10 @@ class QuantizationFreezePass(object): ...@@ -1154,7 +1158,10 @@ class QuantizationFreezePass(object):
else: else:
quant_axis = 0 quant_axis = 0
quantized_param_v = self._quant( quantized_param_v = self._quant(
param_v, scale_v, self._weight_bits, quant_axis) param_v.copy(), scale_v, self._weight_bits, quant_axis)
if self._bias_correction == True:
quantized_param_v = self._bias_correction_w(
param_v, quantized_param_v, scale_v, quant_axis)
self._restore_var(input_arg_name, quantized_param_v) self._restore_var(input_arg_name, quantized_param_v)
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
...@@ -1373,6 +1380,8 @@ class QuantizationFreezePass(object): ...@@ -1373,6 +1380,8 @@ class QuantizationFreezePass(object):
if isinstance(scale, list): if isinstance(scale, list):
for i, s in enumerate(scale): for i, s in enumerate(scale):
if s == 0.0:
s = 1e-8
if quant_axis == 0: if quant_axis == 0:
x[i] = _clip(x[i], s) x[i] = _clip(x[i], s)
x[i] = np.round(x[i] / s * bnt) x[i] = np.round(x[i] / s * bnt)
...@@ -1384,6 +1393,46 @@ class QuantizationFreezePass(object): ...@@ -1384,6 +1393,46 @@ class QuantizationFreezePass(object):
x = np.round(x / scale * bnt) x = np.round(x / scale * bnt)
return x return x
def _bias_correction_w(self, x, x_quant, scale_v, quant_axis):
'''
Bias correction for weight
'''
eps = 1e-8
bnt = (1 << (self._weight_bits - 1)) - 1
x_dequant = x_quant.copy()
if isinstance(scale_v, list):
if quant_axis == 0:
for i, s in enumerate(scale_v):
x_dequant[i] = x_dequant[i] * s / bnt
quant_bias = x - x_dequant
mean_bias = quant_bias.reshape(quant_bias.shape[0], -1).mean(-1)
std_orig = x.reshape(x.shape[0], -1).std(-1)
std_quant = x_dequant.reshape(x_dequant.shape[0], -1).std(-1)
std_bias = std_orig / (std_quant + eps)
else:
for i, s in enumerate(scale_v):
x_dequant[:, i] = x_quant[:, i] * s / bnt
quant_bias = x - x_dequant
mean_bias = np.array([
quant_bias[:, i].mean() for i in range(quant_bias.shape[1])
])
std_orig = np.array([x[:, i].std() for i in range(x.shape[1])])
std_quant = np.array(
[x_dequant[:, i].std() for i in range(x_dequant.shape[1])])
std_bias = std_orig / (std_quant + eps)
else:
x_dequant = x_quant * scale_v / bnt
mean_bias = (x - x_dequant).mean()
std_bias = x.std() / (x_dequant.std() + eps)
if mean_bias.ndim == 1:
std_bias = np.resize(std_bias, x.shape)
mean_bias = np.resize(mean_bias, x.shape)
x_dequant = (mean_bias + x_dequant) * std_bias
quantized_param_v = self._quant(x_dequant, scale_v, self._weight_bits,
quant_axis)
return quantized_param_v
class ConvertToInt8Pass(object): class ConvertToInt8Pass(object):
def __init__(self, scope, place, quantizable_op_type=None): def __init__(self, scope, place, quantizable_op_type=None):
......
...@@ -204,6 +204,66 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization): ...@@ -204,6 +204,66 @@ class TestPostTrainingKLForMnist(TestPostTrainingQuantization):
quant_iterations) quant_iterations)
class TestPostTraininghistForMnist(TestPostTrainingQuantization):
def test_post_training_hist(self):
model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "hist"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
class TestPostTrainingmseForMnist(TestPostTrainingQuantization):
def test_post_training_mse(self):
model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "mse"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
class TestPostTrainingavgForMnist(TestPostTrainingQuantization):
def test_post_training_avg(self):
model_name = "mnist_model"
data_url = "http://paddle-inference-dist.bj.bcebos.com/int8/mnist_model.tar.gz"
data_md5 = "be71d3997ec35ac2a65ae8a145e2887c"
algo = "avg"
quantizable_op_type = ["conv2d", "depthwise_conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.01
batch_size = 10
infer_iterations = 50
quant_iterations = 5
self.run_test(model_name, data_url, data_md5, algo, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold, batch_size, infer_iterations,
quant_iterations)
class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization): class TestPostTrainingAbsMaxForMnist(TestPostTrainingQuantization):
def test_post_training_abs_max(self): def test_post_training_abs_max(self):
model_name = "mnist_model" model_name = "mnist_model"
......
...@@ -328,6 +328,50 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): ...@@ -328,6 +328,50 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
diff_threshold) diff_threshold)
class TestPostTrainingavgForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_avg_mobilenetv1(self):
model = "MobileNet-V1"
algo = "avg"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
class TestPostTraininghistForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_hist_mobilenetv1(self):
model = "MobileNet-V1"
algo = "hist"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d",
"depthwise_conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
is_optimize_model = True
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, is_optimize_model,
diff_threshold)
class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_abs_max_mobilenetv1(self): def test_post_training_abs_max_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
......
...@@ -257,6 +257,7 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -257,6 +257,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
use_cuda, use_cuda,
seed, seed,
activation_quant_type, activation_quant_type,
bias_correction=False,
weight_quant_type='abs_max', weight_quant_type='abs_max',
for_ci=True, for_ci=True,
quant_skip_pattern='skip_quant'): quant_skip_pattern='skip_quant'):
...@@ -355,7 +356,8 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -355,7 +356,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
# Freeze graph for inference, but the weight of fc/conv is still float type. # Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass( freeze_pass = QuantizationFreezePass(
scope=scope, place=place, weight_quantize_type=weight_quant_type) scope=scope, place=place, bias_correction=bias_correction, \
weight_quantize_type=weight_quant_type)
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
if not for_ci: if not for_ci:
marked_nodes = set() marked_nodes = set()
...@@ -472,6 +474,13 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -472,6 +474,13 @@ class TestQuantizationFreezePass(unittest.TestCase):
def test_freeze_graph_cuda_static(self): def test_freeze_graph_cuda_static(self):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard(): with fluid.unique_name.guard():
self.freeze_graph(
True,
seed=1,
activation_quant_type='range_abs_max',
bias_correction=True,
weight_quant_type='abs_max',
for_ci=True)
self.freeze_graph( self.freeze_graph(
True, True,
seed=1, seed=1,
...@@ -496,6 +505,13 @@ class TestQuantizationFreezePass(unittest.TestCase): ...@@ -496,6 +505,13 @@ class TestQuantizationFreezePass(unittest.TestCase):
activation_quant_type='moving_average_abs_max', activation_quant_type='moving_average_abs_max',
weight_quant_type='channel_wise_abs_max', weight_quant_type='channel_wise_abs_max',
for_ci=True) for_ci=True)
self.freeze_graph(
True,
seed=1,
activation_quant_type='moving_average_abs_max',
bias_correction=True,
weight_quant_type='channel_wise_abs_max',
for_ci=True)
def test_freeze_graph_cpu_static(self): def test_freeze_graph_cpu_static(self):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册