未验证 提交 00b11a4a 编写于 作者: J juncaipeng 提交者: GitHub

Support more ops in post training quantization, test=develop (#21073)

* Support  more ops in post training quantization, and save the output scale in quantized op.
* Update docs in post training quantization and qat 
上级 23876de5
...@@ -23,6 +23,7 @@ from ....log_helper import get_logger ...@@ -23,6 +23,7 @@ from ....log_helper import get_logger
from .quantization_pass import QuantizationTransformPass from .quantization_pass import QuantizationTransformPass
from .quantization_pass import QuantizationFreezePass from .quantization_pass import QuantizationFreezePass
from .quantization_pass import AddQuantDequantPass from .quantization_pass import AddQuantDequantPass
from .quantization_pass import _op_real_in_out_name
__all__ = ['PostTrainingQuantization'] __all__ = ['PostTrainingQuantization']
...@@ -39,10 +40,8 @@ class PostTrainingQuantization(object): ...@@ -39,10 +40,8 @@ class PostTrainingQuantization(object):
batch_nums=None, batch_nums=None,
scope=None, scope=None,
algo="KL", algo="KL",
quantizable_op_type=[ quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
"conv2d", "depthwise_conv2d", "mul", "pool2d", is_full_quantize=False):
"elementwise_add"
]):
''' '''
The class utilizes post training quantization methon to quantize the The class utilizes post training quantization methon to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of fp32 model. It uses calibrate data to calculate the scale factor of
...@@ -66,7 +65,14 @@ class PostTrainingQuantization(object): ...@@ -66,7 +65,14 @@ class PostTrainingQuantization(object):
abs_max methon to get the scale factor. Default is KL. abs_max methon to get the scale factor. Default is KL.
quantizable_op_type(list[str], optional): List the type of ops quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d", that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul", "pool2d", "elementwise_add"]. "mul"].
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type.
Returns:
None
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -98,13 +104,18 @@ class PostTrainingQuantization(object): ...@@ -98,13 +104,18 @@ class PostTrainingQuantization(object):
self._batch_size = batch_size self._batch_size = batch_size
self._batch_nums = batch_nums self._batch_nums = batch_nums
self._scope = global_scope() if scope == None else scope self._scope = global_scope() if scope == None else scope
self._quantizable_op_type = quantizable_op_type
self._algo = algo self._algo = algo
supported_quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add" supported_quantizable_op_type = \
] QuantizationTransformPass._supported_quantizable_op_type + \
AddQuantDequantPass._supported_quantizable_op_type
if is_full_quantize:
self._quantizable_op_type = supported_quantizable_op_type
else:
self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type: for op_type in self._quantizable_op_type:
assert op_type in supported_quantizable_op_type, \ assert op_type in supported_quantizable_op_type + \
AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization." op_type + " is not supported for quantization."
self._place = self._executor.place self._place = self._executor.place
...@@ -113,6 +124,7 @@ class PostTrainingQuantization(object): ...@@ -113,6 +124,7 @@ class PostTrainingQuantization(object):
self._fetch_list = None self._fetch_list = None
self._data_loader = None self._data_loader = None
self._op_real_in_out_name = _op_real_in_out_name
self._bit_length = 8 self._bit_length = 8
self._quantized_weight_var_name = [] self._quantized_weight_var_name = []
self._quantized_act_var_name = [] self._quantized_act_var_name = []
...@@ -125,10 +137,12 @@ class PostTrainingQuantization(object): ...@@ -125,10 +137,12 @@ class PostTrainingQuantization(object):
quantized variables, and inserts fake quant/dequant op to obtain the quantized variables, and inserts fake quant/dequant op to obtain the
quantized model. quantized model.
Return: Args:
None
Returns:
the program of quantized model. the program of quantized model.
''' '''
self._prepare() self._preprocess()
batch_id = 0 batch_id = 0
for data in self._data_loader(): for data in self._data_loader():
...@@ -136,7 +150,6 @@ class PostTrainingQuantization(object): ...@@ -136,7 +150,6 @@ class PostTrainingQuantization(object):
feed=data, feed=data,
fetch_list=self._fetch_list) fetch_list=self._fetch_list)
self._sample_data() self._sample_data()
if batch_id % 5 == 0: if batch_id % 5 == 0:
_logger.info("run batch: " + str(batch_id)) _logger.info("run batch: " + str(batch_id))
batch_id += 1 batch_id += 1
...@@ -144,9 +157,13 @@ class PostTrainingQuantization(object): ...@@ -144,9 +157,13 @@ class PostTrainingQuantization(object):
break break
_logger.info("all run batch: " + str(batch_id)) _logger.info("all run batch: " + str(batch_id))
_logger.info("calculate scale factor ...")
self._calculate_scale_factor() self._calculate_scale_factor()
_logger.info("update the program ...")
self._update_program() self._update_program()
self._save_output_scale()
return self._program return self._program
def save_quantized_model(self, save_model_path): def save_quantized_model(self, save_model_path):
...@@ -155,7 +172,7 @@ class PostTrainingQuantization(object): ...@@ -155,7 +172,7 @@ class PostTrainingQuantization(object):
Args: Args:
save_model_path(str): The path to save the quantized model save_model_path(str): The path to save the quantized model
Return: Returns:
None None
''' '''
io.save_inference_model( io.save_inference_model(
...@@ -165,7 +182,7 @@ class PostTrainingQuantization(object): ...@@ -165,7 +182,7 @@ class PostTrainingQuantization(object):
executor=self._executor, executor=self._executor,
main_program=self._program) main_program=self._program)
def _prepare(self): def _preprocess(self):
''' '''
Load model and set data loader, collect the variable names for sampling, Load model and set data loader, collect the variable names for sampling,
and set activation variables to be persistable. and set activation variables to be persistable.
...@@ -183,14 +200,13 @@ class PostTrainingQuantization(object): ...@@ -183,14 +200,13 @@ class PostTrainingQuantization(object):
drop_last=True, drop_last=True,
places=self._place) places=self._place)
#collect the variable names for sampling # collect the variable names for sampling
persistable_var_names = [] persistable_var_names = []
for var in self._program.list_vars(): for var in self._program.list_vars():
if var.persistable: if var.persistable:
persistable_var_names.append(var.name) persistable_var_names.append(var.name)
block = self._program.global_block() for op in self._program.global_block().ops:
for op in block.ops:
op_type = op.type op_type = op.type
if op_type in self._quantizable_op_type: if op_type in self._quantizable_op_type:
if op_type in ("conv2d", "depthwise_conv2d"): if op_type in ("conv2d", "depthwise_conv2d"):
...@@ -199,29 +215,30 @@ class PostTrainingQuantization(object): ...@@ -199,29 +215,30 @@ class PostTrainingQuantization(object):
op.input("Filter")[0]) op.input("Filter")[0])
self._quantized_act_var_name.append(op.output("Output")[0]) self._quantized_act_var_name.append(op.output("Output")[0])
elif op_type == "mul": elif op_type == "mul":
x_var_name = op.input("X")[0] if self._is_input_all_not_persistable(
y_var_name = op.input("Y")[0] op, persistable_var_names):
if x_var_name not in persistable_var_names and \
y_var_name not in persistable_var_names:
op._set_attr("skip_quant", True) op._set_attr("skip_quant", True)
_logger.warning("A mul op skip quant for two " _logger.warning("Skip quant a mul op for two "
"input variables are not persistable") "input variables are not persistable")
else: else:
self._quantized_act_var_name.append(x_var_name)
self._quantized_weight_var_name.append(y_var_name)
self._quantized_act_var_name.append(op.output("Out")[0])
elif op_type == "pool2d":
self._quantized_act_var_name.append(op.input("X")[0]) self._quantized_act_var_name.append(op.input("X")[0])
elif op_type == "elementwise_add": self._quantized_weight_var_name.append(op.input("Y")[0])
x_var_name = op.input("X")[0] self._quantized_act_var_name.append(op.output("Out")[0])
y_var_name = op.input("Y")[0] else:
if x_var_name not in persistable_var_names and \ # process other quantizable op type, the input must all not persistable
y_var_name not in persistable_var_names: if self._is_input_all_not_persistable(
self._quantized_act_var_name.append(x_var_name) op, persistable_var_names):
self._quantized_act_var_name.append(y_var_name) input_output_name_list = self._op_real_in_out_name[
op_type]
# set activation variables to be persistable, for input_name in input_output_name_list[0]:
# so can obtain the tensor data in sample_data stage for var_name in op.input(input_name):
self._quantized_act_var_name.append(var_name)
for output_name in input_output_name_list[1]:
for var_name in op.output(output_name):
self._quantized_act_var_name.append(var_name)
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
for var in self._program.list_vars(): for var in self._program.list_vars():
if var.name in self._quantized_act_var_name: if var.name in self._quantized_act_var_name:
var.persistable = True var.persistable = True
...@@ -246,8 +263,7 @@ class PostTrainingQuantization(object): ...@@ -246,8 +263,7 @@ class PostTrainingQuantization(object):
''' '''
Calculate the scale factor of quantized variables. Calculate the scale factor of quantized variables.
''' '''
_logger.info("calculate scale factor ...") # apply channel_wise_abs_max quantization for weights
for var_name in self._quantized_weight_var_name: for var_name in self._quantized_weight_var_name:
data = self._sampling_data[var_name] data = self._sampling_data[var_name]
scale_factor_per_channel = [] scale_factor_per_channel = []
...@@ -257,6 +273,7 @@ class PostTrainingQuantization(object): ...@@ -257,6 +273,7 @@ class PostTrainingQuantization(object):
self._quantized_var_scale_factor[ self._quantized_var_scale_factor[
var_name] = scale_factor_per_channel var_name] = scale_factor_per_channel
# apply kl quantization for activation
for var_name in self._quantized_act_var_name: for var_name in self._quantized_act_var_name:
if self._algo == "KL": if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \ self._quantized_var_scale_factor[var_name] = \
...@@ -269,8 +286,7 @@ class PostTrainingQuantization(object): ...@@ -269,8 +286,7 @@ class PostTrainingQuantization(object):
''' '''
Insert fake_quantize/fake_dequantize op to the program. Insert fake_quantize/fake_dequantize op to the program.
''' '''
_logger.info("update the program ...") # reset quantized activation variable
for var in self._program.list_vars(): for var in self._program.list_vars():
if var.name in self._quantized_act_var_name: if var.name in self._quantized_act_var_name:
var.persistable = False var.persistable = False
...@@ -278,10 +294,10 @@ class PostTrainingQuantization(object): ...@@ -278,10 +294,10 @@ class PostTrainingQuantization(object):
# use QuantizationTransformPass to insert fake_quantize/fake_dequantize op # use QuantizationTransformPass to insert fake_quantize/fake_dequantize op
graph = IrGraph(core.Graph(self._program.desc), for_test=True) graph = IrGraph(core.Graph(self._program.desc), for_test=True)
qtp_quantizable_op_type = [] major_quantizable_op_types = []
for op_type in ["conv2d", "depthwise_conv2d", "mul"]: for op_type in QuantizationTransformPass._supported_quantizable_op_type:
if op_type in self._quantizable_op_type: if op_type in self._quantizable_op_type:
qtp_quantizable_op_type.append(op_type) major_quantizable_op_types.append(op_type)
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
...@@ -289,18 +305,18 @@ class PostTrainingQuantization(object): ...@@ -289,18 +305,18 @@ class PostTrainingQuantization(object):
activation_bits=self._bit_length, activation_bits=self._bit_length,
activation_quantize_type='moving_average_abs_max', activation_quantize_type='moving_average_abs_max',
weight_quantize_type='channel_wise_abs_max', weight_quantize_type='channel_wise_abs_max',
quantizable_op_type=qtp_quantizable_op_type) quantizable_op_type=major_quantizable_op_types)
transform_pass.apply(graph) transform_pass.apply(graph)
# use AddQuantDequantPass to insert fake_quant_dequant op # use AddQuantDequantPass to insert fake_quant_dequant op
aqdp_quantizable_op_type = [] minor_quantizable_op_types = []
for op_type in ["pool2d", "elementwise_add"]: for op_type in AddQuantDequantPass._supported_quantizable_op_type:
if op_type in self._quantizable_op_type: if op_type in self._quantizable_op_type:
aqdp_quantizable_op_type.append(op_type) minor_quantizable_op_types.append(op_type)
add_quant_dequant_pass = AddQuantDequantPass( add_quant_dequant_pass = AddQuantDequantPass(
scope=self._scope, scope=self._scope,
place=self._place, place=self._place,
quantizable_op_type=aqdp_quantizable_op_type) quantizable_op_type=minor_quantizable_op_types)
add_quant_dequant_pass.apply(graph) add_quant_dequant_pass.apply(graph)
# save scale factor to scale var node # save scale factor to scale var node
...@@ -319,10 +335,25 @@ class PostTrainingQuantization(object): ...@@ -319,10 +335,25 @@ class PostTrainingQuantization(object):
weight_bits=self._bit_length, weight_bits=self._bit_length,
activation_bits=self._bit_length, activation_bits=self._bit_length,
weight_quantize_type='channel_wise_abs_max', weight_quantize_type='channel_wise_abs_max',
quantizable_op_type=qtp_quantizable_op_type) quantizable_op_type=major_quantizable_op_types)
freeze_pass.apply(graph) freeze_pass.apply(graph)
self._program = graph.to_program() self._program = graph.to_program()
def _save_output_scale(self):
'''
Save output scale to the quantized op.
'''
output_scale_name = "output_scale"
for op in self._program.global_block().ops:
if op.type in self._quantizable_op_type:
output_name_list = self._op_real_in_out_name[op.type][1]
for output_name in output_name_list:
output_var_name = op.output(output_name)[0]
if output_var_name in self._quantized_var_scale_factor:
op._set_attr(
output_scale_name,
self._quantized_var_scale_factor[output_var_name])
def _load_var_value(self, var_name): def _load_var_value(self, var_name):
''' '''
Load variable value from scope Load variable value from scope
...@@ -331,7 +362,7 @@ class PostTrainingQuantization(object): ...@@ -331,7 +362,7 @@ class PostTrainingQuantization(object):
def _set_var_node_value(self, var_node_name, np_value): def _set_var_node_value(self, var_node_name, np_value):
''' '''
Set the value of var node by name, if the node is not exits, Set the value of var node by name, if the node exits,
''' '''
assert isinstance(np_value, np.ndarray), \ assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.' 'The type of value should be numpy array.'
...@@ -340,6 +371,19 @@ class PostTrainingQuantization(object): ...@@ -340,6 +371,19 @@ class PostTrainingQuantization(object):
tensor = var_node.get_tensor() tensor = var_node.get_tensor()
tensor.set(np_value, self._place) tensor.set(np_value, self._place)
def _is_input_all_not_persistable(self, op, persistable_var_names):
'''
Analyze the real inputs of the op are all not persistable.
'''
is_input_all_not_persistable = True
input_name_list = self._op_real_in_out_name[op.type][0]
for input_name in input_name_list:
for var_name in op.input(input_name):
if var_name in persistable_var_names:
is_input_all_not_persistable = False
break
return is_input_all_not_persistable
def _get_kl_scaling_factor(self, activation_blob, num_quantized_bins=255): def _get_kl_scaling_factor(self, activation_blob, num_quantized_bins=255):
''' '''
Using the KL-divergenc method to get the more precise scaling factor. Using the KL-divergenc method to get the more precise scaling factor.
...@@ -441,7 +485,7 @@ class PostTrainingQuantization(object): ...@@ -441,7 +485,7 @@ class PostTrainingQuantization(object):
tmp_sum2 += 0 tmp_sum2 += 0
else: else:
if q_idx == 0: if q_idx == 0:
print("Fatal error!, idx = " + str(idx) + _logger.error("Fatal error!, idx = " + str(idx) +
" qindex = 0! p_idx = " + str(p_idx)) " qindex = 0! p_idx = " + str(p_idx))
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx)) tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx)) tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
......
...@@ -41,6 +41,40 @@ _out_scale_op_list = [ ...@@ -41,6 +41,40 @@ _out_scale_op_list = [
"dropout", "split", "prelu", "conv2d_transpose", "leaky_relu" "dropout", "split", "prelu", "conv2d_transpose", "leaky_relu"
] ]
# list op real input and output names, to avoid processing input such as AxisTensor.
_op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input"], ["Output"]],
"mul": [["X", "Y"], ["Out"]],
"pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]],
"concat": [["X"], ["Out"]],
"softmax": [["X"], ["Out"]],
"argmax": [["X"], ["Out"]],
"transpose": [["X"], ["Out"]],
"equal": [["X", "Y"], ["Out"]],
"gather": [["X"], ["Out"]],
"greater_equal": [["X", "Y"], ["Out"]],
"greater_than": [["X", "Y"], ["Out"]],
"less_equal": [["X", "Y"], ["Out"]],
"less_than": [["X", "Y"], ["Out"]],
"mean": [["X"], ["Out"]],
"not_equal": [["X", "Y"], ["Out"]],
"reshape": [["X"], ["Out"]],
"reshape2": [["X"], ["Out"]],
"bilinear_interp": [["X"], ["Out"]],
"nearest_interp": [["X"], ["Out"]],
"trilinear_interp": [["X"], ["Out"]],
"slice": [["Input"], ["Out"]],
"squeeze": [["X"], ["Out"]],
"elementwise_sub": [["X", "Y"], ["Out"]],
"relu": [["X"], ["Out"]],
"relu6": [["X"], ["Out"]],
"leaky_relu": [["X"], ["Out"]],
"tanh": [["X"], ["Out"]],
"swish": [["X"], ["Out"]],
}
def _init_var_node(var_node, value, scope, place): def _init_var_node(var_node, value, scope, place):
assert isinstance(value, assert isinstance(value,
...@@ -54,6 +88,8 @@ def _init_var_node(var_node, value, scope, place): ...@@ -54,6 +88,8 @@ def _init_var_node(var_node, value, scope, place):
class QuantizationTransformPass(object): class QuantizationTransformPass(object):
_supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
def __init__(self, def __init__(self,
scope=None, scope=None,
place=None, place=None,
...@@ -75,25 +111,27 @@ class QuantizationTransformPass(object): ...@@ -75,25 +111,27 @@ class QuantizationTransformPass(object):
initialize these new parameters. initialize these new parameters.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
parameters described above. parameters described above.
weight_bits (int): quantization bit number for weights, weight_bits(int): quantization bit number for weights,
the bias is not quantized. the bias is not quantized.
activation_bits (int): quantization bit number for activation. activation_bits(int): quantization bit number for activation.
activation_quantize_type (str): quantization type for activation, activation_quantize_type(str): quantization type for activation,
now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'. now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
If use 'abs_max' mode, the quantization scale will be calculated If use 'abs_max' mode, the quantization scale will be calculated
dynamically each step in both training and testing period. If use dynamically each step in both training and testing period. If use
'range_abs_max', a static quantization scale will be calculated 'range_abs_max', a static quantization scale will be calculated
during training and used in inference. during training and used in inference.
weight_quantize_type (str): quantization type for weights, weight_quantize_type(str): quantization type for weights,
support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max' support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max'
usually is not used for weight, since weights are fixed once the usually is not used for weight, since weights are fixed once the
model is well trained. model is well trained.
window_size (int): the window size for 'range_abs_max' quantization. window_size(int): the window size for 'range_abs_max' quantization.
moving_rate(float): the param for 'moving_average_abs_max' quantization.
skip_pattern(str): The user-defined quantization skip pattern, which skip_pattern(str): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized. detected in an op's name scope, the corresponding op will not be quantized.
quantizable_op_type(list[str]): List the type of ops that will be quantized. quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -139,9 +177,8 @@ class QuantizationTransformPass(object): ...@@ -139,9 +177,8 @@ class QuantizationTransformPass(object):
self._moving_rate = moving_rate self._moving_rate = moving_rate
self._quantizable_ops = quantizable_op_type self._quantizable_ops = quantizable_op_type
supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
for op in self._quantizable_ops: for op in self._quantizable_ops:
assert op in supported_quantizable_ops, \ assert op in QuantizationTransformPass._supported_quantizable_op_type, \
op + " is not supported for quantization." op + " is not supported for quantization."
self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._quantizable_grad_ops = [ self._quantizable_grad_ops = [
...@@ -158,6 +195,8 @@ class QuantizationTransformPass(object): ...@@ -158,6 +195,8 @@ class QuantizationTransformPass(object):
Args: Args:
graph(IrGraph): the applied graph. graph(IrGraph): the applied graph.
Returns:
None
""" """
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
...@@ -589,6 +628,16 @@ class QuantizationTransformPass(object): ...@@ -589,6 +628,16 @@ class QuantizationTransformPass(object):
class QuantizationFreezePass(object): class QuantizationFreezePass(object):
_supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type
def __init__(self,
scope,
place,
weight_bits=8,
activation_bits=8,
weight_quantize_type='abs_max',
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
""" """
The freeze pass is used to adjust the quantize operator order, for example: The freeze pass is used to adjust the quantize operator order, for example:
1) `activation -> quant -> dequant -> conv2d` will be freezed into 1) `activation -> quant -> dequant -> conv2d` will be freezed into
...@@ -599,22 +648,15 @@ class QuantizationFreezePass(object): ...@@ -599,22 +648,15 @@ class QuantizationFreezePass(object):
Args: Args:
scope(fluid.Scope): scope is used to get the weight tensor values. scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors. place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
weight_bits (int): quantization bit number for weights. weight_bits(int): quantization bit number for weights.
activation_bits (int): quantization bit number for activation. activation_bits(int): quantization bit number for activation.
weight_quantize_type (str): quantization type for weights, support 'abs_max' and weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained. since weights are fixed once the model is well trained.
quantizable_op_type(list[str]): List the type of ops that will be quantized. quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationTransformPass and ConvertToInt8Pass must be the same as this.
""" """
def __init__(self,
scope,
place,
weight_bits=8,
activation_bits=8,
weight_quantize_type='abs_max',
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
assert scope is not None, \ assert scope is not None, \
'The scope cannot be set None.' 'The scope cannot be set None.'
assert place is not None, \ assert place is not None, \
...@@ -625,9 +667,8 @@ class QuantizationFreezePass(object): ...@@ -625,9 +667,8 @@ class QuantizationFreezePass(object):
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = quantizable_op_type self._quantizable_ops = quantizable_op_type
supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
for op in self._quantizable_ops: for op in self._quantizable_ops:
assert op in supported_quantizable_ops, \ assert op in QuantizationFreezePass._supported_quantizable_op_type, \
op + " is not supported for quantization." op + " is not supported for quantization."
self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._fake_quant_op_names = _fake_quant_op_list self._fake_quant_op_names = _fake_quant_op_list
...@@ -642,6 +683,8 @@ class QuantizationFreezePass(object): ...@@ -642,6 +683,8 @@ class QuantizationFreezePass(object):
Args: Args:
graph(IrGraph): the applied graph. graph(IrGraph): the applied graph.
Returns:
None
""" """
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
...@@ -895,6 +938,13 @@ class QuantizationFreezePass(object): ...@@ -895,6 +938,13 @@ class QuantizationFreezePass(object):
class ConvertToInt8Pass(object): class ConvertToInt8Pass(object):
_supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type
def __init__(self,
scope,
place,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
""" """
Convert the weights into int8_t type. Convert the weights into int8_t type.
...@@ -903,13 +953,9 @@ class ConvertToInt8Pass(object): ...@@ -903,13 +953,9 @@ class ConvertToInt8Pass(object):
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
8bits weight tensors. 8bits weight tensors.
quantizable_op_type(list[str]): List the type of ops that will be quantized. quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationTransformPass and QuantizationFreezePass must be the same as this.
""" """
def __init__(self,
scope,
place,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
assert scope is not None, \ assert scope is not None, \
'The scope cannot be set None.' 'The scope cannot be set None.'
assert place is not None, \ assert place is not None, \
...@@ -917,9 +963,8 @@ class ConvertToInt8Pass(object): ...@@ -917,9 +963,8 @@ class ConvertToInt8Pass(object):
self._scope = scope self._scope = scope
self._place = place self._place = place
self._quantizable_ops = quantizable_op_type self._quantizable_ops = quantizable_op_type
supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
for op in self._quantizable_ops: for op in self._quantizable_ops:
assert op in supported_quantizable_ops, \ assert op in ConvertToInt8Pass._supported_quantizable_op_type, \
op + " is not supported for quantization." op + " is not supported for quantization."
def apply(self, graph): def apply(self, graph):
...@@ -929,6 +974,8 @@ class ConvertToInt8Pass(object): ...@@ -929,6 +974,8 @@ class ConvertToInt8Pass(object):
Args: Args:
graph(IrGraph): the applied graph. graph(IrGraph): the applied graph.
Returns:
None
""" """
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
...@@ -993,11 +1040,10 @@ class ConvertToInt8Pass(object): ...@@ -993,11 +1040,10 @@ class ConvertToInt8Pass(object):
class TransformForMobilePass(object): class TransformForMobilePass(object):
def __init__(self):
""" """
This pass is used to convert the freezed graph for paddle-mobile execution. This pass is used to convert the freezed graph for paddle-mobile execution.
""" """
def __init__(self):
self._fake_quant_op_names = _fake_quant_op_list self._fake_quant_op_names = _fake_quant_op_list
self._fake_dequant_op_names = _fake_dequant_op_list self._fake_dequant_op_names = _fake_dequant_op_list
...@@ -1009,6 +1055,8 @@ class TransformForMobilePass(object): ...@@ -1009,6 +1055,8 @@ class TransformForMobilePass(object):
Args: Args:
graph(IrGraph): the graph will be transformed. graph(IrGraph): the graph will be transformed.
Returns:
None
""" """
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
...@@ -1183,16 +1231,45 @@ class ScaleForInferencePass(object): ...@@ -1183,16 +1231,45 @@ class ScaleForInferencePass(object):
class AddQuantDequantPass(object): class AddQuantDequantPass(object):
_supported_quantizable_op_type = [
"pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose",
"equal", "gather", "greater_equal", "greater_than", "less_equal",
"less_than", "mean", "not_equal", "reshape", "reshape2",
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
"squeeze", "elementwise_sub"
]
_activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]
def __init__(self, def __init__(self,
scope=None, scope=None,
place=None, place=None,
moving_rate=0.9, moving_rate=0.9,
quant_bits=8, quant_bits=8,
skip_pattern='skip_quant', skip_pattern='skip_quant',
quantizable_op_type=["elementwise_add", "pool2d"]): quantizable_op_type=["elementwise_add", "pool2d", "concat"],
is_full_quantized=False):
""" """
This pass is used to add quant_dequant op for some ops, such as the This pass add quant_dequant op for some ops, of which all the inputs must be
'elementwise_add' and 'pool2d' op. not persistable.
The input scales can be obtained from the quant_dequant op.
Args:
scope(fluid.Scope): The scope is used to initialize these new parameters.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
parameters described above.
moving_rate(float, optional): the param for 'quant_dequant_moving_average_abs_max'
quantization. Default is 0.9.
quant_bits(int, optional): quantization bit number for activation. Default is 8.
skip_pattern(str, optional): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized.
Default is 'skip_quant'.
quantizable_op_type(list[str], optional): List the type of ops that will be
quantized. Default is ["elementwise_add", "pool2d", "concat"].
is_full_quantized(bool, optional): If set is_full_quantized as True, apply
quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input
quantizable_op_type.
""" """
self._scope = scope self._scope = scope
self._place = place self._place = place
...@@ -1200,60 +1277,67 @@ class AddQuantDequantPass(object): ...@@ -1200,60 +1277,67 @@ class AddQuantDequantPass(object):
self._quant_bits = quant_bits self._quant_bits = quant_bits
self._is_test = None self._is_test = None
self._skip_pattern = skip_pattern self._skip_pattern = skip_pattern
if is_full_quantized:
self._quantizable_op_type = \
AddQuantDequantPass._supported_quantizable_op_type
else:
self._quantizable_op_type = quantizable_op_type self._quantizable_op_type = quantizable_op_type
for op_type in quantizable_op_type:
assert op_type in AddQuantDequantPass._supported_quantizable_op_type + \
AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization."
self._quantizable_grad_op_type = [ self._quantizable_grad_op_type = [
'%s_grad' % (op) for op in self._quantizable_op_type '%s_grad' % (op) for op in self._quantizable_op_type
] ]
supported_quantizable_op_type = ["elementwise_add", "pool2d"] assert self._scope != None, "scope must not be None."
for op_type in quantizable_op_type: assert self._place != None, "place must not be None."
assert op_type in supported_quantizable_op_type, \
op_type + " is not supported for quantization."
def apply(self, graph): def apply(self, graph):
""" """
Add quant_dequant before some ops, such as the 'elementwise_add' Add quant_dequant before some ops, such as the 'elementwise_add',
and 'pool2d' op. 'pool2d' and 'concat' op.
Args: Args:
graph(IrGraph): the target graph. graph(IrGraph): the target graph.
Returns:
None
""" """
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test() self._is_test = graph.is_test()
dequantized_vars_map = collections.OrderedDict() dequantized_vars_map = collections.OrderedDict()
ops = graph.all_op_nodes()
for op_node in ops: # Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes()
for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type: if op_node.name() in self._quantizable_op_type:
if isinstance(self._skip_pattern, str) and \ if isinstance(self._skip_pattern, str) and \
op_node.op().has_attr("op_namescope") and \ op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1: op_node.op().attr("op_namescope").find(self._skip_pattern) != -1:
continue continue
in_nodes_all_not_persistable = True if not self._is_input_all_not_persistable(graph, op_node):
for input_name in op_node.input_arg_names():
in_node = graph._find_node_by_name(op_node.inputs,
input_name)
in_nodes_all_not_persistable = (
in_nodes_all_not_persistable and
not in_node.persistable())
if not in_nodes_all_not_persistable:
continue continue
input_names = op_node.input_arg_names() input_name_list = _op_real_in_out_name[op_node.name()][0]
for input_name in input_names: for input_name in input_name_list:
for arg_name in op_node.input(input_name):
in_node = graph._find_node_by_name(op_node.inputs, in_node = graph._find_node_by_name(op_node.inputs,
input_name) arg_name)
if input_name in dequantized_vars_map: if arg_name in dequantized_vars_map:
quant_var_node = dequantized_vars_map[input_name] quant_var_node = dequantized_vars_map[arg_name]
else: else:
quant_var_node, scale_var_node = \ quant_var_node, _ = \
self._inser_quant_dequant_moving_average_abs_max_op( self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits) graph, in_node, self._quant_bits)
dequantized_vars_map[input_name] = quant_var_node dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node, op_node) graph.update_input_link(in_node, quant_var_node,
op_node)
for op_node in ops: # Backward stage, update input link
for op_node in all_op_nodes:
if op_node.name() in self._quantizable_grad_op_type: if op_node.name() in self._quantizable_grad_op_type:
for input_name in op_node.input_arg_names(): for input_name in op_node.input_arg_names():
if input_name in dequantized_vars_map: if input_name in dequantized_vars_map:
...@@ -1266,6 +1350,21 @@ class AddQuantDequantPass(object): ...@@ -1266,6 +1350,21 @@ class AddQuantDequantPass(object):
graph.resolve_hazard() graph.resolve_hazard()
return graph return graph
def _is_input_all_not_persistable(self, graph, op_node):
'''
Analyse the real inputs of the op node are all not persistable.
'''
is_input_all_not_persistable = True
op_node_name = op_node.name()
input_name_list = _op_real_in_out_name[op_node_name][0]
for input_name in input_name_list:
for arg_name in op_node.input(input_name):
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
is_input_all_not_persistable = (is_input_all_not_persistable and \
(not in_node.persistable()))
return is_input_all_not_persistable
def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node, def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
quant_bits): quant_bits):
"""Insert fake_quantize_dequantize_moving_average_abs_max op. """Insert fake_quantize_dequantize_moving_average_abs_max op.
......
...@@ -233,7 +233,10 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -233,7 +233,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
acc1 = np.sum(test_info) / cnt acc1 = np.sum(test_info) / cnt
return (throughput, latency, acc1) return (throughput, latency, acc1)
def generate_quantized_model(self, model_path, algo="KL"): def generate_quantized_model(self,
model_path,
algo="KL",
is_full_quantize=False):
self.int8_model = os.path.join(os.getcwd(), self.int8_model = os.path.join(os.getcwd(),
"post_training_" + self.timestamp) "post_training_" + self.timestamp)
try: try:
...@@ -257,7 +260,8 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -257,7 +260,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path=model_path, model_path=model_path,
data_reader=val_reader, data_reader=val_reader,
algo=algo, algo=algo,
quantizable_op_type=quantizable_op_type) quantizable_op_type=quantizable_op_type,
is_full_quantize=is_full_quantize)
ptq.quantize() ptq.quantize()
ptq.save_quantized_model(self.int8_model) ptq.save_quantized_model(self.int8_model)
...@@ -285,7 +289,9 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization): ...@@ -285,7 +289,9 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
print("Start INT8 post training quantization for {0} on {1} images ...". print("Start INT8 post training quantization for {0} on {1} images ...".
format(self.model, self.sample_iterations * self.batch_size)) format(self.model, self.sample_iterations * self.batch_size))
self.generate_quantized_model( self.generate_quantized_model(
self.model_cache_folder + "/model", algo=self.algo) self.model_cache_folder + "/model",
algo=self.algo,
is_full_quantize=True)
print("Start INT8 inference for {0} on {1} images ...".format( print("Start INT8 inference for {0} on {1} images ...".format(
self.model, self.infer_iterations * self.batch_size)) self.model, self.infer_iterations * self.batch_size))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册