未验证 提交 00b11a4a 编写于 作者: J juncaipeng 提交者: GitHub

Support more ops in post training quantization, test=develop (#21073)

* Support  more ops in post training quantization, and save the output scale in quantized op.
* Update docs in post training quantization and qat 
上级 23876de5
......@@ -23,6 +23,7 @@ from ....log_helper import get_logger
from .quantization_pass import QuantizationTransformPass
from .quantization_pass import QuantizationFreezePass
from .quantization_pass import AddQuantDequantPass
from .quantization_pass import _op_real_in_out_name
__all__ = ['PostTrainingQuantization']
......@@ -39,10 +40,8 @@ class PostTrainingQuantization(object):
batch_nums=None,
scope=None,
algo="KL",
quantizable_op_type=[
"conv2d", "depthwise_conv2d", "mul", "pool2d",
"elementwise_add"
]):
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False):
'''
The class utilizes post training quantization methon to quantize the
fp32 model. It uses calibrate data to calculate the scale factor of
......@@ -66,7 +65,14 @@ class PostTrainingQuantization(object):
abs_max methon to get the scale factor. Default is KL.
quantizable_op_type(list[str], optional): List the type of ops
that will be quantized. Default is ["conv2d", "depthwise_conv2d",
"mul", "pool2d", "elementwise_add"].
"mul"].
is_full_quantized(bool, optional): If set is_full_quantized as True,
apply quantization to all supported quantizable op type. If set
is_full_quantized as False, only apply quantization to the op type
according to the input quantizable_op_type.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
......@@ -98,13 +104,18 @@ class PostTrainingQuantization(object):
self._batch_size = batch_size
self._batch_nums = batch_nums
self._scope = global_scope() if scope == None else scope
self._quantizable_op_type = quantizable_op_type
self._algo = algo
supported_quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type + \
AddQuantDequantPass._supported_quantizable_op_type
if is_full_quantize:
self._quantizable_op_type = supported_quantizable_op_type
else:
self._quantizable_op_type = quantizable_op_type
for op_type in self._quantizable_op_type:
assert op_type in supported_quantizable_op_type, \
assert op_type in supported_quantizable_op_type + \
AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization."
self._place = self._executor.place
......@@ -113,6 +124,7 @@ class PostTrainingQuantization(object):
self._fetch_list = None
self._data_loader = None
self._op_real_in_out_name = _op_real_in_out_name
self._bit_length = 8
self._quantized_weight_var_name = []
self._quantized_act_var_name = []
......@@ -125,10 +137,12 @@ class PostTrainingQuantization(object):
quantized variables, and inserts fake quant/dequant op to obtain the
quantized model.
Return:
Args:
None
Returns:
the program of quantized model.
'''
self._prepare()
self._preprocess()
batch_id = 0
for data in self._data_loader():
......@@ -136,7 +150,6 @@ class PostTrainingQuantization(object):
feed=data,
fetch_list=self._fetch_list)
self._sample_data()
if batch_id % 5 == 0:
_logger.info("run batch: " + str(batch_id))
batch_id += 1
......@@ -144,9 +157,13 @@ class PostTrainingQuantization(object):
break
_logger.info("all run batch: " + str(batch_id))
_logger.info("calculate scale factor ...")
self._calculate_scale_factor()
_logger.info("update the program ...")
self._update_program()
self._save_output_scale()
return self._program
def save_quantized_model(self, save_model_path):
......@@ -155,7 +172,7 @@ class PostTrainingQuantization(object):
Args:
save_model_path(str): The path to save the quantized model
Return:
Returns:
None
'''
io.save_inference_model(
......@@ -165,7 +182,7 @@ class PostTrainingQuantization(object):
executor=self._executor,
main_program=self._program)
def _prepare(self):
def _preprocess(self):
'''
Load model and set data loader, collect the variable names for sampling,
and set activation variables to be persistable.
......@@ -183,14 +200,13 @@ class PostTrainingQuantization(object):
drop_last=True,
places=self._place)
#collect the variable names for sampling
# collect the variable names for sampling
persistable_var_names = []
for var in self._program.list_vars():
if var.persistable:
persistable_var_names.append(var.name)
block = self._program.global_block()
for op in block.ops:
for op in self._program.global_block().ops:
op_type = op.type
if op_type in self._quantizable_op_type:
if op_type in ("conv2d", "depthwise_conv2d"):
......@@ -199,29 +215,30 @@ class PostTrainingQuantization(object):
op.input("Filter")[0])
self._quantized_act_var_name.append(op.output("Output")[0])
elif op_type == "mul":
x_var_name = op.input("X")[0]
y_var_name = op.input("Y")[0]
if x_var_name not in persistable_var_names and \
y_var_name not in persistable_var_names:
if self._is_input_all_not_persistable(
op, persistable_var_names):
op._set_attr("skip_quant", True)
_logger.warning("A mul op skip quant for two "
_logger.warning("Skip quant a mul op for two "
"input variables are not persistable")
else:
self._quantized_act_var_name.append(x_var_name)
self._quantized_weight_var_name.append(y_var_name)
self._quantized_act_var_name.append(op.output("Out")[0])
elif op_type == "pool2d":
self._quantized_act_var_name.append(op.input("X")[0])
elif op_type == "elementwise_add":
x_var_name = op.input("X")[0]
y_var_name = op.input("Y")[0]
if x_var_name not in persistable_var_names and \
y_var_name not in persistable_var_names:
self._quantized_act_var_name.append(x_var_name)
self._quantized_act_var_name.append(y_var_name)
# set activation variables to be persistable,
# so can obtain the tensor data in sample_data stage
self._quantized_weight_var_name.append(op.input("Y")[0])
self._quantized_act_var_name.append(op.output("Out")[0])
else:
# process other quantizable op type, the input must all not persistable
if self._is_input_all_not_persistable(
op, persistable_var_names):
input_output_name_list = self._op_real_in_out_name[
op_type]
for input_name in input_output_name_list[0]:
for var_name in op.input(input_name):
self._quantized_act_var_name.append(var_name)
for output_name in input_output_name_list[1]:
for var_name in op.output(output_name):
self._quantized_act_var_name.append(var_name)
# set activation variables to be persistable, so can obtain
# the tensor data in sample_data
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = True
......@@ -246,8 +263,7 @@ class PostTrainingQuantization(object):
'''
Calculate the scale factor of quantized variables.
'''
_logger.info("calculate scale factor ...")
# apply channel_wise_abs_max quantization for weights
for var_name in self._quantized_weight_var_name:
data = self._sampling_data[var_name]
scale_factor_per_channel = []
......@@ -257,6 +273,7 @@ class PostTrainingQuantization(object):
self._quantized_var_scale_factor[
var_name] = scale_factor_per_channel
# apply kl quantization for activation
for var_name in self._quantized_act_var_name:
if self._algo == "KL":
self._quantized_var_scale_factor[var_name] = \
......@@ -269,8 +286,7 @@ class PostTrainingQuantization(object):
'''
Insert fake_quantize/fake_dequantize op to the program.
'''
_logger.info("update the program ...")
# reset quantized activation variable
for var in self._program.list_vars():
if var.name in self._quantized_act_var_name:
var.persistable = False
......@@ -278,10 +294,10 @@ class PostTrainingQuantization(object):
# use QuantizationTransformPass to insert fake_quantize/fake_dequantize op
graph = IrGraph(core.Graph(self._program.desc), for_test=True)
qtp_quantizable_op_type = []
for op_type in ["conv2d", "depthwise_conv2d", "mul"]:
major_quantizable_op_types = []
for op_type in QuantizationTransformPass._supported_quantizable_op_type:
if op_type in self._quantizable_op_type:
qtp_quantizable_op_type.append(op_type)
major_quantizable_op_types.append(op_type)
transform_pass = QuantizationTransformPass(
scope=self._scope,
place=self._place,
......@@ -289,18 +305,18 @@ class PostTrainingQuantization(object):
activation_bits=self._bit_length,
activation_quantize_type='moving_average_abs_max',
weight_quantize_type='channel_wise_abs_max',
quantizable_op_type=qtp_quantizable_op_type)
quantizable_op_type=major_quantizable_op_types)
transform_pass.apply(graph)
# use AddQuantDequantPass to insert fake_quant_dequant op
aqdp_quantizable_op_type = []
for op_type in ["pool2d", "elementwise_add"]:
minor_quantizable_op_types = []
for op_type in AddQuantDequantPass._supported_quantizable_op_type:
if op_type in self._quantizable_op_type:
aqdp_quantizable_op_type.append(op_type)
minor_quantizable_op_types.append(op_type)
add_quant_dequant_pass = AddQuantDequantPass(
scope=self._scope,
place=self._place,
quantizable_op_type=aqdp_quantizable_op_type)
quantizable_op_type=minor_quantizable_op_types)
add_quant_dequant_pass.apply(graph)
# save scale factor to scale var node
......@@ -319,10 +335,25 @@ class PostTrainingQuantization(object):
weight_bits=self._bit_length,
activation_bits=self._bit_length,
weight_quantize_type='channel_wise_abs_max',
quantizable_op_type=qtp_quantizable_op_type)
quantizable_op_type=major_quantizable_op_types)
freeze_pass.apply(graph)
self._program = graph.to_program()
def _save_output_scale(self):
'''
Save output scale to the quantized op.
'''
output_scale_name = "output_scale"
for op in self._program.global_block().ops:
if op.type in self._quantizable_op_type:
output_name_list = self._op_real_in_out_name[op.type][1]
for output_name in output_name_list:
output_var_name = op.output(output_name)[0]
if output_var_name in self._quantized_var_scale_factor:
op._set_attr(
output_scale_name,
self._quantized_var_scale_factor[output_var_name])
def _load_var_value(self, var_name):
'''
Load variable value from scope
......@@ -331,7 +362,7 @@ class PostTrainingQuantization(object):
def _set_var_node_value(self, var_node_name, np_value):
'''
Set the value of var node by name, if the node is not exits,
Set the value of var node by name, if the node exits,
'''
assert isinstance(np_value, np.ndarray), \
'The type of value should be numpy array.'
......@@ -340,6 +371,19 @@ class PostTrainingQuantization(object):
tensor = var_node.get_tensor()
tensor.set(np_value, self._place)
def _is_input_all_not_persistable(self, op, persistable_var_names):
'''
Analyze the real inputs of the op are all not persistable.
'''
is_input_all_not_persistable = True
input_name_list = self._op_real_in_out_name[op.type][0]
for input_name in input_name_list:
for var_name in op.input(input_name):
if var_name in persistable_var_names:
is_input_all_not_persistable = False
break
return is_input_all_not_persistable
def _get_kl_scaling_factor(self, activation_blob, num_quantized_bins=255):
'''
Using the KL-divergenc method to get the more precise scaling factor.
......@@ -441,7 +485,7 @@ class PostTrainingQuantization(object):
tmp_sum2 += 0
else:
if q_idx == 0:
print("Fatal error!, idx = " + str(idx) +
_logger.error("Fatal error!, idx = " + str(idx) +
" qindex = 0! p_idx = " + str(p_idx))
tmp_sum1 += p_idx * (math.log(Q_sum * p_idx))
tmp_sum2 += p_idx * (math.log(P_sum * q_idx))
......
......@@ -41,6 +41,40 @@ _out_scale_op_list = [
"dropout", "split", "prelu", "conv2d_transpose", "leaky_relu"
]
# list op real input and output names, to avoid processing input such as AxisTensor.
_op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input"], ["Output"]],
"mul": [["X", "Y"], ["Out"]],
"pool2d": [["X"], ["Out"]],
"elementwise_add": [["X", "Y"], ["Out"]],
"concat": [["X"], ["Out"]],
"softmax": [["X"], ["Out"]],
"argmax": [["X"], ["Out"]],
"transpose": [["X"], ["Out"]],
"equal": [["X", "Y"], ["Out"]],
"gather": [["X"], ["Out"]],
"greater_equal": [["X", "Y"], ["Out"]],
"greater_than": [["X", "Y"], ["Out"]],
"less_equal": [["X", "Y"], ["Out"]],
"less_than": [["X", "Y"], ["Out"]],
"mean": [["X"], ["Out"]],
"not_equal": [["X", "Y"], ["Out"]],
"reshape": [["X"], ["Out"]],
"reshape2": [["X"], ["Out"]],
"bilinear_interp": [["X"], ["Out"]],
"nearest_interp": [["X"], ["Out"]],
"trilinear_interp": [["X"], ["Out"]],
"slice": [["Input"], ["Out"]],
"squeeze": [["X"], ["Out"]],
"elementwise_sub": [["X", "Y"], ["Out"]],
"relu": [["X"], ["Out"]],
"relu6": [["X"], ["Out"]],
"leaky_relu": [["X"], ["Out"]],
"tanh": [["X"], ["Out"]],
"swish": [["X"], ["Out"]],
}
def _init_var_node(var_node, value, scope, place):
assert isinstance(value,
......@@ -54,6 +88,8 @@ def _init_var_node(var_node, value, scope, place):
class QuantizationTransformPass(object):
_supported_quantizable_op_type = ['conv2d', 'depthwise_conv2d', 'mul']
def __init__(self,
scope=None,
place=None,
......@@ -75,25 +111,27 @@ class QuantizationTransformPass(object):
initialize these new parameters.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
parameters described above.
weight_bits (int): quantization bit number for weights,
weight_bits(int): quantization bit number for weights,
the bias is not quantized.
activation_bits (int): quantization bit number for activation.
activation_quantize_type (str): quantization type for activation,
activation_bits(int): quantization bit number for activation.
activation_quantize_type(str): quantization type for activation,
now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
If use 'abs_max' mode, the quantization scale will be calculated
dynamically each step in both training and testing period. If use
'range_abs_max', a static quantization scale will be calculated
during training and used in inference.
weight_quantize_type (str): quantization type for weights,
weight_quantize_type(str): quantization type for weights,
support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max'
usually is not used for weight, since weights are fixed once the
model is well trained.
window_size (int): the window size for 'range_abs_max' quantization.
window_size(int): the window size for 'range_abs_max' quantization.
moving_rate(float): the param for 'moving_average_abs_max' quantization.
skip_pattern(str): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"].
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
Examples:
.. code-block:: python
......@@ -139,9 +177,8 @@ class QuantizationTransformPass(object):
self._moving_rate = moving_rate
self._quantizable_ops = quantizable_op_type
supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
for op in self._quantizable_ops:
assert op in supported_quantizable_ops, \
assert op in QuantizationTransformPass._supported_quantizable_op_type, \
op + " is not supported for quantization."
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._quantizable_grad_ops = [
......@@ -158,6 +195,8 @@ class QuantizationTransformPass(object):
Args:
graph(IrGraph): the applied graph.
Returns:
None
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
......@@ -589,6 +628,16 @@ class QuantizationTransformPass(object):
class QuantizationFreezePass(object):
_supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type
def __init__(self,
scope,
place,
weight_bits=8,
activation_bits=8,
weight_quantize_type='abs_max',
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
"""
The freeze pass is used to adjust the quantize operator order, for example:
1) `activation -> quant -> dequant -> conv2d` will be freezed into
......@@ -599,22 +648,15 @@ class QuantizationFreezePass(object):
Args:
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
weight_bits (int): quantization bit number for weights.
activation_bits (int): quantization bit number for activation.
weight_quantize_type (str): quantization type for weights, support 'abs_max' and
weight_bits(int): quantization bit number for weights.
activation_bits(int): quantization bit number for activation.
weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"].
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationTransformPass and ConvertToInt8Pass must be the same as this.
"""
def __init__(self,
scope,
place,
weight_bits=8,
activation_bits=8,
weight_quantize_type='abs_max',
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
assert scope is not None, \
'The scope cannot be set None.'
assert place is not None, \
......@@ -625,9 +667,8 @@ class QuantizationFreezePass(object):
self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = quantizable_op_type
supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
for op in self._quantizable_ops:
assert op in supported_quantizable_ops, \
assert op in QuantizationFreezePass._supported_quantizable_op_type, \
op + " is not supported for quantization."
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._fake_quant_op_names = _fake_quant_op_list
......@@ -642,6 +683,8 @@ class QuantizationFreezePass(object):
Args:
graph(IrGraph): the applied graph.
Returns:
None
"""
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_op_nodes()
......@@ -895,6 +938,13 @@ class QuantizationFreezePass(object):
class ConvertToInt8Pass(object):
_supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type
def __init__(self,
scope,
place,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
"""
Convert the weights into int8_t type.
......@@ -903,13 +953,9 @@ class ConvertToInt8Pass(object):
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
8bits weight tensors.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"].
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationTransformPass and QuantizationFreezePass must be the same as this.
"""
def __init__(self,
scope,
place,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
assert scope is not None, \
'The scope cannot be set None.'
assert place is not None, \
......@@ -917,9 +963,8 @@ class ConvertToInt8Pass(object):
self._scope = scope
self._place = place
self._quantizable_ops = quantizable_op_type
supported_quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
for op in self._quantizable_ops:
assert op in supported_quantizable_ops, \
assert op in ConvertToInt8Pass._supported_quantizable_op_type, \
op + " is not supported for quantization."
def apply(self, graph):
......@@ -929,6 +974,8 @@ class ConvertToInt8Pass(object):
Args:
graph(IrGraph): the applied graph.
Returns:
None
"""
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_op_nodes()
......@@ -993,11 +1040,10 @@ class ConvertToInt8Pass(object):
class TransformForMobilePass(object):
def __init__(self):
"""
This pass is used to convert the freezed graph for paddle-mobile execution.
"""
def __init__(self):
self._fake_quant_op_names = _fake_quant_op_list
self._fake_dequant_op_names = _fake_dequant_op_list
......@@ -1009,6 +1055,8 @@ class TransformForMobilePass(object):
Args:
graph(IrGraph): the graph will be transformed.
Returns:
None
"""
ops = graph.all_op_nodes()
for op_node in ops:
......@@ -1183,16 +1231,45 @@ class ScaleForInferencePass(object):
class AddQuantDequantPass(object):
_supported_quantizable_op_type = [
"pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose",
"equal", "gather", "greater_equal", "greater_than", "less_equal",
"less_than", "mean", "not_equal", "reshape", "reshape2",
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
"squeeze", "elementwise_sub"
]
_activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]
def __init__(self,
scope=None,
place=None,
moving_rate=0.9,
quant_bits=8,
skip_pattern='skip_quant',
quantizable_op_type=["elementwise_add", "pool2d"]):
quantizable_op_type=["elementwise_add", "pool2d", "concat"],
is_full_quantized=False):
"""
This pass is used to add quant_dequant op for some ops, such as the
'elementwise_add' and 'pool2d' op.
This pass add quant_dequant op for some ops, of which all the inputs must be
not persistable.
The input scales can be obtained from the quant_dequant op.
Args:
scope(fluid.Scope): The scope is used to initialize these new parameters.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to initialize new
parameters described above.
moving_rate(float, optional): the param for 'quant_dequant_moving_average_abs_max'
quantization. Default is 0.9.
quant_bits(int, optional): quantization bit number for activation. Default is 8.
skip_pattern(str, optional): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized.
Default is 'skip_quant'.
quantizable_op_type(list[str], optional): List the type of ops that will be
quantized. Default is ["elementwise_add", "pool2d", "concat"].
is_full_quantized(bool, optional): If set is_full_quantized as True, apply
quantization to all supported quantizable op type. If set is_full_quantized
as False, only apply quantization to the op type according to the input
quantizable_op_type.
"""
self._scope = scope
self._place = place
......@@ -1200,60 +1277,67 @@ class AddQuantDequantPass(object):
self._quant_bits = quant_bits
self._is_test = None
self._skip_pattern = skip_pattern
if is_full_quantized:
self._quantizable_op_type = \
AddQuantDequantPass._supported_quantizable_op_type
else:
self._quantizable_op_type = quantizable_op_type
for op_type in quantizable_op_type:
assert op_type in AddQuantDequantPass._supported_quantizable_op_type + \
AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization."
self._quantizable_grad_op_type = [
'%s_grad' % (op) for op in self._quantizable_op_type
]
supported_quantizable_op_type = ["elementwise_add", "pool2d"]
for op_type in quantizable_op_type:
assert op_type in supported_quantizable_op_type, \
op_type + " is not supported for quantization."
assert self._scope != None, "scope must not be None."
assert self._place != None, "place must not be None."
def apply(self, graph):
"""
Add quant_dequant before some ops, such as the 'elementwise_add'
and 'pool2d' op.
Add quant_dequant before some ops, such as the 'elementwise_add',
'pool2d' and 'concat' op.
Args:
graph(IrGraph): the target graph.
Returns:
None
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
self._is_test = graph.is_test()
dequantized_vars_map = collections.OrderedDict()
ops = graph.all_op_nodes()
for op_node in ops:
# Forward stage, insert quant_dequant op
all_op_nodes = graph.all_op_nodes()
for op_node in all_op_nodes:
if op_node.name() in self._quantizable_op_type:
if isinstance(self._skip_pattern, str) and \
op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1:
continue
in_nodes_all_not_persistable = True
for input_name in op_node.input_arg_names():
in_node = graph._find_node_by_name(op_node.inputs,
input_name)
in_nodes_all_not_persistable = (
in_nodes_all_not_persistable and
not in_node.persistable())
if not in_nodes_all_not_persistable:
if not self._is_input_all_not_persistable(graph, op_node):
continue
input_names = op_node.input_arg_names()
for input_name in input_names:
input_name_list = _op_real_in_out_name[op_node.name()][0]
for input_name in input_name_list:
for arg_name in op_node.input(input_name):
in_node = graph._find_node_by_name(op_node.inputs,
input_name)
if input_name in dequantized_vars_map:
quant_var_node = dequantized_vars_map[input_name]
arg_name)
if arg_name in dequantized_vars_map:
quant_var_node = dequantized_vars_map[arg_name]
else:
quant_var_node, scale_var_node = \
quant_var_node, _ = \
self._inser_quant_dequant_moving_average_abs_max_op(
graph, in_node, self._quant_bits)
dequantized_vars_map[input_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node, op_node)
dequantized_vars_map[arg_name] = quant_var_node
graph.update_input_link(in_node, quant_var_node,
op_node)
for op_node in ops:
# Backward stage, update input link
for op_node in all_op_nodes:
if op_node.name() in self._quantizable_grad_op_type:
for input_name in op_node.input_arg_names():
if input_name in dequantized_vars_map:
......@@ -1266,6 +1350,21 @@ class AddQuantDequantPass(object):
graph.resolve_hazard()
return graph
def _is_input_all_not_persistable(self, graph, op_node):
'''
Analyse the real inputs of the op node are all not persistable.
'''
is_input_all_not_persistable = True
op_node_name = op_node.name()
input_name_list = _op_real_in_out_name[op_node_name][0]
for input_name in input_name_list:
for arg_name in op_node.input(input_name):
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
is_input_all_not_persistable = (is_input_all_not_persistable and \
(not in_node.persistable()))
return is_input_all_not_persistable
def _inser_quant_dequant_moving_average_abs_max_op(self, graph, var_node,
quant_bits):
"""Insert fake_quantize_dequantize_moving_average_abs_max op.
......
......@@ -233,7 +233,10 @@ class TestPostTrainingQuantization(unittest.TestCase):
acc1 = np.sum(test_info) / cnt
return (throughput, latency, acc1)
def generate_quantized_model(self, model_path, algo="KL"):
def generate_quantized_model(self,
model_path,
algo="KL",
is_full_quantize=False):
self.int8_model = os.path.join(os.getcwd(),
"post_training_" + self.timestamp)
try:
......@@ -257,7 +260,8 @@ class TestPostTrainingQuantization(unittest.TestCase):
model_path=model_path,
data_reader=val_reader,
algo=algo,
quantizable_op_type=quantizable_op_type)
quantizable_op_type=quantizable_op_type,
is_full_quantize=is_full_quantize)
ptq.quantize()
ptq.save_quantized_model(self.int8_model)
......@@ -285,7 +289,9 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
print("Start INT8 post training quantization for {0} on {1} images ...".
format(self.model, self.sample_iterations * self.batch_size))
self.generate_quantized_model(
self.model_cache_folder + "/model", algo=self.algo)
self.model_cache_folder + "/model",
algo=self.algo,
is_full_quantize=True)
print("Start INT8 inference for {0} on {1} images ...".format(
self.model, self.infer_iterations * self.batch_size))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册