未验证 提交 589cd878 编写于 作者: C cc 提交者: GitHub

Post_training_quantizaion supports min_max methon (#23078)

* Post_training_quantizaion supports min_max methon
上级 194a22c5
...@@ -37,7 +37,10 @@ def _load_variable_data(scope, var_name): ...@@ -37,7 +37,10 @@ def _load_variable_data(scope, var_name):
''' '''
Load variable value from scope Load variable value from scope
''' '''
return np.array(scope.find_var(var_name).get_tensor()) var_node = scope.find_var(var_name)
assert var_node is not None, \
"Cannot find " + var_name + " in scope."
return np.array(var_node.get_tensor())
def _set_variable_data(scope, place, var_name, np_value): def _set_variable_data(scope, place, var_name, np_value):
...@@ -53,6 +56,12 @@ def _set_variable_data(scope, place, var_name, np_value): ...@@ -53,6 +56,12 @@ def _set_variable_data(scope, place, var_name, np_value):
class PostTrainingQuantization(object): class PostTrainingQuantization(object):
"""
Utilizing post training quantization methon to quantize the FP32 model,
and it uses calibrate data to get the quantization information for all
quantized variables.
"""
def __init__(self, def __init__(self,
executor=None, executor=None,
scope=None, scope=None,
...@@ -70,13 +79,10 @@ class PostTrainingQuantization(object): ...@@ -70,13 +79,10 @@ class PostTrainingQuantization(object):
is_use_cache_file=False, is_use_cache_file=False,
cache_dir="./temp_post_training"): cache_dir="./temp_post_training"):
''' '''
The class utilizes post training quantization methon to quantize the Constructor.
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Args: Args:
executor(fluid.Executor): The executor to load, run and save the executor(fluid.Executor): The executor to load, run and save the
quantized model. quantized model.
scope(fluid.Scope, optional): The scope of the program, use it to load scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope(). and save variables. If scope=None, get scope by global_scope().
...@@ -96,9 +102,11 @@ class PostTrainingQuantization(object): ...@@ -96,9 +102,11 @@ class PostTrainingQuantization(object):
batch_nums(int, optional): If batch_nums is not None, the number of batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use calibrate data is batch_size*batch_nums. If batch_nums is None, use
all data provided by sample_generator as calibrate data. all data provided by sample_generator as calibrate data.
algo(str, optional): If algo=KL, use KL-divergenc method to algo(str, optional): If algo='KL', use KL-divergenc method to
get the more precise scale factor. If algo='direct', use get the KL threshold for quantized activations and get the abs_max
abs_max methon to get the scale factor. Default is KL. value for quantized weights. If algo='abs_max', get the abs max
value for activations and weights. If algo= 'min_max', get the min
and max value for quantized activations and weights. Default is KL.
quantizable_op_type(list[str], optional): List the type of ops quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d", that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"]. "mul"].
...@@ -158,7 +166,9 @@ class PostTrainingQuantization(object): ...@@ -158,7 +166,9 @@ class PostTrainingQuantization(object):
assert executor is not None, "The executor cannot be None." assert executor is not None, "The executor cannot be None."
assert model_dir is not None, "The model_dir cannot be None." assert model_dir is not None, "The model_dir cannot be None."
assert sample_generator is not None, \ assert sample_generator is not None, \
"The sample_generator cannot be None." "The sample_generator cannot be None."
assert algo in ['KL', 'abs_max', 'min_max'], \
"The algo should be KL, abs_max or min_max."
self._executor = executor self._executor = executor
self._scope = global_scope() if scope == None else scope self._scope = global_scope() if scope == None else scope
...@@ -182,8 +192,7 @@ class PostTrainingQuantization(object): ...@@ -182,8 +192,7 @@ class PostTrainingQuantization(object):
else: else:
self._quantizable_op_type = quantizable_op_type self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type: for op_type in self._quantizable_op_type:
assert op_type in supported_quantizable_op_type + \ assert op_type in supported_quantizable_op_type, \
AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization." op_type + " is not supported for quantization."
self._place = self._executor.place self._place = self._executor.place
...@@ -197,20 +206,25 @@ class PostTrainingQuantization(object): ...@@ -197,20 +206,25 @@ class PostTrainingQuantization(object):
self._quantized_weight_var_name = set() self._quantized_weight_var_name = set()
self._quantized_act_var_name = set() self._quantized_act_var_name = set()
self._sampling_data = {} self._sampling_data = {}
self._quantized_var_scale_factor = {} self._quantized_var_kl_threshold = {}
self._quantized_var_min = {}
self._quantized_var_max = {}
self._quantized_var_abs_max = {}
def quantize(self): def quantize(self):
''' '''
Quantize the fp32 model. Use calibrate data to calculate the scale factor of Load the FP32 model, and use the calibrate data to calculate the forward-stage.
quantized variables, and inserts fake quant/dequant op to obtain the Based on the sample data, we can get the quantization information, and obtain
quantized model. the final quantized model.
Args: Args:
None None
Returns: Returns:
the program of quantized model. the program of quantized model.
''' '''
self._preprocess() self._load_model_data()
self._collect_quantized_varnames()
self._set_activation_persistable()
batch_id = 0 batch_id = 0
for data in self._data_loader(): for data in self._data_loader():
...@@ -218,22 +232,29 @@ class PostTrainingQuantization(object): ...@@ -218,22 +232,29 @@ class PostTrainingQuantization(object):
feed=data, feed=data,
fetch_list=self._fetch_list, fetch_list=self._fetch_list,
return_numpy=False) return_numpy=False)
self._sample_data(batch_id) if self._algo == "KL":
self._sample_data(batch_id)
else:
self._sample_threshold()
if batch_id % 5 == 0: if batch_id % 5 == 0:
_logger.info("run batch: " + str(batch_id)) _logger.info("Run batch: " + str(batch_id))
batch_id += 1 batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums: if self._batch_nums and batch_id >= self._batch_nums:
break break
_logger.info("all run batch: " + str(batch_id)) _logger.info("Finish all batch: " + str(batch_id))
_logger.info("calculate scale factor ...") self._reset_activation_persistable()
self._calculate_scale_factor()
_logger.info("update the program ...") if self._algo == "KL":
self._update_program() self._calculate_kl_threshold()
self._save_output_scale() if self._algo in ["KL", "abs_max"]:
self._update_program()
else:
self._save_input_threhold()
self._save_output_threshold()
return self._program return self._program
def save_quantized_model(self, save_model_path): def save_quantized_model(self, save_model_path):
...@@ -252,12 +273,11 @@ class PostTrainingQuantization(object): ...@@ -252,12 +273,11 @@ class PostTrainingQuantization(object):
executor=self._executor, executor=self._executor,
main_program=self._program) main_program=self._program)
def _preprocess(self): def _load_model_data(self):
''' '''
Load model and set data loader, collect the variable names for sampling, Load model and set data loader.
and set activation variables to be persistable.
''' '''
# load model and set data loader _logger.info("Load model and set data loader ...")
[self._program, self._feed_list, self._fetch_list] = \ [self._program, self._feed_list, self._fetch_list] = \
io.load_inference_model(dirname=self._model_dir, io.load_inference_model(dirname=self._model_dir,
executor=self._executor, executor=self._executor,
...@@ -273,7 +293,12 @@ class PostTrainingQuantization(object): ...@@ -273,7 +293,12 @@ class PostTrainingQuantization(object):
drop_last=True, drop_last=True,
places=self._place) places=self._place)
# collect the variable names for sampling. def _collect_quantized_varnames(self):
'''
Collect the variable names for sampling, and set activation
variables to be persistable.
'''
_logger.info("Collect quantized variable names ...")
# TODO(juncaipeng), consider the name_scope of skip_quant and # TODO(juncaipeng), consider the name_scope of skip_quant and
# reduce the variables for sampling # reduce the variables for sampling
persistable_var_names = [] persistable_var_names = []
...@@ -284,46 +309,109 @@ class PostTrainingQuantization(object): ...@@ -284,46 +309,109 @@ class PostTrainingQuantization(object):
for op in self._program.global_block().ops: for op in self._program.global_block().ops:
op_type = op.type op_type = op.type
if op_type in self._quantizable_op_type: if op_type in self._quantizable_op_type:
if op_type in ("conv2d", "depthwise_conv2d"): name_list = self._op_real_in_out_name[op_type]
self._quantized_act_var_name.add(op.input("Input")[0]) for input_name in name_list[0]:
self._quantized_weight_var_name.add(op.input("Filter")[0]) for var_name in op.input(input_name):
self._quantized_act_var_name.add(op.output("Output")[0]) if var_name in persistable_var_names:
elif op_type in ["mul", "matmul"]: self._quantized_weight_var_name.add(var_name)
x_var_name = op.input("X")[0] else:
if x_var_name in persistable_var_names: self._quantized_act_var_name.add(var_name)
self._quantized_weight_var_name.add(x_var_name) for output_name in name_list[1]:
else: for var_name in op.output(output_name):
self._quantized_act_var_name.add(x_var_name) if var_name in persistable_var_names:
y_var_name = op.input("Y")[0] self._quantized_weight_var_name.add(var_name)
if y_var_name in persistable_var_names: else:
self._quantized_weight_var_name.add(y_var_name) self._quantized_act_var_name.add(var_name)
else:
self._quantized_act_var_name.add(y_var_name) def _set_activation_persistable(self):
self._quantized_act_var_name.add(op.output("Out")[0]) '''
else: Set activation variables to be persistable, so can obtain
# process other quantizable op type, the input must all not persistable the tensor data in sample_data
if self._is_input_all_not_persistable( '''
op, persistable_var_names): persistable_var_names = []
input_output_name_list = self._op_real_in_out_name[ for var in self._program.list_vars():
op_type] if var.persistable:
for input_name in input_output_name_list[0]: persistable_var_names.append(var.name)
for var_name in op.input(input_name):
self._quantized_act_var_name.add(var_name)
for output_name in input_output_name_list[1]:
for var_name in op.output(output_name):
self._quantized_act_var_name.add(var_name)
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
for var in self._program.list_vars(): for var in self._program.list_vars():
if var.name in self._quantized_act_var_name: if var.name in self._quantized_act_var_name:
var.persistable = True var.persistable = True
def _reset_activation_persistable(self):
'''
Reset activations to be not persistable.
'''
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = False
def _sample_threshold(self):
'''
Sample the input threshold(min, max, or abs_max) in every iterations.
'''
assert self._algo in ["abs_max", "min_max"], \
"The algo should be abs_max or min_max to sample min max value."
if self._algo == "abs_max":
# Only calculate abs_max value for weight for once
if self._quantized_var_abs_max == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
abs_max_per_channel = []
for i in range(var_tensor.shape[0]):
abs_max_per_channel.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_var_abs_max[var_name] = abs_max_per_channel
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_var_abs_max) or \
(abs_max_value > self._quantized_var_abs_max[var_name]):
self._quantized_var_abs_max[var_name] = abs_max_value
elif self._algo == "min_max":
if self._quantized_var_min == {} and self._quantized_var_max == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
min_per_channel = []
max_per_channle = []
for i in range(var_tensor.shape[0]):
min_per_channel.append(float(np.min(var_tensor[i])))
max_per_channle.append(float(np.max(var_tensor[i])))
self._quantized_var_min[var_name] = min_per_channel
self._quantized_var_max[var_name] = max_per_channle
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
if (var_name not in self._quantized_var_min) or \
(min_value < self._quantized_var_min[var_name]):
self._quantized_var_min[var_name] = min_value
if (var_name not in self._quantized_var_max) or \
(max_value > self._quantized_var_max[var_name]):
self._quantized_var_max[var_name] = max_value
def _save_input_threhold(self):
'''
Save input threshold to the quantized op.
'''
assert self._algo == "min_max", \
"The algo should be min_max to save input threshold."
for op in self._program.global_block().ops:
if op.type in self._quantizable_op_type:
input_name_list = self._op_real_in_out_name[op.type][0]
for input_name in input_name_list:
for var_name in op.input(input_name):
assert var_name in self._quantized_var_min
assert var_name in self._quantized_var_max
op._set_attr(var_name + ".min",
self._quantized_var_min[var_name])
op._set_attr(var_name + ".max",
self._quantized_var_max[var_name])
def _sample_data(self, iter): def _sample_data(self, iter):
''' '''
Sample the tensor data of quantized variables, Sample the tensor data of quantized variables,
applied in every iteration. applied in every iteration.
''' '''
assert self._algo == "KL", "The algo should be KL to sample data."
for var_name in self._quantized_weight_var_name: for var_name in self._quantized_weight_var_name:
if var_name not in self._sampling_data: if var_name not in self._sampling_data:
var_tensor = _load_variable_data(self._scope, var_name) var_tensor = _load_variable_data(self._scope, var_name)
...@@ -344,19 +432,20 @@ class PostTrainingQuantization(object): ...@@ -344,19 +432,20 @@ class PostTrainingQuantization(object):
var_tensor = var_tensor.ravel() var_tensor = var_tensor.ravel()
self._sampling_data[var_name].append(var_tensor) self._sampling_data[var_name].append(var_tensor)
def _calculate_scale_factor(self): def _calculate_kl_threshold(self):
''' '''
Calculate the scale factor of quantized variables. Calculate the KL threshold of quantized variables.
''' '''
_logger.info("Calculate KL threshold ...")
assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
# apply channel_wise_abs_max quantization for weights # apply channel_wise_abs_max quantization for weights
for var_name in self._quantized_weight_var_name: for var_name in self._quantized_weight_var_name:
data = self._sampling_data[var_name] data = self._sampling_data[var_name]
scale_factor_per_channel = [] threshold_per_channel = []
for i in range(data.shape[0]): for i in range(data.shape[0]):
abs_max_value = np.max(np.abs(data[i])) abs_max_value = np.max(np.abs(data[i]))
scale_factor_per_channel.append(abs_max_value) threshold_per_channel.append(abs_max_value)
self._quantized_var_scale_factor[ self._quantized_var_kl_threshold[var_name] = threshold_per_channel
var_name] = scale_factor_per_channel
# apply kl quantization for activation # apply kl quantization for activation
if self._is_use_cache_file: if self._is_use_cache_file:
...@@ -369,36 +458,25 @@ class PostTrainingQuantization(object): ...@@ -369,36 +458,25 @@ class PostTrainingQuantization(object):
sampling_data.append(np.load(file_path)) sampling_data.append(np.load(file_path))
os.remove(file_path) os.remove(file_path)
sampling_data = np.concatenate(sampling_data) sampling_data = np.concatenate(sampling_data)
self._quantized_var_kl_threshold[var_name] = \
if self._algo == "KL": self._get_kl_scaling_factor(np.abs(sampling_data))
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(sampling_data))
else: else:
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
self._sampling_data[var_name] = np.concatenate( self._sampling_data[var_name] = np.concatenate(
self._sampling_data[var_name]) self._sampling_data[var_name])
if self._algo == "KL": self._quantized_var_kl_threshold[var_name] = \
self._quantized_var_scale_factor[var_name] = \ self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(self._sampling_data[var_name]))
def _update_program(self): def _update_program(self):
''' '''
Insert fake_quantize/fake_dequantize op to the program. Use QuantizationTransformPass and AddQuantDequantPass to insert
fake_quantize, fake_dequantize and fake_quant_dequant op.
Besides, save all kl threshold to the scale var node.
''' '''
# reset quantized activation variable _logger.info("Update the program ...")
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = False
# use QuantizationTransformPass to insert fake_quantize/fake_dequantize op
graph = IrGraph(core.Graph(self._program.desc), for_test=True) graph = IrGraph(core.Graph(self._program.desc), for_test=True)
# use QuantizationTransformPass to insert fake_quant/fake_dequantize op
major_quantizable_op_types = [] major_quantizable_op_types = []
for op_type in QuantizationTransformPass._supported_quantizable_op_type: for op_type in QuantizationTransformPass._supported_quantizable_op_type:
if op_type in self._quantizable_op_type: if op_type in self._quantizable_op_type:
...@@ -424,8 +502,12 @@ class PostTrainingQuantization(object): ...@@ -424,8 +502,12 @@ class PostTrainingQuantization(object):
quantizable_op_type=minor_quantizable_op_types) quantizable_op_type=minor_quantizable_op_types)
add_quant_dequant_pass.apply(graph) add_quant_dequant_pass.apply(graph)
# save scale factor to scale var node # save abs_max or KL threshold to scale var node
for key, val in self._quantized_var_scale_factor.items(): if self._algo == "KL":
scale_dict = self._quantized_var_kl_threshold
else:
scale_dict = self._quantized_var_abs_max
for key, val in scale_dict.items():
_set_variable_data( _set_variable_data(
self._scope, self._scope,
self._place, self._place,
...@@ -450,33 +532,34 @@ class PostTrainingQuantization(object): ...@@ -450,33 +532,34 @@ class PostTrainingQuantization(object):
freeze_pass.apply(graph) freeze_pass.apply(graph)
self._program = graph.to_program() self._program = graph.to_program()
def _save_output_scale(self): def _save_output_threshold(self):
''' '''
Save output scale to the quantized op. Save output threshold to the quantized op.
''' '''
output_scale_name = "output_scale"
for op in self._program.global_block().ops: for op in self._program.global_block().ops:
if op.type in self._quantizable_op_type: if op.type in self._quantizable_op_type:
output_name_list = self._op_real_in_out_name[op.type][1] output_name_list = self._op_real_in_out_name[op.type][1]
for output_name in output_name_list: for output_name in output_name_list:
for output_var_name in op.output(output_name): for var_name in op.output(output_name):
if output_var_name in self._quantized_var_scale_factor: if self._algo == "KL":
op._set_attr(output_scale_name, assert var_name in self._quantized_var_kl_threshold
self._quantized_var_scale_factor[ op._set_attr(
output_var_name]) var_name + ".threshold",
self._quantized_var_kl_threshold[var_name])
def _is_input_all_not_persistable(self, op, persistable_var_names): op._set_attr("quantization_type", "post_kl")
''' elif self._algo == "abs_max":
Analyze the real inputs of the op are all not persistable. assert var_name in self._quantized_var_abs_max
''' op._set_attr(var_name + ".threshold",
is_input_all_not_persistable = True self._quantized_var_abs_max[var_name])
input_name_list = self._op_real_in_out_name[op.type][0] op._set_attr("quantization_type", "post_abs_max")
for input_name in input_name_list: elif self._algo == "min_max":
for var_name in op.input(input_name): assert var_name in self._quantized_var_min
if var_name in persistable_var_names: assert var_name in self._quantized_var_max
is_input_all_not_persistable = False op._set_attr(var_name + ".min",
break self._quantized_var_min[var_name])
return is_input_all_not_persistable op._set_attr(var_name + ".max",
self._quantized_var_max[var_name])
op._set_attr("quantization_type", "post_min_max")
def _get_kl_scaling_factor(self, activation_blob, num_quantized_bins=255): def _get_kl_scaling_factor(self, activation_blob, num_quantized_bins=255):
''' '''
......
...@@ -35,6 +35,10 @@ _fake_dequant_op_list = [ ...@@ -35,6 +35,10 @@ _fake_dequant_op_list = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs' 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
] ]
_fake_quant_dequant_op_list = [
'fake_quantize_dequantize_moving_average_abs_max'
]
_out_scale_op_list = [ _out_scale_op_list = [
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d", "mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d",
"batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul", "batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul",
...@@ -44,7 +48,7 @@ _out_scale_op_list = [ ...@@ -44,7 +48,7 @@ _out_scale_op_list = [
# list op real input and output names, to avoid processing input such as AxisTensor. # list op real input and output names, to avoid processing input such as AxisTensor.
_op_real_in_out_name = { _op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]], "conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input"], ["Output"]], "depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"mul": [["X", "Y"], ["Out"]], "mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]], "matmul": [["X", "Y"], ["Out"]],
"pool2d": [["X"], ["Out"]], "pool2d": [["X"], ["Out"]],
...@@ -236,6 +240,7 @@ class QuantizationTransformPass(object): ...@@ -236,6 +240,7 @@ class QuantizationTransformPass(object):
op_node.op()._set_attr("skip_quant", True) op_node.op()._set_attr("skip_quant", True)
def _transform_forward(graph, op): def _transform_forward(graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight")
for var_node in op.inputs: for var_node in op.inputs:
if var_node.name() not in op.input_arg_names(): if var_node.name() not in op.input_arg_names():
continue continue
...@@ -290,7 +295,7 @@ class QuantizationTransformPass(object): ...@@ -290,7 +295,7 @@ class QuantizationTransformPass(object):
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: for op in ops:
if op.name() in self._quantizable_ops: if op.name() in self._quantizable_ops:
if not QuantizationTransformPass._is_skip_quant(graph, op): if not self._is_skip_quant(graph, op):
_transform_forward(graph, op) _transform_forward(graph, op)
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
...@@ -636,8 +641,7 @@ class QuantizationTransformPass(object): ...@@ -636,8 +641,7 @@ class QuantizationTransformPass(object):
""" """
return "%s.scale" % (var_name) return "%s.scale" % (var_name)
@staticmethod def _is_skip_quant(self, graph, op_node):
def _is_skip_quant(graph, op_node):
""" """
Analyse whether the op node skips quantization. Analyse whether the op node skips quantization.
""" """
...@@ -650,20 +654,20 @@ class QuantizationTransformPass(object): ...@@ -650,20 +654,20 @@ class QuantizationTransformPass(object):
if op_node.name() in ["mul", "matmul"] and \ if op_node.name() in ["mul", "matmul"] and \
_is_input_all_not_persistable(graph, op_node): _is_input_all_not_persistable(graph, op_node):
is_skip = True is_skip = True
if op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_without_weight":
is_skip = True
return is_skip return is_skip
class QuantizationFreezePass(object): class QuantizationFreezePass(object):
_supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type
def __init__(self, def __init__(self,
scope, scope,
place, place,
weight_bits=8, weight_bits=8,
activation_bits=8, activation_bits=8,
weight_quantize_type='abs_max', weight_quantize_type='abs_max',
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']): quantizable_op_type=None):
""" """
The freeze pass is used to adjust the quantize operator order, for example: The freeze pass is used to adjust the quantize operator order, for example:
1) `activation -> quant -> dequant -> conv2d` will be frozen into 1) `activation -> quant -> dequant -> conv2d` will be frozen into
...@@ -679,9 +683,8 @@ class QuantizationFreezePass(object): ...@@ -679,9 +683,8 @@ class QuantizationFreezePass(object):
weight_quantize_type(str): quantization type for weights, support 'abs_max' and weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained. since weights are fixed once the model is well trained.
quantizable_op_type(list[str]): List the type of ops that will be quantized. quantizable_op_type(list[str]): This input param will be removed latter. The pass
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in will process all quantized op, so it is not necessary to set the input param.
QuantizationTransformPass and ConvertToInt8Pass must be the same as this.
""" """
assert scope is not None, \ assert scope is not None, \
'The scope cannot be set None.' 'The scope cannot be set None.'
...@@ -692,16 +695,12 @@ class QuantizationFreezePass(object): ...@@ -692,16 +695,12 @@ class QuantizationFreezePass(object):
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = quantizable_op_type
for op in self._quantizable_ops:
assert op in QuantizationFreezePass._supported_quantizable_op_type, \
op + " is not supported for quantization."
self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._fake_quant_op_names = _fake_quant_op_list self._fake_quant_op_names = _fake_quant_op_list
self._fake_dequant_op_names = _fake_dequant_op_list self._fake_dequant_op_names = _fake_dequant_op_list
self._op_input_rename_map = collections.OrderedDict() self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict() self._op_output_rename_map = collections.OrderedDict()
self._var_scale_map = collections.OrderedDict() self._quant_var_scale_map = collections.OrderedDict()
def apply(self, graph): def apply(self, graph):
""" """
...@@ -712,6 +711,7 @@ class QuantizationFreezePass(object): ...@@ -712,6 +711,7 @@ class QuantizationFreezePass(object):
Returns: Returns:
None None
""" """
# Get input scales in fake quant op and process weights
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
...@@ -733,7 +733,7 @@ class QuantizationFreezePass(object): ...@@ -733,7 +733,7 @@ class QuantizationFreezePass(object):
else: else:
scale_v = self._load_var( scale_v = self._load_var(
op_node.output('OutScale')[0])[0] op_node.output('OutScale')[0])[0]
self._var_scale_map[input_arg_name] = scale_v self._quant_var_scale_map[input_arg_name] = scale_v
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
# quantize weight and restore # quantize weight and restore
param_v = self._load_var(input_arg_name) param_v = self._load_var(input_arg_name)
...@@ -743,32 +743,29 @@ class QuantizationFreezePass(object): ...@@ -743,32 +743,29 @@ class QuantizationFreezePass(object):
else: else:
scale_v = graph._find_node_by_name( scale_v = graph._find_node_by_name(
op_node.outputs, op_node.output('OutScale')[0]) op_node.outputs, op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v self._quant_var_scale_map[input_arg_name] = scale_v
# Remove all fake dequant op
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._fake_dequant_op_names: if op_name in self._fake_dequant_op_names:
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
# Insert post dequant op
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_node_desc = op_node.op()
if op_name in self._quantizable_ops: if op_node_desc.has_attr("quantization_type") and \
# only process the node that is quantized by QuantizationTransformPass op_node_desc.attr("quantization_type") == "qat_with_weight":
is_op_node_quantized = False if self._weight_quantize_type == 'channel_wise_abs_max' \
for var_node in op_node.inputs: and op_node.name() in self._conv_ops:
var_name = var_node.name() self._insert_post_channel_dequant_op(graph, op_node)
if var_name.endswith('.dequantized'): else:
is_op_node_quantized = True self._insert_post_dequant_op(graph, op_node)
if is_op_node_quantized:
if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
self._insert_post_channel_dequant_op(graph, op_node)
else:
self._insert_post_dequant_op(graph, op_node)
# Rename inputs of the followed ops after inserting dequant_op after fc/conv
for op_node in ops: for op_node in ops:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
for var_node in op_node.inputs: for var_node in op_node.inputs:
if var_node.node in self._op_output_rename_map: if var_node.node in self._op_output_rename_map:
old_in = var_node old_in = var_node
...@@ -802,7 +799,7 @@ class QuantizationFreezePass(object): ...@@ -802,7 +799,7 @@ class QuantizationFreezePass(object):
new_in.clear_outputs() new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node) graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name) original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name] scale_v = self._quant_var_scale_map[original_var_name]
if original_var_name in persistable_vars: if original_var_name in persistable_vars:
assert isinstance( assert isinstance(
scale_v, scale_v,
...@@ -811,7 +808,7 @@ class QuantizationFreezePass(object): ...@@ -811,7 +808,7 @@ class QuantizationFreezePass(object):
channel_scale = np.array(scale_v) channel_scale = np.array(scale_v)
else: else:
assert isinstance(scale_v, IrNode) assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name] scale_var_node = self._quant_var_scale_map[original_var_name]
if len(op_node.output_arg_names()) != 1: if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has" raise ValueError("Only support one output, but op %s has"
...@@ -867,7 +864,7 @@ class QuantizationFreezePass(object): ...@@ -867,7 +864,7 @@ class QuantizationFreezePass(object):
new_in.clear_outputs() new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node) graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name) original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name] scale_v = self._quant_var_scale_map[original_var_name]
if original_var_name in persistable_vars: if original_var_name in persistable_vars:
assert self._is_float( assert self._is_float(
scale_v), 'The scale of parameter %s is not a float.' % ( scale_v), 'The scale of parameter %s is not a float.' % (
...@@ -876,7 +873,7 @@ class QuantizationFreezePass(object): ...@@ -876,7 +873,7 @@ class QuantizationFreezePass(object):
else: else:
max_range *= act_range max_range *= act_range
assert isinstance(scale_v, IrNode) assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name] scale_var_node = self._quant_var_scale_map[original_var_name]
if len(op_node.output_arg_names()) != 1: if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has" raise ValueError("Only support one output, but op %s has"
...@@ -963,13 +960,7 @@ class QuantizationFreezePass(object): ...@@ -963,13 +960,7 @@ class QuantizationFreezePass(object):
class ConvertToInt8Pass(object): class ConvertToInt8Pass(object):
_supported_quantizable_op_type = \ def __init__(self, scope, place, quantizable_op_type=None):
QuantizationTransformPass._supported_quantizable_op_type
def __init__(self,
scope,
place,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
""" """
Convert the weights into int8_t type. Convert the weights into int8_t type.
...@@ -977,9 +968,8 @@ class ConvertToInt8Pass(object): ...@@ -977,9 +968,8 @@ class ConvertToInt8Pass(object):
scope(fluid.Scope): scope is used to get the weight tensor values. scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
8bits weight tensors. 8bits weight tensors.
quantizable_op_type(list[str]): List the type of ops that will be quantized. quantizable_op_type(list[str]): This input param will be removed latter. The pass
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in will process all quantized op, so it is not necessary to set the input param.
QuantizationTransformPass and QuantizationFreezePass must be the same as this.
""" """
assert scope is not None, \ assert scope is not None, \
'The scope cannot be set None.' 'The scope cannot be set None.'
...@@ -987,10 +977,6 @@ class ConvertToInt8Pass(object): ...@@ -987,10 +977,6 @@ class ConvertToInt8Pass(object):
'The place cannot be set None.' 'The place cannot be set None.'
self._scope = scope self._scope = scope
self._place = place self._place = place
self._quantizable_ops = quantizable_op_type
for op in self._quantizable_ops:
assert op in ConvertToInt8Pass._supported_quantizable_op_type, \
op + " is not supported for quantization."
def apply(self, graph): def apply(self, graph):
""" """
...@@ -1006,10 +992,8 @@ class ConvertToInt8Pass(object): ...@@ -1006,10 +992,8 @@ class ConvertToInt8Pass(object):
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
input_map = {} input_map = {}
for op_node in ops: for op_node in ops:
op_name = op_node.name() if op_node.op().has_attr("quantization_type") and \
if op_name in self._quantizable_ops: op_node.op().attr("quantization_type") == "qat_with_weight":
if QuantizationTransformPass._is_skip_quant(graph, op_node):
continue
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() name = var_node.name()
if name in persistable_vars: if name in persistable_vars:
...@@ -1259,9 +1243,9 @@ class AddQuantDequantPass(object): ...@@ -1259,9 +1243,9 @@ class AddQuantDequantPass(object):
"equal", "gather", "greater_equal", "greater_than", "less_equal", "equal", "gather", "greater_equal", "greater_than", "less_equal",
"less_than", "mean", "not_equal", "reshape", "reshape2", "less_than", "mean", "not_equal", "reshape", "reshape2",
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice", "bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
"squeeze", "elementwise_sub", "mul", "matmul" "squeeze", "elementwise_sub", "mul", "matmul", "relu", "relu6",
"leaky_relu", "tanh", "swish"
] ]
_activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]
def __init__(self, def __init__(self,
scope=None, scope=None,
...@@ -1307,8 +1291,7 @@ class AddQuantDequantPass(object): ...@@ -1307,8 +1291,7 @@ class AddQuantDequantPass(object):
else: else:
self._quantizable_op_type = quantizable_op_type self._quantizable_op_type = quantizable_op_type
for op_type in quantizable_op_type: for op_type in quantizable_op_type:
assert op_type in AddQuantDequantPass._supported_quantizable_op_type + \ assert op_type in AddQuantDequantPass._supported_quantizable_op_type, \
AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization." op_type + " is not supported for quantization."
self._quantizable_grad_op_type = [ self._quantizable_grad_op_type = [
'%s_grad' % (op) for op in self._quantizable_op_type '%s_grad' % (op) for op in self._quantizable_op_type
...@@ -1343,17 +1326,15 @@ class AddQuantDequantPass(object): ...@@ -1343,17 +1326,15 @@ class AddQuantDequantPass(object):
elif isinstance(self._skip_pattern, str): elif isinstance(self._skip_pattern, str):
is_skip = op_node.op().has_attr("op_namescope") and \ is_skip = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
is_quantized = op_node.op().has_attr("quantization_type") and \
is_op_node_quantized = False op_node.op().attr("quantization_type") == "qat_with_weight"
for var_node in op_node.inputs: if is_skip or is_quantized or \
var_name = var_node.name()
if var_name.endswith('.dequantized'):
is_op_node_quantized = True
if is_skip or is_op_node_quantized or \
(not _is_input_all_not_persistable(graph, op_node)): (not _is_input_all_not_persistable(graph, op_node)):
continue continue
op_node.op()._set_attr("quantization_type",
"qat_without_weight")
op_node.op()._set_attr("activation_bits", self._quant_bits)
input_name_list = _op_real_in_out_name[op_node.name()][0] input_name_list = _op_real_in_out_name[op_node.name()][0]
arg_names = [] arg_names = []
for input_name in input_name_list: for input_name in input_name_list:
......
...@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq.save_quantized_model(self.int8_model) ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type, def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file): is_full_quantize, is_use_cache_file, diff_threshold):
infer_iterations = self.infer_iterations infer_iterations = self.infer_iterations
batch_size = self.batch_size batch_size = self.batch_size
sample_iterations = self.sample_iterations sample_iterations = self.sample_iterations
...@@ -296,11 +296,11 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -296,11 +296,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
sys.stdout.flush() sys.stdout.flush()
delta_value = fp32_acc1 - int8_acc1 delta_value = fp32_acc1 - int8_acc1
self.assertLess(delta_value, 0.025) self.assertLess(delta_value, diff_threshold)
class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization): class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_mobilenetv1(self): def test_post_training_kl_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "KL" algo = "KL"
data_urls = [ data_urls = [
...@@ -310,10 +310,29 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization): ...@@ -310,10 +310,29 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
quantizable_op_type = [ quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add" "conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
] ]
is_full_quantize = True is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file) is_full_quantize, is_use_cache_file, diff_threshold)
class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_abs_max_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
is_full_quantize = False
is_use_cache_file = False
diff_threshold = 0.05
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,7 +20,7 @@ from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantiza ...@@ -20,7 +20,7 @@ from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantiza
class TestPostTrainingForResnet50(TestPostTrainingQuantization): class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def test_post_training_resnet50(self): def test_post_training_resnet50(self):
model = "ResNet-50" model = "ResNet-50"
algo = "direct" algo = "min_max"
data_urls = [ data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
] ]
...@@ -28,8 +28,9 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): ...@@ -28,8 +28,9 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
quantizable_op_type = ["conv2d", "mul"] quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file) is_full_quantize, is_use_cache_file, diff_threshold)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册