未验证 提交 589cd878 编写于 作者: C cc 提交者: GitHub

Post_training_quantizaion supports min_max methon (#23078)

* Post_training_quantizaion supports min_max methon
上级 194a22c5
......@@ -37,7 +37,10 @@ def _load_variable_data(scope, var_name):
'''
Load variable value from scope
'''
return np.array(scope.find_var(var_name).get_tensor())
var_node = scope.find_var(var_name)
assert var_node is not None, \
"Cannot find " + var_name + " in scope."
return np.array(var_node.get_tensor())
def _set_variable_data(scope, place, var_name, np_value):
......@@ -53,6 +56,12 @@ def _set_variable_data(scope, place, var_name, np_value):
class PostTrainingQuantization(object):
"""
Utilizing post training quantization methon to quantize the FP32 model,
and it uses calibrate data to get the quantization information for all
quantized variables.
"""
def __init__(self,
executor=None,
scope=None,
......@@ -70,13 +79,10 @@ class PostTrainingQuantization(object):
is_use_cache_file=False,
cache_dir="./temp_post_training"):
'''
The class utilizes post training quantization methon to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Constructor.
Args:
executor(fluid.Executor): The executor to load, run and save the
executor(fluid.Executor): The executor to load, run and save the
quantized model.
scope(fluid.Scope, optional): The scope of the program, use it to load
and save variables. If scope=None, get scope by global_scope().
......@@ -96,9 +102,11 @@ class PostTrainingQuantization(object):
batch_nums(int, optional): If batch_nums is not None, the number of
calibrate data is batch_size*batch_nums. If batch_nums is None, use
all data provided by sample_generator as calibrate data.
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
abs_max methon to get the scale factor. Default is KL.
algo(str, optional): If algo='KL', use KL-divergenc method to
get the KL threshold for quantized activations and get the abs_max
value for quantized weights. If algo='abs_max', get the abs max
value for activations and weights. If algo= 'min_max', get the min
and max value for quantized activations and weights. Default is KL.
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul"].
......@@ -158,7 +166,9 @@ class PostTrainingQuantization(object):
assert executor is not None, "The executor cannot be None."
assert model_dir is not None, "The model_dir cannot be None."
assert sample_generator is not None, \
"The sample_generator cannot be None."
"The sample_generator cannot be None."
assert algo in ['KL', 'abs_max', 'min_max'], \
"The algo should be KL, abs_max or min_max."
self._executor = executor
self._scope = global_scope() if scope == None else scope
......@@ -182,8 +192,7 @@ class PostTrainingQuantization(object):
else:
self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type:
assert op_type in supported_quantizable_op_type + \
AddQuantDequantPass._activation_type, \
assert op_type in supported_quantizable_op_type, \
op_type + " is not supported for quantization."
self._place = self._executor.place
......@@ -197,20 +206,25 @@ class PostTrainingQuantization(object):
self._quantized_weight_var_name = set()
self._quantized_act_var_name = set()
self._sampling_data = {}
self._quantized_var_scale_factor = {}
self._quantized_var_kl_threshold = {}
self._quantized_var_min = {}
self._quantized_var_max = {}
self._quantized_var_abs_max = {}
def quantize(self):
'''
Quantize the fp32 model. Use calibrate data to calculate the scale factor of
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Load the FP32 model, and use the calibrate data to calculate the forward-stage.
Based on the sample data, we can get the quantization information, and obtain
the final quantized model.
Args:
None
Returns:
the program of quantized model.
'''
self._preprocess()
self._load_model_data()
self._collect_quantized_varnames()
self._set_activation_persistable()
batch_id = 0
for data in self._data_loader():
......@@ -218,22 +232,29 @@ class PostTrainingQuantization(object):
feed=data,
fetch_list=self._fetch_list,
return_numpy=False)
self._sample_data(batch_id)
if self._algo == "KL":
self._sample_data(batch_id)
else:
self._sample_threshold()
if batch_id % 5 == 0:
_logger.info("run batch: " + str(batch_id))
_logger.info("Run batch: " + str(batch_id))
batch_id += 1
if self._batch_nums and batch_id >= self._batch_nums:
break
_logger.info("all run batch: " + str(batch_id))
_logger.info("Finish all batch: " + str(batch_id))
_logger.info("calculate scale factor ...")
self._calculate_scale_factor()
self._reset_activation_persistable()
_logger.info("update the program ...")
self._update_program()
if self._algo == "KL":
self._calculate_kl_threshold()
self._save_output_scale()
if self._algo in ["KL", "abs_max"]:
self._update_program()
else:
self._save_input_threhold()
self._save_output_threshold()
return self._program
def save_quantized_model(self, save_model_path):
......@@ -252,12 +273,11 @@ class PostTrainingQuantization(object):
executor=self._executor,
main_program=self._program)
def _preprocess(self):
def _load_model_data(self):
'''
Load model and set data loader, collect the variable names for sampling,
and set activation variables to be persistable.
Load model and set data loader.
'''
# load model and set data loader
_logger.info("Load model and set data loader ...")
[self._program, self._feed_list, self._fetch_list] = \
io.load_inference_model(dirname=self._model_dir,
executor=self._executor,
......@@ -273,7 +293,12 @@ class PostTrainingQuantization(object):
drop_last=True,
places=self._place)
# collect the variable names for sampling.
def _collect_quantized_varnames(self):
'''
Collect the variable names for sampling, and set activation
variables to be persistable.
'''
_logger.info("Collect quantized variable names ...")
# TODO(juncaipeng), consider the name_scope of skip_quant and
# reduce the variables for sampling
persistable_var_names = []
......@@ -284,46 +309,109 @@ class PostTrainingQuantization(object):
for op in self._program.global_block().ops:
op_type = op.type
if op_type in self._quantizable_op_type:
if op_type in ("conv2d", "depthwise_conv2d"):
self._quantized_act_var_name.add(op.input("Input")[0])
self._quantized_weight_var_name.add(op.input("Filter")[0])
self._quantized_act_var_name.add(op.output("Output")[0])
elif op_type in ["mul", "matmul"]:
x_var_name = op.input("X")[0]
if x_var_name in persistable_var_names:
self._quantized_weight_var_name.add(x_var_name)
else:
self._quantized_act_var_name.add(x_var_name)
y_var_name = op.input("Y")[0]
if y_var_name in persistable_var_names:
self._quantized_weight_var_name.add(y_var_name)
else:
self._quantized_act_var_name.add(y_var_name)
self._quantized_act_var_name.add(op.output("Out")[0])
else:
# process other quantizable op type, the input must all not persistable
if self._is_input_all_not_persistable(
op, persistable_var_names):
input_output_name_list = self._op_real_in_out_name[
op_type]
for input_name in input_output_name_list[0]:
for var_name in op.input(input_name):
self._quantized_act_var_name.add(var_name)
for output_name in input_output_name_list[1]:
for var_name in op.output(output_name):
self._quantized_act_var_name.add(var_name)
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
name_list = self._op_real_in_out_name[op_type]
for input_name in name_list[0]:
for var_name in op.input(input_name):
if var_name in persistable_var_names:
self._quantized_weight_var_name.add(var_name)
else:
self._quantized_act_var_name.add(var_name)
for output_name in name_list[1]:
for var_name in op.output(output_name):
if var_name in persistable_var_names:
self._quantized_weight_var_name.add(var_name)
else:
self._quantized_act_var_name.add(var_name)
def _set_activation_persistable(self):
'''
Set activation variables to be persistable, so can obtain
the tensor data in sample_data
'''
persistable_var_names = []
for var in self._program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = True
def _reset_activation_persistable(self):
'''
Reset activations to be not persistable.
'''
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = False
def _sample_threshold(self):
'''
Sample the input threshold(min, max, or abs_max) in every iterations.
'''
assert self._algo in ["abs_max", "min_max"], \
"The algo should be abs_max or min_max to sample min max value."
if self._algo == "abs_max":
# Only calculate abs_max value for weight for once
if self._quantized_var_abs_max == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
abs_max_per_channel = []
for i in range(var_tensor.shape[0]):
abs_max_per_channel.append(
float(np.max(np.abs(var_tensor[i]))))
self._quantized_var_abs_max[var_name] = abs_max_per_channel
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
abs_max_value = float(np.max(np.abs(var_tensor)))
if (var_name not in self._quantized_var_abs_max) or \
(abs_max_value > self._quantized_var_abs_max[var_name]):
self._quantized_var_abs_max[var_name] = abs_max_value
elif self._algo == "min_max":
if self._quantized_var_min == {} and self._quantized_var_max == {}:
for var_name in self._quantized_weight_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
min_per_channel = []
max_per_channle = []
for i in range(var_tensor.shape[0]):
min_per_channel.append(float(np.min(var_tensor[i])))
max_per_channle.append(float(np.max(var_tensor[i])))
self._quantized_var_min[var_name] = min_per_channel
self._quantized_var_max[var_name] = max_per_channle
for var_name in self._quantized_act_var_name:
var_tensor = _load_variable_data(self._scope, var_name)
min_value = float(np.min(var_tensor))
max_value = float(np.max(var_tensor))
if (var_name not in self._quantized_var_min) or \
(min_value < self._quantized_var_min[var_name]):
self._quantized_var_min[var_name] = min_value
if (var_name not in self._quantized_var_max) or \
(max_value > self._quantized_var_max[var_name]):
self._quantized_var_max[var_name] = max_value
def _save_input_threhold(self):
'''
Save input threshold to the quantized op.
'''
assert self._algo == "min_max", \
"The algo should be min_max to save input threshold."
for op in self._program.global_block().ops:
if op.type in self._quantizable_op_type:
input_name_list = self._op_real_in_out_name[op.type][0]
for input_name in input_name_list:
for var_name in op.input(input_name):
assert var_name in self._quantized_var_min
assert var_name in self._quantized_var_max
op._set_attr(var_name + ".min",
self._quantized_var_min[var_name])
op._set_attr(var_name + ".max",
self._quantized_var_max[var_name])
def _sample_data(self, iter):
'''
Sample the tensor data of quantized variables,
applied in every iteration.
'''
assert self._algo == "KL", "The algo should be KL to sample data."
for var_name in self._quantized_weight_var_name:
if var_name not in self._sampling_data:
var_tensor = _load_variable_data(self._scope, var_name)
......@@ -344,19 +432,20 @@ class PostTrainingQuantization(object):
var_tensor = var_tensor.ravel()
self._sampling_data[var_name].append(var_tensor)
def _calculate_scale_factor(self):
def _calculate_kl_threshold(self):
'''
Calculate the scale factor of quantized variables.
Calculate the KL threshold of quantized variables.
'''
_logger.info("Calculate KL threshold ...")
assert self._algo == "KL", "The algo should be KL to calculate kl threshold."
# apply channel_wise_abs_max quantization for weights
for var_name in self._quantized_weight_var_name:
data = self._sampling_data[var_name]
scale_factor_per_channel = []
threshold_per_channel = []
for i in range(data.shape[0]):
abs_max_value = np.max(np.abs(data[i]))
scale_factor_per_channel.append(abs_max_value)
self._quantized_var_scale_factor[
var_name] = scale_factor_per_channel
threshold_per_channel.append(abs_max_value)
self._quantized_var_kl_threshold[var_name] = threshold_per_channel
# apply kl quantization for activation
if self._is_use_cache_file:
......@@ -369,36 +458,25 @@ class PostTrainingQuantization(object):
sampling_data.append(np.load(file_path))
os.remove(file_path)
sampling_data = np.concatenate(sampling_data)
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(sampling_data))
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(sampling_data))
else:
for var_name in self._quantized_act_var_name:
self._sampling_data[var_name] = np.concatenate(
self._sampling_data[var_name])
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
else:
self._quantized_var_scale_factor[var_name] = \
np.max(np.abs(self._sampling_data[var_name]))
self._quantized_var_kl_threshold[var_name] = \
self._get_kl_scaling_factor(np.abs(self._sampling_data[var_name]))
def _update_program(self):
'''
Insert fake_quantize/fake_dequantize op to the program.
Use QuantizationTransformPass and AddQuantDequantPass to insert
fake_quantize, fake_dequantize and fake_quant_dequant op.
Besides, save all kl threshold to the scale var node.
'''
# reset quantized activation variable
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = False
# use QuantizationTransformPass to insert fake_quantize/fake_dequantize op
_logger.info("Update the program ...")
graph = IrGraph(core.Graph(self._program.desc), for_test=True)
# use QuantizationTransformPass to insert fake_quant/fake_dequantize op
major_quantizable_op_types = []
for op_type in QuantizationTransformPass._supported_quantizable_op_type:
if op_type in self._quantizable_op_type:
......@@ -424,8 +502,12 @@ class PostTrainingQuantization(object):
quantizable_op_type=minor_quantizable_op_types)
add_quant_dequant_pass.apply(graph)
# save scale factor to scale var node
for key, val in self._quantized_var_scale_factor.items():
# save abs_max or KL threshold to scale var node
if self._algo == "KL":
scale_dict = self._quantized_var_kl_threshold
else:
scale_dict = self._quantized_var_abs_max
for key, val in scale_dict.items():
_set_variable_data(
self._scope,
self._place,
......@@ -450,33 +532,34 @@ class PostTrainingQuantization(object):
freeze_pass.apply(graph)
self._program = graph.to_program()
def _save_output_scale(self):
def _save_output_threshold(self):
'''
Save output scale to the quantized op.
Save output threshold to the quantized op.
'''
output_scale_name = "output_scale"
for op in self._program.global_block().ops:
if op.type in self._quantizable_op_type:
output_name_list = self._op_real_in_out_name[op.type][1]
for output_name in output_name_list:
for output_var_name in op.output(output_name):
if output_var_name in self._quantized_var_scale_factor:
op._set_attr(output_scale_name,
self._quantized_var_scale_factor[
output_var_name])
def _is_input_all_not_persistable(self, op, persistable_var_names):
'''
Analyze the real inputs of the op are all not persistable.
'''
is_input_all_not_persistable = True
input_name_list = self._op_real_in_out_name[op.type][0]
for input_name in input_name_list:
for var_name in op.input(input_name):
if var_name in persistable_var_names:
is_input_all_not_persistable = False
break
return is_input_all_not_persistable
for var_name in op.output(output_name):
if self._algo == "KL":
assert var_name in self._quantized_var_kl_threshold
op._set_attr(
var_name + ".threshold",
self._quantized_var_kl_threshold[var_name])
op._set_attr("quantization_type", "post_kl")
elif self._algo == "abs_max":
assert var_name in self._quantized_var_abs_max
op._set_attr(var_name + ".threshold",
self._quantized_var_abs_max[var_name])
op._set_attr("quantization_type", "post_abs_max")
elif self._algo == "min_max":
assert var_name in self._quantized_var_min
assert var_name in self._quantized_var_max
op._set_attr(var_name + ".min",
self._quantized_var_min[var_name])
op._set_attr(var_name + ".max",
self._quantized_var_max[var_name])
op._set_attr("quantization_type", "post_min_max")
def _get_kl_scaling_factor(self, activation_blob, num_quantized_bins=255):
'''
......
......@@ -35,6 +35,10 @@ _fake_dequant_op_list = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
]
_fake_quant_dequant_op_list = [
'fake_quantize_dequantize_moving_average_abs_max'
]
_out_scale_op_list = [
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d",
"batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul",
......@@ -44,7 +48,7 @@ _out_scale_op_list = [
# list op real input and output names, to avoid processing input such as AxisTensor.
_op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]],
"pool2d": [["X"], ["Out"]],
......@@ -236,6 +240,7 @@ class QuantizationTransformPass(object):
op_node.op()._set_attr("skip_quant", True)
def _transform_forward(graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight")
for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
continue
......@@ -290,7 +295,7 @@ class QuantizationTransformPass(object):
# The loop for transforming the forward graph:
for op in ops:
if op.name() in self._quantizable_ops:
if not QuantizationTransformPass._is_skip_quant(graph, op):
if not self._is_skip_quant(graph, op):
_transform_forward(graph, op)
# The loop for renaming the inputs of backward op.
for op in ops:
......@@ -636,8 +641,7 @@ class QuantizationTransformPass(object):
"""
return "%s.scale" % (var_name)
@staticmethod
def _is_skip_quant(graph, op_node):
def _is_skip_quant(self, graph, op_node):
"""
Analyse whether the op node skips quantization.
"""
......@@ -650,20 +654,20 @@ class QuantizationTransformPass(object):
if op_node.name() in ["mul", "matmul"] and \
_is_input_all_not_persistable(graph, op_node):
is_skip = True
if op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_without_weight":
is_skip = True
return is_skip
class QuantizationFreezePass(object):
_supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type
def __init__(self,
scope,
place,
weight_bits=8,
activation_bits=8,
weight_quantize_type='abs_max',
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
quantizable_op_type=None):
"""
The freeze pass is used to adjust the quantize operator order, for example:
1) `activation -> quant -> dequant -> conv2d` will be frozen into
......@@ -679,9 +683,8 @@ class QuantizationFreezePass(object):
weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationTransformPass and ConvertToInt8Pass must be the same as this.
quantizable_op_type(list[str]): This input param will be removed latter. The pass
will process all quantized op, so it is not necessary to set the input param.
"""
assert scope is not None, \
'The scope cannot be set None.'
......@@ -692,16 +695,12 @@ class QuantizationFreezePass(object):
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = quantizable_op_type
for op in self._quantizable_ops:
assert op in QuantizationFreezePass._supported_quantizable_op_type, \
op + " is not supported for quantization."
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._fake_quant_op_names = _fake_quant_op_list
self._fake_dequant_op_names = _fake_dequant_op_list
self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict()
self._var_scale_map = collections.OrderedDict()
self._quant_var_scale_map = collections.OrderedDict()
def apply(self, graph):
"""
......@@ -712,6 +711,7 @@ class QuantizationFreezePass(object):
Returns:
None
"""
# Get input scales in fake quant op and process weights
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_op_nodes()
for op_node in ops:
......@@ -733,7 +733,7 @@ class QuantizationFreezePass(object):
else:
scale_v = self._load_var(
op_node.output('OutScale')[0])[0]
self._var_scale_map[input_arg_name] = scale_v
self._quant_var_scale_map[input_arg_name] = scale_v
self._remove_fake_quant_and_dequant_op(graph, op_node)
# quantize weight and restore
param_v = self._load_var(input_arg_name)
......@@ -743,32 +743,29 @@ class QuantizationFreezePass(object):
else:
scale_v = graph._find_node_by_name(
op_node.outputs, op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
self._quant_var_scale_map[input_arg_name] = scale_v
# Remove all fake dequant op
ops = graph.all_op_nodes()
for op_node in ops:
op_name = op_node.name()
if op_name in self._fake_dequant_op_names:
self._remove_fake_quant_and_dequant_op(graph, op_node)
# Insert post dequant op
ops = graph.all_op_nodes()
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
# only process the node that is quantized by QuantizationTransformPass
is_op_node_quantized = False
for var_node in op_node.inputs:
var_name = var_node.name()
if var_name.endswith('.dequantized'):
is_op_node_quantized = True
if is_op_node_quantized:
if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
self._insert_post_channel_dequant_op(graph, op_node)
else:
self._insert_post_dequant_op(graph, op_node)
op_node_desc = op_node.op()
if op_node_desc.has_attr("quantization_type") and \
op_node_desc.attr("quantization_type") == "qat_with_weight":
if self._weight_quantize_type == 'channel_wise_abs_max' \
and op_node.name() in self._conv_ops:
self._insert_post_channel_dequant_op(graph, op_node)
else:
self._insert_post_dequant_op(graph, op_node)
# Rename inputs of the followed ops after inserting dequant_op after fc/conv
for op_node in ops:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
for var_node in op_node.inputs:
if var_node.node in self._op_output_rename_map:
old_in = var_node
......@@ -802,7 +799,7 @@ class QuantizationFreezePass(object):
new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name]
scale_v = self._quant_var_scale_map[original_var_name]
if original_var_name in persistable_vars:
assert isinstance(
scale_v,
......@@ -811,7 +808,7 @@ class QuantizationFreezePass(object):
channel_scale = np.array(scale_v)
else:
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]
scale_var_node = self._quant_var_scale_map[original_var_name]
if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has"
......@@ -867,7 +864,7 @@ class QuantizationFreezePass(object):
new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name]
scale_v = self._quant_var_scale_map[original_var_name]
if original_var_name in persistable_vars:
assert self._is_float(
scale_v), 'The scale of parameter %s is not a float.' % (
......@@ -876,7 +873,7 @@ class QuantizationFreezePass(object):
else:
max_range *= act_range
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]
scale_var_node = self._quant_var_scale_map[original_var_name]
if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has"
......@@ -963,13 +960,7 @@ class QuantizationFreezePass(object):
class ConvertToInt8Pass(object):
_supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type
def __init__(self,
scope,
place,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
def __init__(self, scope, place, quantizable_op_type=None):
"""
Convert the weights into int8_t type.
......@@ -977,9 +968,8 @@ class ConvertToInt8Pass(object):
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
8bits weight tensors.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationTransformPass and QuantizationFreezePass must be the same as this.
quantizable_op_type(list[str]): This input param will be removed latter. The pass
will process all quantized op, so it is not necessary to set the input param.
"""
assert scope is not None, \
'The scope cannot be set None.'
......@@ -987,10 +977,6 @@ class ConvertToInt8Pass(object):
'The place cannot be set None.'
self._scope = scope
self._place = place
self._quantizable_ops = quantizable_op_type
for op in self._quantizable_ops:
assert op in ConvertToInt8Pass._supported_quantizable_op_type, \
op + " is not supported for quantization."
def apply(self, graph):
"""
......@@ -1006,10 +992,8 @@ class ConvertToInt8Pass(object):
ops = graph.all_op_nodes()
input_map = {}
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
if QuantizationTransformPass._is_skip_quant(graph, op_node):
continue
if op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_with_weight":
for var_node in op_node.inputs:
name = var_node.name()
if name in persistable_vars:
......@@ -1259,9 +1243,9 @@ class AddQuantDequantPass(object):
"equal", "gather", "greater_equal", "greater_than", "less_equal",
"less_than", "mean", "not_equal", "reshape", "reshape2",
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
"squeeze", "elementwise_sub", "mul", "matmul"
"squeeze", "elementwise_sub", "mul", "matmul", "relu", "relu6",
"leaky_relu", "tanh", "swish"
]
_activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]
def __init__(self,
scope=None,
......@@ -1307,8 +1291,7 @@ class AddQuantDequantPass(object):
else:
self._quantizable_op_type = quantizable_op_type
for op_type in quantizable_op_type:
assert op_type in AddQuantDequantPass._supported_quantizable_op_type + \
AddQuantDequantPass._activation_type, \
assert op_type in AddQuantDequantPass._supported_quantizable_op_type, \
op_type + " is not supported for quantization."
self._quantizable_grad_op_type = [
'%s_grad' % (op) for op in self._quantizable_op_type
......@@ -1343,17 +1326,15 @@ class AddQuantDequantPass(object):
elif isinstance(self._skip_pattern, str):
is_skip = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
is_op_node_quantized = False
for var_node in op_node.inputs:
var_name = var_node.name()
if var_name.endswith('.dequantized'):
is_op_node_quantized = True
if is_skip or is_op_node_quantized or \
is_quantized = op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_with_weight"
if is_skip or is_quantized or \
(not _is_input_all_not_persistable(graph, op_node)):
continue
op_node.op()._set_attr("quantization_type",
"qat_without_weight")
op_node.op()._set_attr("activation_bits", self._quant_bits)
input_name_list = _op_real_in_out_name[op_node.name()][0]
arg_names = []
for input_name in input_name_list:
......
......@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file):
is_full_quantize, is_use_cache_file, diff_threshold):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
......@@ -296,11 +296,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
sys.stdout.flush()
delta_value = fp32_acc1 - int8_acc1
self.assertLess(delta_value, 0.025)
self.assertLess(delta_value, diff_threshold)
class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_mobilenetv1(self):
class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_kl_mobilenetv1(self):
model = "MobileNet-V1"
algo = "KL"
data_urls = [
......@@ -310,10 +310,29 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
is_full_quantize = True
is_full_quantize = False
is_use_cache_file = False
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file)
is_full_quantize, is_use_cache_file, diff_threshold)
class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_abs_max_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
is_full_quantize = False
is_use_cache_file = False
diff_threshold = 0.05
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold)
if __name__ == '__main__':
......
......@@ -20,7 +20,7 @@ from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantiza
class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def test_post_training_resnet50(self):
model = "ResNet-50"
algo = "direct"
algo = "min_max"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
......@@ -28,8 +28,9 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file)
is_full_quantize, is_use_cache_file, diff_threshold)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册