提交 b80d8969 编写于 作者: S slf12

add elementwise_add, pool2d

上级 72c800e9
...@@ -20,11 +20,19 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass ...@@ -20,11 +20,19 @@ from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
from paddle.fluid import core from paddle.fluid import core
WEIGHT_QUANTIZATION_TYPES=['abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'] WEIGHT_QUANTIZATION_TYPES = [
ACTIVATION_QUANTIZATION_TYPES=['abs_max','range_abs_max', 'moving_average_abs_max'] 'abs_max', 'channel_wise_abs_max', 'range_abs_max',
'moving_average_abs_max'
]
ACTIVATION_QUANTIZATION_TYPES = [
'abs_max', 'range_abs_max', 'moving_average_abs_max'
]
VALID_DTYPES = ['int8'] VALID_DTYPES = ['int8']
TRANSFORM_PASS_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul']
QUANT_DEQUANT_PASS_OP_TYPES = ['elementwise_add', 'pool2d']
_quant_config_default = { _quant_config_default = {
# weight quantize type, default is 'abs_max' # weight quantize type, default is 'abs_max'
...@@ -38,7 +46,8 @@ _quant_config_default = { ...@@ -38,7 +46,8 @@ _quant_config_default = {
# ops of name_scope in not_quant_pattern list, will not be quantized # ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'], 'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized # ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], 'quantize_op_types':
['conv2d', 'depthwise_conv2d', 'mul', 'elementwise_add', 'pool2d'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8', 'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000 # window size for 'range_abs_max' quantization. defaulf is 10000
...@@ -88,6 +97,12 @@ def _parse_configs(user_config): ...@@ -88,6 +97,12 @@ def _parse_configs(user_config):
assert isinstance(configs['quantize_op_types'], list), \ assert isinstance(configs['quantize_op_types'], list), \
"quantize_op_types must be a list" "quantize_op_types must be a list"
for op_type in configs['quantize_op_types']:
assert (op_type in QUANT_DEQUANT_PASS_OP_TYPES) or (
op_type in TRANSFORM_PASS_OP_TYPES), "{} is not support, \
now support op types are {}".format(
op_type, TRANSFORM_PASS_OP_TYPES + QUANT_DEQUANT_PASS_OP_TYPES)
assert isinstance(configs['dtype'], str), \ assert isinstance(configs['dtype'], str), \
"dtype must be a str." "dtype must be a str."
...@@ -132,6 +147,14 @@ def quant_aware(program, place, config, scope=None, for_test=False): ...@@ -132,6 +147,14 @@ def quant_aware(program, place, config, scope=None, for_test=False):
config = _parse_configs(config) config = _parse_configs(config)
main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) main_graph = IrGraph(core.Graph(program.desc), for_test=for_test)
transform_pass_ops = []
quant_dequant_ops = []
for op_type in config['quantize_op_types']:
if op_type in TRANSFORM_PASS_OP_TYPES:
transform_pass_ops.append(op_type)
elif op_type in QUANT_DEQUANT_PASS_OP_TYPES:
quant_dequant_ops.append(op_type)
if len(transform_pass_ops) > 0:
transform_pass = QuantizationTransformPass( transform_pass = QuantizationTransformPass(
scope=scope, scope=scope,
place=place, place=place,
...@@ -141,11 +164,21 @@ def quant_aware(program, place, config, scope=None, for_test=False): ...@@ -141,11 +164,21 @@ def quant_aware(program, place, config, scope=None, for_test=False):
weight_quantize_type=config['weight_quantize_type'], weight_quantize_type=config['weight_quantize_type'],
window_size=config['window_size'], window_size=config['window_size'],
moving_rate=config['moving_rate'], moving_rate=config['moving_rate'],
quantizable_op_type=config['quantize_op_types'], quantizable_op_type=transform_pass_ops,
skip_pattern=config['not_quant_pattern']) skip_pattern=config['not_quant_pattern'])
transform_pass.apply(main_graph) transform_pass.apply(main_graph)
if len(quant_dequant_ops) > 0:
quant_dequant_pass = AddQuantDequantPass(
scope=scope,
place=place,
moving_rate=config['moving_rate'],
quant_bits=config['activation_bits'],
skip_pattern=config['not_quant_pattern'],
quantizable_op_type=quant_dequant_ops)
quant_dequant_pass.apply(main_graph)
if for_test: if for_test:
quant_program = main_graph.to_program() quant_program = main_graph.to_program()
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册