未验证 提交 25628587 编写于 作者: C cc 提交者: GitHub

Collect output scale for quantized op and fused op (#23369)

* Collect output scale for quantized op and fused op
* Post_training_quantizaion sets batch_generator to support lod tensor
上级 6162cf2f
...@@ -18,12 +18,13 @@ from ..... import compat as cpt ...@@ -18,12 +18,13 @@ from ..... import compat as cpt
from .... import core from .... import core
from ....framework import IrGraph from ....framework import IrGraph
from ....framework import IrNode from ....framework import IrNode
from ....framework import Operator
from .... import unique_name from .... import unique_name
__all__ = [ __all__ = [
'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass', 'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass',
'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass', 'TransformForMobilePass', 'OutScaleForTrainingPass',
'AddQuantDequantPass' 'OutScaleForInferencePass', 'AddQuantDequantPass'
] ]
_fake_quant_op_list = [ _fake_quant_op_list = [
...@@ -40,9 +41,9 @@ _fake_quant_dequant_op_list = [ ...@@ -40,9 +41,9 @@ _fake_quant_dequant_op_list = [
] ]
_out_scale_op_list = [ _out_scale_op_list = [
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d", "conv2d", "depthwise_conv2d", "mul", "matmul", "relu", "leaky_relu",
"batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul", "relu6", "sigmoid", "tanh", "prelu", "swish", "softmax", "batch_norm",
"dropout", "split", "prelu", "conv2d_transpose", "leaky_relu" "elementwise_add", "pool2d", "reshape2", "transpose2"
] ]
# list op real input and output names, to avoid processing input such as AxisTensor. # list op real input and output names, to avoid processing input such as AxisTensor.
...@@ -67,6 +68,7 @@ _op_real_in_out_name = { ...@@ -67,6 +68,7 @@ _op_real_in_out_name = {
"not_equal": [["X", "Y"], ["Out"]], "not_equal": [["X", "Y"], ["Out"]],
"reshape": [["X"], ["Out"]], "reshape": [["X"], ["Out"]],
"reshape2": [["X"], ["Out"]], "reshape2": [["X"], ["Out"]],
"transpose2": [["X"], ["Out"]],
"bilinear_interp": [["X"], ["Out"]], "bilinear_interp": [["X"], ["Out"]],
"nearest_interp": [["X"], ["Out"]], "nearest_interp": [["X"], ["Out"]],
"trilinear_interp": [["X"], ["Out"]], "trilinear_interp": [["X"], ["Out"]],
...@@ -76,11 +78,49 @@ _op_real_in_out_name = { ...@@ -76,11 +78,49 @@ _op_real_in_out_name = {
"relu": [["X"], ["Out"]], "relu": [["X"], ["Out"]],
"relu6": [["X"], ["Out"]], "relu6": [["X"], ["Out"]],
"leaky_relu": [["X"], ["Out"]], "leaky_relu": [["X"], ["Out"]],
"prelu": [["X"], ["Out"]],
"tanh": [["X"], ["Out"]], "tanh": [["X"], ["Out"]],
"swish": [["X"], ["Out"]], "swish": [["X"], ["Out"]],
"dropout": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Y"]],
} }
def _get_op_input_var_names(op):
""" """
assert isinstance(op, (IrNode, Operator)), \
"The input op should be IrNode or Operator."
var_names = []
op_name = op.name() if isinstance(op, IrNode) \
else op.type
name_list = _op_real_in_out_name[op_name][0]
for name in name_list:
var_name = op.input(name)
if isinstance(var_name, list):
var_names.extend(var_name)
else:
var_names.append(var_name)
return var_names
def _get_op_output_var_names(op):
""" """
assert isinstance(op, (IrNode, Operator)), \
"The input op should be IrNode or Operator."
var_names = []
op_name = op.name() if isinstance(op, IrNode) \
else op.type
name_list = _op_real_in_out_name[op_name][1]
for name in name_list:
var_name = op.output(name)
if isinstance(var_name, list):
var_names.extend(var_name)
else:
var_names.append(var_name)
return var_names
def _init_var_node(var_node, value, scope, place): def _init_var_node(var_node, value, scope, place):
assert isinstance(value, assert isinstance(value,
np.ndarray), 'The type of value should be numpy array.' np.ndarray), 'The type of value should be numpy array.'
...@@ -97,17 +137,18 @@ def _is_input_all_not_persistable(graph, op_node): ...@@ -97,17 +137,18 @@ def _is_input_all_not_persistable(graph, op_node):
Analyse the real inputs of the op node are all not persistable. Analyse the real inputs of the op node are all not persistable.
''' '''
is_input_all_not_persistable = True is_input_all_not_persistable = True
op_node_name = op_node.name() for var_name in _get_op_input_var_names(op_node):
input_name_list = _op_real_in_out_name[op_node_name][0] in_node = graph._find_node_by_name(op_node.inputs, var_name)
for input_name in input_name_list:
for arg_name in op_node.input(input_name):
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
is_input_all_not_persistable = (is_input_all_not_persistable and \ is_input_all_not_persistable = (is_input_all_not_persistable and \
(not in_node.persistable())) (not in_node.persistable()))
return is_input_all_not_persistable return is_input_all_not_persistable
class QuantizationTransformPass(object): class QuantizationTransformPass(object):
"""
Quantize the ops that have weights. Add quant and dequant ops for the quantized
ops's inputs.
"""
_supported_quantizable_op_type = [ _supported_quantizable_op_type = [
'conv2d', 'depthwise_conv2d', 'mul', 'matmul' 'conv2d', 'depthwise_conv2d', 'mul', 'matmul'
] ]
...@@ -124,8 +165,7 @@ class QuantizationTransformPass(object): ...@@ -124,8 +165,7 @@ class QuantizationTransformPass(object):
skip_pattern=['skip_quant'], skip_pattern=['skip_quant'],
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']): quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
""" """
Convert and rewrite the IrGraph according to weight and Constructor.
activation quantization type.
Args: Args:
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
...@@ -1088,7 +1128,7 @@ class TransformForMobilePass(object): ...@@ -1088,7 +1128,7 @@ class TransformForMobilePass(object):
return graph return graph
class ScaleForTrainingPass(object): class OutScaleForTrainingPass(object):
def __init__(self, scope=None, place=None, moving_rate=0.9): def __init__(self, scope=None, place=None, moving_rate=0.9):
""" """
This pass is used for calculating output scales of some operators. This pass is used for calculating output scales of some operators.
...@@ -1195,7 +1235,7 @@ class ScaleForTrainingPass(object): ...@@ -1195,7 +1235,7 @@ class ScaleForTrainingPass(object):
return "%s@scale" % (var_name) return "%s@scale" % (var_name)
class ScaleForInferencePass(object): class OutScaleForInferencePass(object):
def __init__(self, scope=None): def __init__(self, scope=None):
""" """
This pass is used for setting output scales of some operators. This pass is used for setting output scales of some operators.
...@@ -1226,7 +1266,7 @@ class ScaleForInferencePass(object): ...@@ -1226,7 +1266,7 @@ class ScaleForInferencePass(object):
scale_name = self._scale_name(op_node.output_arg_names()[0]) scale_name = self._scale_name(op_node.output_arg_names()[0])
scale_v = np.array( scale_v = np.array(
self._scope.find_var(scale_name).get_tensor())[0] self._scope.find_var(scale_name).get_tensor())[0]
op_node.op()._set_attr("out_scale", float(scale_v)) op_node.op()._set_attr("out_threshold", float(scale_v))
graph.resolve_hazard() graph.resolve_hazard()
return graph return graph
...@@ -1238,6 +1278,10 @@ class ScaleForInferencePass(object): ...@@ -1238,6 +1278,10 @@ class ScaleForInferencePass(object):
class AddQuantDequantPass(object): class AddQuantDequantPass(object):
"""
Quantize the ops that do not have weights, and add quant_dequant op for the
quantized ops's inputs.
"""
_supported_quantizable_op_type = [ _supported_quantizable_op_type = [
"pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose", "pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose",
"equal", "gather", "greater_equal", "greater_than", "less_equal", "equal", "gather", "greater_equal", "greater_than", "less_equal",
...@@ -1259,9 +1303,7 @@ class AddQuantDequantPass(object): ...@@ -1259,9 +1303,7 @@ class AddQuantDequantPass(object):
quantizable_op_type=["elementwise_add", "pool2d"], quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False): is_full_quantized=False):
""" """
This pass add quant_dequant op for some ops, of which all the inputs must be Constructor.
not persistable.
The input scales can be obtained from the quant_dequant op.
Args: Args:
scope(fluid.Scope): The scope is used to initialize these new parameters. scope(fluid.Scope): The scope is used to initialize these new parameters.
...@@ -1338,10 +1380,7 @@ class AddQuantDequantPass(object): ...@@ -1338,10 +1380,7 @@ class AddQuantDequantPass(object):
op_node.op()._set_attr("quantization_type", op_node.op()._set_attr("quantization_type",
"qat_without_weight") "qat_without_weight")
op_node.op()._set_attr("activation_bits", self._quant_bits) op_node.op()._set_attr("activation_bits", self._quant_bits)
input_name_list = _op_real_in_out_name[op_node.name()][0] arg_names = _get_op_input_var_names(op_node)
arg_names = []
for input_name in input_name_list:
arg_names.extend(op_node.input(input_name))
for arg_name in arg_names: for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name) in_node = graph._find_node_by_name(op_node.inputs, arg_name)
if arg_name in dequantized_vars_map: if arg_name in dequantized_vars_map:
......
...@@ -107,15 +107,14 @@ function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_mode ...@@ -107,15 +107,14 @@ function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_mode
--quantized_ops ${quantized_ops}) --quantized_ops ${quantized_ops})
endfunction() endfunction()
# Disable the unittest temporary
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
if(WIN32) if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas) list(REMOVE_ITEM TEST_OPS test_light_nas)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50) list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
endif() endif()
# Disable unittest for random error # Disable unittest for random error temporary
list(REMOVE_ITEM TEST_OPS test_quantization_scale_pass) list(REMOVE_ITEM TEST_OPS test_quantization_scale_pass)
if(LINUX AND WITH_MKLDNN) if(LINUX AND WITH_MKLDNN)
......
...@@ -140,9 +140,9 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -140,9 +140,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.batch_size = 1 if os.environ.get('DATASET') == 'full' else 50 self.batch_size = 1 if os.environ.get('DATASET') == 'full' else 50
self.sample_iterations = 50 if os.environ.get( self.sample_iterations = 50 if os.environ.get(
'DATASET') == 'full' else 1 'DATASET') == 'full' else 2
self.infer_iterations = 50000 if os.environ.get( self.infer_iterations = 50000 if os.environ.get(
'DATASET') == 'full' else 1 'DATASET') == 'full' else 2
self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
self.int8_model = os.path.join(os.getcwd(), self.int8_model = os.path.join(os.getcwd(),
...@@ -287,11 +287,12 @@ class TestPostTrainingQuantization(unittest.TestCase): ...@@ -287,11 +287,12 @@ class TestPostTrainingQuantization(unittest.TestCase):
(int8_throughput, int8_latency, int8_acc1) = self.run_program( (int8_throughput, int8_latency, int8_acc1) = self.run_program(
self.int8_model, batch_size, infer_iterations) self.int8_model, batch_size, infer_iterations)
print("---Post training quantization of {} method---".format(algo))
print( print(
"FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}". "FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}.".
format(model, batch_size, fp32_throughput, fp32_latency, fp32_acc1)) format(model, batch_size, fp32_throughput, fp32_latency, fp32_acc1))
print( print(
"INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}". "INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}.\n".
format(model, batch_size, int8_throughput, int8_latency, int8_acc1)) format(model, batch_size, int8_throughput, int8_latency, int8_acc1))
sys.stdout.flush() sys.stdout.flush()
...@@ -308,7 +309,10 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization): ...@@ -308,7 +309,10 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
] ]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [ quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add" "conv2d",
"depthwise_conv2d",
"mul",
"pool2d",
] ]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
...@@ -326,10 +330,12 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization): ...@@ -326,10 +330,12 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
] ]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b'] data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [ quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add" "conv2d",
"mul",
] ]
is_full_quantize = False is_full_quantize = False
is_use_cache_file = False is_use_cache_file = False
# The accuracy diff of post-traing quantization (abs_max) maybe bigger
diff_threshold = 0.05 diff_threshold = 0.05
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type, self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold) is_full_quantize, is_use_cache_file, diff_threshold)
......
...@@ -22,8 +22,8 @@ import paddle ...@@ -22,8 +22,8 @@ import paddle
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ScaleForTrainingPass from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import ScaleForInferencePass from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core from paddle.fluid import core
...@@ -112,7 +112,7 @@ class TestQuantizationScalePass(unittest.TestCase): ...@@ -112,7 +112,7 @@ class TestQuantizationScalePass(unittest.TestCase):
add_quant_dequant_pass.apply(main_graph) add_quant_dequant_pass.apply(main_graph)
add_quant_dequant_pass.apply(test_graph) add_quant_dequant_pass.apply(test_graph)
scale_training_pass = ScaleForTrainingPass(scope=scope, place=place) scale_training_pass = OutScaleForTrainingPass(scope=scope, place=place)
scale_training_pass.apply(main_graph) scale_training_pass.apply(main_graph)
dev_name = '_gpu' if use_cuda else '_cpu' dev_name = '_gpu' if use_cuda else '_cpu'
...@@ -151,7 +151,7 @@ class TestQuantizationScalePass(unittest.TestCase): ...@@ -151,7 +151,7 @@ class TestQuantizationScalePass(unittest.TestCase):
if not for_ci: if not for_ci:
print('{}: {}'.format('loss' + dev_name, loss_v)) print('{}: {}'.format('loss' + dev_name, loss_v))
scale_inference_pass = ScaleForInferencePass(scope=scope) scale_inference_pass = OutScaleForInferencePass(scope=scope)
scale_inference_pass.apply(test_graph) scale_inference_pass.apply(test_graph)
# Freeze graph for inference, but the weight of fc/conv is still float type. # Freeze graph for inference, but the weight of fc/conv is still float type.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册