未验证 提交 589cd878 编写于 作者: C cc 提交者: GitHub

Post_training_quantizaion supports min_max methon (#23078)

* Post_training_quantizaion supports min_max methon
上级 194a22c5
...@@ -35,6 +35,10 @@ _fake_dequant_op_list = [ ...@@ -35,6 +35,10 @@ _fake_dequant_op_list = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs' 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
] ]
_fake_quant_dequant_op_list = [
'fake_quantize_dequantize_moving_average_abs_max'
]
_out_scale_op_list = [ _out_scale_op_list = [
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d", "mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d",
"batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul", "batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul",
...@@ -44,7 +48,7 @@ _out_scale_op_list = [ ...@@ -44,7 +48,7 @@ _out_scale_op_list = [
# list op real input and output names, to avoid processing input such as AxisTensor. # list op real input and output names, to avoid processing input such as AxisTensor.
_op_real_in_out_name = { _op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]], "conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input"], ["Output"]], "depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"mul": [["X", "Y"], ["Out"]], "mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]], "matmul": [["X", "Y"], ["Out"]],
"pool2d": [["X"], ["Out"]], "pool2d": [["X"], ["Out"]],
...@@ -236,6 +240,7 @@ class QuantizationTransformPass(object): ...@@ -236,6 +240,7 @@ class QuantizationTransformPass(object):
op_node.op()._set_attr("skip_quant", True) op_node.op()._set_attr("skip_quant", True)
def _transform_forward(graph, op): def _transform_forward(graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight")
for var_node in op.inputs: for var_node in op.inputs:
if var_node.name() not in op.input_arg_names(): if var_node.name() not in op.input_arg_names():
continue continue
...@@ -290,7 +295,7 @@ class QuantizationTransformPass(object): ...@@ -290,7 +295,7 @@ class QuantizationTransformPass(object):
# The loop for transforming the forward graph: # The loop for transforming the forward graph:
for op in ops: for op in ops:
if op.name() in self._quantizable_ops: if op.name() in self._quantizable_ops:
if not QuantizationTransformPass._is_skip_quant(graph, op): if not self._is_skip_quant(graph, op):
_transform_forward(graph, op) _transform_forward(graph, op)
# The loop for renaming the inputs of backward op. # The loop for renaming the inputs of backward op.
for op in ops: for op in ops:
...@@ -636,8 +641,7 @@ class QuantizationTransformPass(object): ...@@ -636,8 +641,7 @@ class QuantizationTransformPass(object):
""" """
return "%s.scale" % (var_name) return "%s.scale" % (var_name)
@staticmethod def _is_skip_quant(self, graph, op_node):
def _is_skip_quant(graph, op_node):
""" """
Analyse whether the op node skips quantization. Analyse whether the op node skips quantization.
""" """
...@@ -650,20 +654,20 @@ class QuantizationTransformPass(object): ...@@ -650,20 +654,20 @@ class QuantizationTransformPass(object):
if op_node.name() in ["mul", "matmul"] and \ if op_node.name() in ["mul", "matmul"] and \
_is_input_all_not_persistable(graph, op_node): _is_input_all_not_persistable(graph, op_node):
is_skip = True is_skip = True
if op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_without_weight":
is_skip = True
return is_skip return is_skip
class QuantizationFreezePass(object): class QuantizationFreezePass(object):
_supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type
def __init__(self, def __init__(self,
scope, scope,
place, place,
weight_bits=8, weight_bits=8,
activation_bits=8, activation_bits=8,
weight_quantize_type='abs_max', weight_quantize_type='abs_max',
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']): quantizable_op_type=None):
""" """
The freeze pass is used to adjust the quantize operator order, for example: The freeze pass is used to adjust the quantize operator order, for example:
1) `activation -> quant -> dequant -> conv2d` will be frozen into 1) `activation -> quant -> dequant -> conv2d` will be frozen into
...@@ -679,9 +683,8 @@ class QuantizationFreezePass(object): ...@@ -679,9 +683,8 @@ class QuantizationFreezePass(object):
weight_quantize_type(str): quantization type for weights, support 'abs_max' and weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained. since weights are fixed once the model is well trained.
quantizable_op_type(list[str]): List the type of ops that will be quantized. quantizable_op_type(list[str]): This input param will be removed latter. The pass
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in will process all quantized op, so it is not necessary to set the input param.
QuantizationTransformPass and ConvertToInt8Pass must be the same as this.
""" """
assert scope is not None, \ assert scope is not None, \
'The scope cannot be set None.' 'The scope cannot be set None.'
...@@ -692,16 +695,12 @@ class QuantizationFreezePass(object): ...@@ -692,16 +695,12 @@ class QuantizationFreezePass(object):
self._weight_bits = weight_bits self._weight_bits = weight_bits
self._activation_bits = activation_bits self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = quantizable_op_type
for op in self._quantizable_ops:
assert op in QuantizationFreezePass._supported_quantizable_op_type, \
op + " is not supported for quantization."
self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._fake_quant_op_names = _fake_quant_op_list self._fake_quant_op_names = _fake_quant_op_list
self._fake_dequant_op_names = _fake_dequant_op_list self._fake_dequant_op_names = _fake_dequant_op_list
self._op_input_rename_map = collections.OrderedDict() self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict() self._op_output_rename_map = collections.OrderedDict()
self._var_scale_map = collections.OrderedDict() self._quant_var_scale_map = collections.OrderedDict()
def apply(self, graph): def apply(self, graph):
""" """
...@@ -712,6 +711,7 @@ class QuantizationFreezePass(object): ...@@ -712,6 +711,7 @@ class QuantizationFreezePass(object):
Returns: Returns:
None None
""" """
# Get input scales in fake quant op and process weights
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
...@@ -733,7 +733,7 @@ class QuantizationFreezePass(object): ...@@ -733,7 +733,7 @@ class QuantizationFreezePass(object):
else: else:
scale_v = self._load_var( scale_v = self._load_var(
op_node.output('OutScale')[0])[0] op_node.output('OutScale')[0])[0]
self._var_scale_map[input_arg_name] = scale_v self._quant_var_scale_map[input_arg_name] = scale_v
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
# quantize weight and restore # quantize weight and restore
param_v = self._load_var(input_arg_name) param_v = self._load_var(input_arg_name)
...@@ -743,32 +743,29 @@ class QuantizationFreezePass(object): ...@@ -743,32 +743,29 @@ class QuantizationFreezePass(object):
else: else:
scale_v = graph._find_node_by_name( scale_v = graph._find_node_by_name(
op_node.outputs, op_node.output('OutScale')[0]) op_node.outputs, op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v self._quant_var_scale_map[input_arg_name] = scale_v
# Remove all fake dequant op
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_name = op_node.name()
if op_name in self._fake_dequant_op_names: if op_name in self._fake_dequant_op_names:
self._remove_fake_quant_and_dequant_op(graph, op_node) self._remove_fake_quant_and_dequant_op(graph, op_node)
# Insert post dequant op
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
for op_node in ops: for op_node in ops:
op_name = op_node.name() op_node_desc = op_node.op()
if op_name in self._quantizable_ops: if op_node_desc.has_attr("quantization_type") and \
# only process the node that is quantized by QuantizationTransformPass op_node_desc.attr("quantization_type") == "qat_with_weight":
is_op_node_quantized = False if self._weight_quantize_type == 'channel_wise_abs_max' \
for var_node in op_node.inputs: and op_node.name() in self._conv_ops:
var_name = var_node.name() self._insert_post_channel_dequant_op(graph, op_node)
if var_name.endswith('.dequantized'): else:
is_op_node_quantized = True self._insert_post_dequant_op(graph, op_node)
if is_op_node_quantized:
if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
self._insert_post_channel_dequant_op(graph, op_node)
else:
self._insert_post_dequant_op(graph, op_node)
# Rename inputs of the followed ops after inserting dequant_op after fc/conv
for op_node in ops: for op_node in ops:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
for var_node in op_node.inputs: for var_node in op_node.inputs:
if var_node.node in self._op_output_rename_map: if var_node.node in self._op_output_rename_map:
old_in = var_node old_in = var_node
...@@ -802,7 +799,7 @@ class QuantizationFreezePass(object): ...@@ -802,7 +799,7 @@ class QuantizationFreezePass(object):
new_in.clear_outputs() new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node) graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name) original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name] scale_v = self._quant_var_scale_map[original_var_name]
if original_var_name in persistable_vars: if original_var_name in persistable_vars:
assert isinstance( assert isinstance(
scale_v, scale_v,
...@@ -811,7 +808,7 @@ class QuantizationFreezePass(object): ...@@ -811,7 +808,7 @@ class QuantizationFreezePass(object):
channel_scale = np.array(scale_v) channel_scale = np.array(scale_v)
else: else:
assert isinstance(scale_v, IrNode) assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name] scale_var_node = self._quant_var_scale_map[original_var_name]
if len(op_node.output_arg_names()) != 1: if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has" raise ValueError("Only support one output, but op %s has"
...@@ -867,7 +864,7 @@ class QuantizationFreezePass(object): ...@@ -867,7 +864,7 @@ class QuantizationFreezePass(object):
new_in.clear_outputs() new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node) graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name) original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name] scale_v = self._quant_var_scale_map[original_var_name]
if original_var_name in persistable_vars: if original_var_name in persistable_vars:
assert self._is_float( assert self._is_float(
scale_v), 'The scale of parameter %s is not a float.' % ( scale_v), 'The scale of parameter %s is not a float.' % (
...@@ -876,7 +873,7 @@ class QuantizationFreezePass(object): ...@@ -876,7 +873,7 @@ class QuantizationFreezePass(object):
else: else:
max_range *= act_range max_range *= act_range
assert isinstance(scale_v, IrNode) assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name] scale_var_node = self._quant_var_scale_map[original_var_name]
if len(op_node.output_arg_names()) != 1: if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has" raise ValueError("Only support one output, but op %s has"
...@@ -963,13 +960,7 @@ class QuantizationFreezePass(object): ...@@ -963,13 +960,7 @@ class QuantizationFreezePass(object):
class ConvertToInt8Pass(object): class ConvertToInt8Pass(object):
_supported_quantizable_op_type = \ def __init__(self, scope, place, quantizable_op_type=None):
QuantizationTransformPass._supported_quantizable_op_type
def __init__(self,
scope,
place,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
""" """
Convert the weights into int8_t type. Convert the weights into int8_t type.
...@@ -977,9 +968,8 @@ class ConvertToInt8Pass(object): ...@@ -977,9 +968,8 @@ class ConvertToInt8Pass(object):
scope(fluid.Scope): scope is used to get the weight tensor values. scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
8bits weight tensors. 8bits weight tensors.
quantizable_op_type(list[str]): List the type of ops that will be quantized. quantizable_op_type(list[str]): This input param will be removed latter. The pass
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in will process all quantized op, so it is not necessary to set the input param.
QuantizationTransformPass and QuantizationFreezePass must be the same as this.
""" """
assert scope is not None, \ assert scope is not None, \
'The scope cannot be set None.' 'The scope cannot be set None.'
...@@ -987,10 +977,6 @@ class ConvertToInt8Pass(object): ...@@ -987,10 +977,6 @@ class ConvertToInt8Pass(object):
'The place cannot be set None.' 'The place cannot be set None.'
self._scope = scope self._scope = scope
self._place = place self._place = place
self._quantizable_ops = quantizable_op_type
for op in self._quantizable_ops:
assert op in ConvertToInt8Pass._supported_quantizable_op_type, \
op + " is not supported for quantization."
def apply(self, graph): def apply(self, graph):
""" """
...@@ -1006,10 +992,8 @@ class ConvertToInt8Pass(object): ...@@ -1006,10 +992,8 @@ class ConvertToInt8Pass(object):
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
input_map = {} input_map = {}
for op_node in ops: for op_node in ops:
op_name = op_node.name() if op_node.op().has_attr("quantization_type") and \
if op_name in self._quantizable_ops: op_node.op().attr("quantization_type") == "qat_with_weight":
if QuantizationTransformPass._is_skip_quant(graph, op_node):
continue
for var_node in op_node.inputs: for var_node in op_node.inputs:
name = var_node.name() name = var_node.name()
if name in persistable_vars: if name in persistable_vars:
...@@ -1259,9 +1243,9 @@ class AddQuantDequantPass(object): ...@@ -1259,9 +1243,9 @@ class AddQuantDequantPass(object):
"equal", "gather", "greater_equal", "greater_than", "less_equal", "equal", "gather", "greater_equal", "greater_than", "less_equal",
"less_than", "mean", "not_equal", "reshape", "reshape2", "less_than", "mean", "not_equal", "reshape", "reshape2",
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice", "bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
"squeeze", "elementwise_sub", "mul", "matmul" "squeeze", "elementwise_sub", "mul", "matmul", "relu", "relu6",
"leaky_relu", "tanh", "swish"
] ]
_activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]
def __init__(self, def __init__(self,
scope=None, scope=None,
...@@ -1307,8 +1291,7 @@ class AddQuantDequantPass(object): ...@@ -1307,8 +1291,7 @@ class AddQuantDequantPass(object):
else: else:
self._quantizable_op_type = quantizable_op_type self._quantizable_op_type = quantizable_op_type
for op_type in quantizable_op_type: for op_type in quantizable_op_type:
assert op_type in AddQuantDequantPass._supported_quantizable_op_type + \ assert op_type in AddQuantDequantPass._supported_quantizable_op_type, \
AddQuantDequantPass._activation_type, \
op_type + " is not supported for quantization." op_type + " is not supported for quantization."
self._quantizable_grad_op_type = [ self._quantizable_grad_op_type = [
'%s_grad' % (op) for op in self._quantizable_op_type '%s_grad' % (op) for op in self._quantizable_op_type
...@@ -1343,17 +1326,15 @@ class AddQuantDequantPass(object): ...@@ -1343,17 +1326,15 @@ class AddQuantDequantPass(object):
elif isinstance(self._skip_pattern, str): elif isinstance(self._skip_pattern, str):
is_skip = op_node.op().has_attr("op_namescope") and \ is_skip = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1 op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
is_quantized = op_node.op().has_attr("quantization_type") and \
is_op_node_quantized = False op_node.op().attr("quantization_type") == "qat_with_weight"
for var_node in op_node.inputs: if is_skip or is_quantized or \
var_name = var_node.name()
if var_name.endswith('.dequantized'):
is_op_node_quantized = True
if is_skip or is_op_node_quantized or \
(not _is_input_all_not_persistable(graph, op_node)): (not _is_input_all_not_persistable(graph, op_node)):
continue continue
op_node.op()._set_attr("quantization_type",
"qat_without_weight")
op_node.op()._set_attr("activation_bits", self._quant_bits)
input_name_list = _op_real_in_out_name[op_node.name()][0] input_name_list = _op_real_in_out_name[op_node.name()][0]
arg_names = [] arg_names = []
for input_name in input_name_list: for input_name in input_name_list:
......
...@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq.save_quantized_model(self.int8_model) ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type, def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file): is_full_quantize, is_use_cache_file, diff_threshold):
infer_iterations = self.infer_iterations infer_iterations = self.infer_iterations
batch_size = self.batch_size batch_size = self.batch_size
sample_iterations = self.sample_iterations sample_iterations = self.sample_iterations
...@@ -296,11 +296,11 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -296,11 +296,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
sys.stdout.flush() sys.stdout.flush()
delta_value = fp32_acc1 - int8_acc1 delta_value = fp32_acc1 - int8_acc1
self.assertLess(delta_value, 0.025) self.assertLess(delta_value, diff_threshold)
class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization): class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_mobilenetv1(self): def test_post_training_kl_mobilenetv1(self):
model = "MobileNet-V1" model = "MobileNet-V1"
algo = "KL" algo = "KL"
data_urls = [ data_urls = [
...@@ -310,10 +310,29 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization): ...@@ -310,10 +310,29 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
quantizable_op_type = [ quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add" "conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
] ]
is_full_quantize = True is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file) is_full_quantize, is_use_cache_file, diff_threshold)
class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_abs_max_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
is_full_quantize = False
is_use_cache_file = False
diff_threshold = 0.05
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,7 +20,7 @@ from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantiza ...@@ -20,7 +20,7 @@ from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantiza
class TestPostTrainingForResnet50(TestPostTrainingQuantization): class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def test_post_training_resnet50(self): def test_post_training_resnet50(self):
model = "ResNet-50" model = "ResNet-50"
algo = "direct" algo = "min_max"
data_urls = [ data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz' 'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
] ]
...@@ -28,8 +28,9 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization): ...@@ -28,8 +28,9 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
quantizable_op_type = ["conv2d", "mul"] quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file) is_full_quantize, is_use_cache_file, diff_threshold)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册