提交 4d461939 编写于 作者: W wanghaoshuang

Merge branch 'quant_dequant' into 'develop'

add elementwise_add, pool2d in quant_aware

See merge request !27
...@@ -20,11 +20,19 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass ...@@ -20,11 +20,19 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core from paddle.fluid import core
WEIGHT_QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'] WEIGHT_QUANTIZATION_TYPES = [
ACTIVATION_QUANTIZATION_TYPES=['abs_max','range_abs_max', 'moving_average_abs_max'] 'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max'
]
ACTIVATION_QUANTIZATION_TYPES = [
'abs_max', 'range_abs_max', 'moving_average_abs_max'
]
VALID_DTYPES = ['int8'] VALID_DTYPES = ['int8']
TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul']
QUANT_DEQUANT_PASS_OP_TYPES = ['elementwise_add', 'pool2d']
_quant_config_default = { _quant_config_default = {
# weight quantize type, default is 'abs_max' # weight quantize type, default is 'abs_max'
...@@ -38,7 +46,8 @@ _quant_config_default = { ...@@ -38,7 +46,8 @@ _quant_config_default = {
# ops of name_scope in not_quant_pattern list, will not be quantized # ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized # ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], 'quantize_op_types':
['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8', 'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000 # window size for 'range_abs_max' quantization. defaulf is 10000
...@@ -88,6 +97,12 @@ def _parse_configs(user_config): ...@@ -88,6 +97,12 @@ def _parse_configs(user_config):
assert isinstance(configs['quantize_op_types'], list), \ assert isinstance(configs['quantize_op_types'], list), \
"quantize_op_types must be a list" "quantize_op_types must be a list"
for op_type in configs['quantize_op_types']:
assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or (
op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \
now support op types are {}".format(
op_type, TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES)
assert isinstance(configs['dtype'], str), \ assert isinstance(configs['dtype'], str), \
"dtype must be a str." "dtype must be a str."
...@@ -132,19 +147,37 @@ def quant_aware(program, place, config, scope=None, for_test=False): ...@@ -132,19 +147,37 @@ def quant_aware(program, place, config, scope=None, for_test=False):
config = _parse_configs(config) config = _parse_configs(config)
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
transform_pass = QuantizationTransformPass( transform_pass_ops = []
scope=scope, quant_dequant_ops = []
place=place, for op_type in config['quantize_op_types']:
weight_bits=config['weight_bits'], if op_type in TRANSFORM_PASS_OP_TYPES:
activation_bits=config['activation_bits'], transform_pass_ops.append(op_type)
activation_quantize_type=config['activation_quantize_type'], elif op_type in QUANT_DEQUANT_PASS_OP_TYPES:
weight_quantize_type=config['weight_quantize_type'], quant_dequant_ops.append(op_type)
window_size=config['window_size'], if len(transform_pass_ops) > 0:
moving_rate=config['moving_rate'], transform_pass = QuantizationTransformPass(
quantizable_op_type=config['quantize_op_types'], scope=scope,
skip_pattern=config['not_quant_pattern']) place=place,
weight_bits=config['weight_bits'],
transform_pass.apply(main_graph) activation_bits=config['activation_bits'],
activation_quantize_type=config['activation_quantize_type'],
weight_quantize_type=config['weight_quantize_type'],
window_size=config['window_size'],
moving_rate=config['moving_rate'],
quantizable_op_type=transform_pass_ops,
skip_pattern=config['not_quant_pattern'])
transform_pass.apply(main_graph)
if len(quant_dequant_ops) > 0:
quant_dequant_pass = AddQuantDequantPass(
scope=scope,
place=place,
moving_rate=config['moving_rate'],
quant_bits=config['activation_bits'],
skip_pattern=config['not_quant_pattern'],
quantizable_op_type=quant_dequant_ops)
quant_dequant_pass.apply(main_graph)
if for_test: if for_test:
quant_program = main_graph.to_program() quant_program = main_graph.to_program()
...@@ -168,7 +201,7 @@ def quant_post(program, place, config, scope=None): ...@@ -168,7 +201,7 @@ def quant_post(program, place, config, scope=None):
pass pass
def convert(program, scope, place, config, save_int8=False): def convert(program, place, config, scope=None, save_int8=False):
""" """
add quantization ops in program. the program returned is not trainable. add quantization ops in program. the program returned is not trainable.
Args: Args:
...@@ -183,7 +216,7 @@ def convert(program, scope, place, config, save_int8=False): ...@@ -183,7 +216,7 @@ def convert(program, scope, place, config, save_int8=False):
fluid.Program: freezed int8 program which can be used for inference. fluid.Program: freezed int8 program which can be used for inference.
if save_int8 is False, this value is None. if save_int8 is False, this value is None.
""" """
scope = fluid.global_scope() if not scope else scope
test_graph = IrGraph(core.Graph(program.desc), for_test=True) test_graph = IrGraph(core.Graph(program.desc), for_test=True)
# Freeze the graph after training by adjusting the quantize # Freeze the graph after training by adjusting the quantize
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册