未验证 提交 589cd878 编写于 作者: C cc 提交者: GitHub

Post_training_quantizaion supports min_max methon (#23078)

* Post_training_quantizaion supports min_max methon
上级 194a22c5
......@@ -35,6 +35,10 @@ _fake_dequant_op_list = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
]
_fake_quant_dequant_op_list = [
'fake_quantize_dequantize_moving_average_abs_max'
]
_out_scale_op_list = [
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d",
"batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul",
......@@ -44,7 +48,7 @@ _out_scale_op_list = [
# list op real input and output names, to avoid processing input such as AxisTensor.
_op_real_in_out_name = {
"conv2d": [["Input", "Filter"], ["Output"]],
"depthwise_conv2d": [["Input"], ["Output"]],
"depthwise_conv2d": [["Input", "Filter"], ["Output"]],
"mul": [["X", "Y"], ["Out"]],
"matmul": [["X", "Y"], ["Out"]],
"pool2d": [["X"], ["Out"]],
......@@ -236,6 +240,7 @@ class QuantizationTransformPass(object):
op_node.op()._set_attr("skip_quant", True)
def _transform_forward(graph, op):
op.op()._set_attr("quantization_type", "qat_with_weight")
for var_node in op.inputs:
if var_node.name() not in op.input_arg_names():
continue
......@@ -290,7 +295,7 @@ class QuantizationTransformPass(object):
# The loop for transforming the forward graph:
for op in ops:
if op.name() in self._quantizable_ops:
if not QuantizationTransformPass._is_skip_quant(graph, op):
if not self._is_skip_quant(graph, op):
_transform_forward(graph, op)
# The loop for renaming the inputs of backward op.
for op in ops:
......@@ -636,8 +641,7 @@ class QuantizationTransformPass(object):
"""
return "%s.scale" % (var_name)
@staticmethod
def _is_skip_quant(graph, op_node):
def _is_skip_quant(self, graph, op_node):
"""
Analyse whether the op node skips quantization.
"""
......@@ -650,20 +654,20 @@ class QuantizationTransformPass(object):
if op_node.name() in ["mul", "matmul"] and \
_is_input_all_not_persistable(graph, op_node):
is_skip = True
if op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_without_weight":
is_skip = True
return is_skip
class QuantizationFreezePass(object):
_supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type
def __init__(self,
scope,
place,
weight_bits=8,
activation_bits=8,
weight_quantize_type='abs_max',
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
quantizable_op_type=None):
"""
The freeze pass is used to adjust the quantize operator order, for example:
1) `activation -> quant -> dequant -> conv2d` will be frozen into
......@@ -679,9 +683,8 @@ class QuantizationFreezePass(object):
weight_quantize_type(str): quantization type for weights, support 'abs_max' and
'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight,
since weights are fixed once the model is well trained.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationTransformPass and ConvertToInt8Pass must be the same as this.
quantizable_op_type(list[str]): This input param will be removed latter. The pass
will process all quantized op, so it is not necessary to set the input param.
"""
assert scope is not None, \
'The scope cannot be set None.'
......@@ -692,16 +695,12 @@ class QuantizationFreezePass(object):
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = quantizable_op_type
for op in self._quantizable_ops:
assert op in QuantizationFreezePass._supported_quantizable_op_type, \
op + " is not supported for quantization."
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._fake_quant_op_names = _fake_quant_op_list
self._fake_dequant_op_names = _fake_dequant_op_list
self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict()
self._var_scale_map = collections.OrderedDict()
self._quant_var_scale_map = collections.OrderedDict()
def apply(self, graph):
"""
......@@ -712,6 +711,7 @@ class QuantizationFreezePass(object):
Returns:
None
"""
# Get input scales in fake quant op and process weights
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
ops = graph.all_op_nodes()
for op_node in ops:
......@@ -733,7 +733,7 @@ class QuantizationFreezePass(object):
else:
scale_v = self._load_var(
op_node.output('OutScale')[0])[0]
self._var_scale_map[input_arg_name] = scale_v
self._quant_var_scale_map[input_arg_name] = scale_v
self._remove_fake_quant_and_dequant_op(graph, op_node)
# quantize weight and restore
param_v = self._load_var(input_arg_name)
......@@ -743,32 +743,29 @@ class QuantizationFreezePass(object):
else:
scale_v = graph._find_node_by_name(
op_node.outputs, op_node.output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
self._quant_var_scale_map[input_arg_name] = scale_v
# Remove all fake dequant op
ops = graph.all_op_nodes()
for op_node in ops:
op_name = op_node.name()
if op_name in self._fake_dequant_op_names:
self._remove_fake_quant_and_dequant_op(graph, op_node)
# Insert post dequant op
ops = graph.all_op_nodes()
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
# only process the node that is quantized by QuantizationTransformPass
is_op_node_quantized = False
for var_node in op_node.inputs:
var_name = var_node.name()
if var_name.endswith('.dequantized'):
is_op_node_quantized = True
if is_op_node_quantized:
if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops:
op_node_desc = op_node.op()
if op_node_desc.has_attr("quantization_type") and \
op_node_desc.attr("quantization_type") == "qat_with_weight":
if self._weight_quantize_type == 'channel_wise_abs_max' \
and op_node.name() in self._conv_ops:
self._insert_post_channel_dequant_op(graph, op_node)
else:
self._insert_post_dequant_op(graph, op_node)
# Rename inputs of the followed ops after inserting dequant_op after fc/conv
for op_node in ops:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
for var_node in op_node.inputs:
if var_node.node in self._op_output_rename_map:
old_in = var_node
......@@ -802,7 +799,7 @@ class QuantizationFreezePass(object):
new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name]
scale_v = self._quant_var_scale_map[original_var_name]
if original_var_name in persistable_vars:
assert isinstance(
scale_v,
......@@ -811,7 +808,7 @@ class QuantizationFreezePass(object):
channel_scale = np.array(scale_v)
else:
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]
scale_var_node = self._quant_var_scale_map[original_var_name]
if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has"
......@@ -867,7 +864,7 @@ class QuantizationFreezePass(object):
new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name]
scale_v = self._quant_var_scale_map[original_var_name]
if original_var_name in persistable_vars:
assert self._is_float(
scale_v), 'The scale of parameter %s is not a float.' % (
......@@ -876,7 +873,7 @@ class QuantizationFreezePass(object):
else:
max_range *= act_range
assert isinstance(scale_v, IrNode)
scale_var_node = self._var_scale_map[original_var_name]
scale_var_node = self._quant_var_scale_map[original_var_name]
if len(op_node.output_arg_names()) != 1:
raise ValueError("Only support one output, but op %s has"
......@@ -963,13 +960,7 @@ class QuantizationFreezePass(object):
class ConvertToInt8Pass(object):
_supported_quantizable_op_type = \
QuantizationTransformPass._supported_quantizable_op_type
def __init__(self,
scope,
place,
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
def __init__(self, scope, place, quantizable_op_type=None):
"""
Convert the weights into int8_t type.
......@@ -977,9 +968,8 @@ class ConvertToInt8Pass(object):
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
8bits weight tensors.
quantizable_op_type(list[str]): List the type of ops that will be quantized.
Default is ["conv2d", "depthwise_conv2d", "mul"]. The quantizable_op_type in
QuantizationTransformPass and QuantizationFreezePass must be the same as this.
quantizable_op_type(list[str]): This input param will be removed latter. The pass
will process all quantized op, so it is not necessary to set the input param.
"""
assert scope is not None, \
'The scope cannot be set None.'
......@@ -987,10 +977,6 @@ class ConvertToInt8Pass(object):
'The place cannot be set None.'
self._scope = scope
self._place = place
self._quantizable_ops = quantizable_op_type
for op in self._quantizable_ops:
assert op in ConvertToInt8Pass._supported_quantizable_op_type, \
op + " is not supported for quantization."
def apply(self, graph):
"""
......@@ -1006,10 +992,8 @@ class ConvertToInt8Pass(object):
ops = graph.all_op_nodes()
input_map = {}
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
if QuantizationTransformPass._is_skip_quant(graph, op_node):
continue
if op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_with_weight":
for var_node in op_node.inputs:
name = var_node.name()
if name in persistable_vars:
......@@ -1259,9 +1243,9 @@ class AddQuantDequantPass(object):
"equal", "gather", "greater_equal", "greater_than", "less_equal",
"less_than", "mean", "not_equal", "reshape", "reshape2",
"bilinear_interp", "nearest_interp", "trilinear_interp", "slice",
"squeeze", "elementwise_sub", "mul", "matmul"
"squeeze", "elementwise_sub", "mul", "matmul", "relu", "relu6",
"leaky_relu", "tanh", "swish"
]
_activation_type = ["relu", "relu6", "leaky_relu", "tanh", "swish"]
def __init__(self,
scope=None,
......@@ -1307,8 +1291,7 @@ class AddQuantDequantPass(object):
else:
self._quantizable_op_type = quantizable_op_type
for op_type in quantizable_op_type:
assert op_type in AddQuantDequantPass._supported_quantizable_op_type + \
AddQuantDequantPass._activation_type, \
assert op_type in AddQuantDequantPass._supported_quantizable_op_type, \
op_type + " is not supported for quantization."
self._quantizable_grad_op_type = [
'%s_grad' % (op) for op in self._quantizable_op_type
......@@ -1343,17 +1326,15 @@ class AddQuantDequantPass(object):
elif isinstance(self._skip_pattern, str):
is_skip = op_node.op().has_attr("op_namescope") and \
op_node.op().attr("op_namescope").find(self._skip_pattern) != -1
is_op_node_quantized = False
for var_node in op_node.inputs:
var_name = var_node.name()
if var_name.endswith('.dequantized'):
is_op_node_quantized = True
if is_skip or is_op_node_quantized or \
is_quantized = op_node.op().has_attr("quantization_type") and \
op_node.op().attr("quantization_type") == "qat_with_weight"
if is_skip or is_quantized or \
(not _is_input_all_not_persistable(graph, op_node)):
continue
op_node.op()._set_attr("quantization_type",
"qat_without_weight")
op_node.op()._set_attr("activation_bits", self._quant_bits)
input_name_list = _op_real_in_out_name[op_node.name()][0]
arg_names = []
for input_name in input_name_list:
......
......@@ -264,7 +264,7 @@ class TestPostTrainingQuantization(unittest.TestCase):
ptq.save_quantized_model(self.int8_model)
def run_test(self, model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file):
is_full_quantize, is_use_cache_file, diff_threshold):
infer_iterations = self.infer_iterations
batch_size = self.batch_size
sample_iterations = self.sample_iterations
......@@ -296,11 +296,11 @@ class TestPostTrainingQuantization(unittest.TestCase):
sys.stdout.flush()
delta_value = fp32_acc1 - int8_acc1
self.assertLess(delta_value, 0.025)
self.assertLess(delta_value, diff_threshold)
class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_mobilenetv1(self):
class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_kl_mobilenetv1(self):
model = "MobileNet-V1"
algo = "KL"
data_urls = [
......@@ -310,10 +310,29 @@ class TestPostTrainingForMobilenetv1(TestPostTrainingQuantization):
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
is_full_quantize = True
is_full_quantize = False
is_use_cache_file = False
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file)
is_full_quantize, is_use_cache_file, diff_threshold)
class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
def test_post_training_abs_max_mobilenetv1(self):
model = "MobileNet-V1"
algo = "abs_max"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/mobilenetv1_int8_model.tar.gz'
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
]
is_full_quantize = False
is_use_cache_file = False
diff_threshold = 0.05
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold)
if __name__ == '__main__':
......
......@@ -20,7 +20,7 @@ from test_post_training_quantization_mobilenetv1 import TestPostTrainingQuantiza
class TestPostTrainingForResnet50(TestPostTrainingQuantization):
def test_post_training_resnet50(self):
model = "ResNet-50"
algo = "direct"
algo = "min_max"
data_urls = [
'http://paddle-inference-dist.bj.bcebos.com/int8/resnet50_int8_model.tar.gz'
]
......@@ -28,8 +28,9 @@ class TestPostTrainingForResnet50(TestPostTrainingQuantization):
quantizable_op_type = ["conv2d", "mul"]
is_full_quantize = False
is_use_cache_file = False
diff_threshold = 0.025
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file)
is_full_quantize, is_use_cache_file, diff_threshold)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册