未验证 提交 25628587 编写于 作者: C cc 提交者: GitHub

Collect output scale for quantized op and fused op (#23369)

* Collect output scale for quantized op and fused op
* Post_training_quantizaion sets batch_generator to support lod tensor
上级 6162cf2f
......@@ -18,12 +18,13 @@ from ..... import compat as cpt
from .... import core
from ....framework import IrGraph
from ....framework import IrNode
from ....framework import Operator
from .... import unique_name
__all__ = [
'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass',
'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass',
'AddQuantDequantPass'
'TransformForMobilePass', 'OutScaleForTrainingPass',
'OutScaleForInferencePass', 'AddQuantDequantPass'
]
_fake_quant_op_list = [
......@@ -40,9 +41,9 @@ _fake_quant_dequant_op_list = [
]
_out_scale_op_list = [
"mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", "depthwise_conv2d",
"batch_norm", "concat", "tanh", "pad", "elementwise_add", "elementwise_mul",
"dropout", "split", "prelu", "conv2d_transpose", "leaky_relu"
"conv2d", "depthwise_conv2d", "mul", "matmul", "relu", "leaky_relu",
"relu6", "sigmoid", "tanh", "prelu", "swish", "softmax", "batch_norm",
"elementwise_add", "pool2d", "reshape2", "transpose2"
]
# list op real input and output names, to avoid processing input such as AxisTensor.
......@@ -67,6 +68,7 @@ _op_real_in_out_name = {
"not_equal": [["X", "Y"], ["Out"]],
"reshape": [["X"], ["Out"]],
"reshape2": [["X"], ["Out"]],
"transpose2": [["X"], ["Out"]],
"bilinear_interp": [["X"], ["Out"]],
"nearest_interp": [["X"], ["Out"]],
"trilinear_interp": [["X"], ["Out"]],
......@@ -76,11 +78,49 @@ _op_real_in_out_name = {
"relu": [["X"], ["Out"]],
"relu6": [["X"], ["Out"]],
"leaky_relu": [["X"], ["Out"]],
"prelu": [["X"], ["Out"]],
"tanh": [["X"], ["Out"]],
"swish": [["X"], ["Out"]],
"dropout": [["X"], ["Out"]],
"batch_norm": [["X"], ["Y"]],
"sigmoid": [["X"], ["Y"]],
}
def _get_op_input_var_names(op):
""" """
assert isinstance(op, (IrNode, Operator)), \
"The input op should be IrNode or Operator."
var_names = []
op_name = op.name() if isinstance(op, IrNode) \
else op.type
name_list = _op_real_in_out_name[op_name][0]
for name in name_list:
var_name = op.input(name)
if isinstance(var_name, list):
var_names.extend(var_name)
else:
var_names.append(var_name)
return var_names
def _get_op_output_var_names(op):
""" """
assert isinstance(op, (IrNode, Operator)), \
"The input op should be IrNode or Operator."
var_names = []
op_name = op.name() if isinstance(op, IrNode) \
else op.type
name_list = _op_real_in_out_name[op_name][1]
for name in name_list:
var_name = op.output(name)
if isinstance(var_name, list):
var_names.extend(var_name)
else:
var_names.append(var_name)
return var_names
def _init_var_node(var_node, value, scope, place):
assert isinstance(value,
np.ndarray), 'The type of value should be numpy array.'
......@@ -97,17 +137,18 @@ def _is_input_all_not_persistable(graph, op_node):
Analyse the real inputs of the op node are all not persistable.
'''
is_input_all_not_persistable = True
op_node_name = op_node.name()
input_name_list = _op_real_in_out_name[op_node_name][0]
for input_name in input_name_list:
for arg_name in op_node.input(input_name):
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
for var_name in _get_op_input_var_names(op_node):
in_node = graph._find_node_by_name(op_node.inputs, var_name)
is_input_all_not_persistable = (is_input_all_not_persistable and \
(not in_node.persistable()))
return is_input_all_not_persistable
class QuantizationTransformPass(object):
"""
Quantize the ops that have weights. Add quant and dequant ops for the quantized
ops's inputs.
"""
_supported_quantizable_op_type = [
'conv2d', 'depthwise_conv2d', 'mul', 'matmul'
]
......@@ -124,8 +165,7 @@ class QuantizationTransformPass(object):
skip_pattern=['skip_quant'],
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul']):
"""
Convert and rewrite the IrGraph according to weight and
activation quantization type.
Constructor.
Args:
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
......@@ -1088,7 +1128,7 @@ class TransformForMobilePass(object):
return graph
class ScaleForTrainingPass(object):
class OutScaleForTrainingPass(object):
def __init__(self, scope=None, place=None, moving_rate=0.9):
"""
This pass is used for calculating output scales of some operators.
......@@ -1195,7 +1235,7 @@ class ScaleForTrainingPass(object):
return "%s@scale" % (var_name)
class ScaleForInferencePass(object):
class OutScaleForInferencePass(object):
def __init__(self, scope=None):
"""
This pass is used for setting output scales of some operators.
......@@ -1226,7 +1266,7 @@ class ScaleForInferencePass(object):
scale_name = self._scale_name(op_node.output_arg_names()[0])
scale_v = np.array(
self._scope.find_var(scale_name).get_tensor())[0]
op_node.op()._set_attr("out_scale", float(scale_v))
op_node.op()._set_attr("out_threshold", float(scale_v))
graph.resolve_hazard()
return graph
......@@ -1238,6 +1278,10 @@ class ScaleForInferencePass(object):
class AddQuantDequantPass(object):
"""
Quantize the ops that do not have weights, and add quant_dequant op for the
quantized ops's inputs.
"""
_supported_quantizable_op_type = [
"pool2d", "elementwise_add", "concat", "softmax", "argmax", "transpose",
"equal", "gather", "greater_equal", "greater_than", "less_equal",
......@@ -1259,9 +1303,7 @@ class AddQuantDequantPass(object):
quantizable_op_type=["elementwise_add", "pool2d"],
is_full_quantized=False):
"""
This pass add quant_dequant op for some ops, of which all the inputs must be
not persistable.
The input scales can be obtained from the quant_dequant op.
Constructor.
Args:
scope(fluid.Scope): The scope is used to initialize these new parameters.
......@@ -1338,10 +1380,7 @@ class AddQuantDequantPass(object):
op_node.op()._set_attr("quantization_type",
"qat_without_weight")
op_node.op()._set_attr("activation_bits", self._quant_bits)
input_name_list = _op_real_in_out_name[op_node.name()][0]
arg_names = []
for input_name in input_name_list:
arg_names.extend(op_node.input(input_name))
arg_names = _get_op_input_var_names(op_node)
for arg_name in arg_names:
in_node = graph._find_node_by_name(op_node.inputs, arg_name)
if arg_name in dequantized_vars_map:
......
......@@ -107,15 +107,14 @@ function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_mode
--quantized_ops ${quantized_ops})
endfunction()
# Disable the unittest temporary
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
endif()
# Disable unittest for random error
# Disable unittest for random error temporary
list(REMOVE_ITEM TEST_OPS test_quantization_scale_pass)
if(LINUX AND WITH_MKLDNN)
......
......@@ -140,9 +140,9 @@ class TestPostTrainingQuantization(unittest.TestCase):
self.batch_size = 1 if os.environ.get('DATASET') == 'full' else 50
self.sample_iterations = 50 if os.environ.get(
'DATASET') == 'full' else 1
'DATASET') == 'full' else 2
self.infer_iterations = 50000 if os.environ.get(
'DATASET') == 'full' else 1
'DATASET') == 'full' else 2
self.timestamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
self.int8_model = os.path.join(os.getcwd(),
......@@ -287,11 +287,12 @@ class TestPostTrainingQuantization(unittest.TestCase):
(int8_throughput, int8_latency, int8_acc1) = self.run_program(
self.int8_model, batch_size, infer_iterations)
print("---Post training quantization of {} method---".format(algo))
print(
"FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}".
"FP32 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}.".
format(model, batch_size, fp32_throughput, fp32_latency, fp32_acc1))
print(
"INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}".
"INT8 {0}: batch_size {1}, throughput {2} images/second, latency {3} second, accuracy {4}.\n".
format(model, batch_size, int8_throughput, int8_latency, int8_acc1))
sys.stdout.flush()
......@@ -308,7 +309,10 @@ class TestPostTrainingKLForMobilenetv1(TestPostTrainingQuantization):
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
"conv2d",
"depthwise_conv2d",
"mul",
"pool2d",
]
is_full_quantize = False
is_use_cache_file = False
......@@ -326,10 +330,12 @@ class TestPostTrainingAbsMaxForMobilenetv1(TestPostTrainingQuantization):
]
data_md5s = ['13892b0716d26443a8cdea15b3c6438b']
quantizable_op_type = [
"conv2d", "depthwise_conv2d", "mul", "pool2d", "elementwise_add"
"conv2d",
"mul",
]
is_full_quantize = False
is_use_cache_file = False
# The accuracy diff of post-traing quantization (abs_max) maybe bigger
diff_threshold = 0.05
self.run_test(model, algo, data_urls, data_md5s, quantizable_op_type,
is_full_quantize, is_use_cache_file, diff_threshold)
......
......@@ -22,8 +22,8 @@ import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import ScaleForInferencePass
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core
......@@ -112,7 +112,7 @@ class TestQuantizationScalePass(unittest.TestCase):
add_quant_dequant_pass.apply(main_graph)
add_quant_dequant_pass.apply(test_graph)
scale_training_pass = ScaleForTrainingPass(scope=scope, place=place)
scale_training_pass = OutScaleForTrainingPass(scope=scope, place=place)
scale_training_pass.apply(main_graph)
dev_name = '_gpu' if use_cuda else '_cpu'
......@@ -151,7 +151,7 @@ class TestQuantizationScalePass(unittest.TestCase):
if not for_ci:
print('{}: {}'.format('loss' + dev_name, loss_v))
scale_inference_pass = ScaleForInferencePass(scope=scope)
scale_inference_pass = OutScaleForInferencePass(scope=scope)
scale_inference_pass.apply(test_graph)
# Freeze graph for inference, but the weight of fc/conv is still float type.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册