diff --git a/paddleslim/common/patterns.py b/paddleslim/common/patterns.py index def7faa4489c273f76b824ec64fc51751fd24a71..c5047d1936cb61705a1cc7b2bcc9152300d9c510 100644 --- a/paddleslim/common/patterns.py +++ b/paddleslim/common/patterns.py @@ -53,7 +53,6 @@ def _is_mha(pattern_ops, pattern_ops_type, skip_quant_tensor_list=[]): for op in pattern_ops: if op.type() in ['matmul', 'matmul_v2']: if not is_dynamic_weight_op(op): - skip_quant_tensor_list.extend(op._op.input('X')) matmul_num += 1 if matmul_num == 2: return True @@ -88,6 +87,8 @@ def get_patterns(program, only_final_node=True): block_num = 0 model_type = None for op in graph.ops(): + if len(op.all_inputs()) == 0 or op.all_inputs()[0] is None: + continue belonged_teacher = False for inp in op.all_inputs(): if 'teacher' in inp._var.name: @@ -106,8 +107,9 @@ def get_patterns(program, only_final_node=True): out_var_name = op.all_outputs()[0]._var.name shortcut_start_op = shortcut_start_op[0] + next_op = graph.next_ops(op) pattern_ops, pattern_ops_type = traversal_ops( - shortcut_start_op, graph, op.idx()) + shortcut_start_op, graph, next_op[0].idx()) pattern_name = shortcut_start_op.type() + '$' + str(op.idx( )) diff --git a/paddleslim/common/patterns_common.py b/paddleslim/common/patterns_common.py index d4e24b5f9f1cb667ddab565b2b4d5db4758fd0b2..c19e7ee72eff7a30305647bf91dad5bfa6117891 100644 --- a/paddleslim/common/patterns_common.py +++ b/paddleslim/common/patterns_common.py @@ -51,6 +51,7 @@ def get_weight(op, return_name=True): return inp.name() else: return inp + return None def is_dynamic_weight_op(op): diff --git a/paddleslim/core/graph_wrapper.py b/paddleslim/core/graph_wrapper.py index 14fc925605df8ed188b94240e38a8b10efa79dd1..19f39d9066f2d4fc45674d425a0dd1230a97cd9f 100644 --- a/paddleslim/core/graph_wrapper.py +++ b/paddleslim/core/graph_wrapper.py @@ -357,7 +357,8 @@ class GraphWrapper(object): ops = [] for p in self.ops(): for out_var in op.all_outputs(): - if out_var in p.all_inputs(): + if len(p.all_inputs()) > 0 and p.all_inputs()[ + 0] is not None and out_var in p.all_inputs(): if p.idx() != op.idx(): ops.append(p) return sorted(ops) diff --git a/paddleslim/quant/quanter.py b/paddleslim/quant/quanter.py index 9f8ed323e3ab6eb86f9a476d9895d89dddf31eca..aab7252342c3fc5d8927d440ede561f9d79ad75f 100755 --- a/paddleslim/quant/quanter.py +++ b/paddleslim/quant/quanter.py @@ -33,12 +33,16 @@ from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass from ..common import get_logger +from ..common.patterns import get_patterns +from ..common.patterns_common import is_dynamic_weight_op, get_weight +from ..core.graph_wrapper import GraphWrapper _logger = get_logger(__name__, level=logging.INFO) try: from paddle.fluid.contrib.slim.quantization import QuantizationTransformPassV2 from paddle.fluid.contrib.slim.quantization import QuantWeightPass from paddle.fluid.contrib.slim.quantization import AddQuantDequantPassV2 + from paddle.fluid.contrib.slim.quantization import PostTrainingQuantizationProgram except: _logger.warning( "Some functions fail to import, please update PaddlePaddle version to 2.3+" @@ -98,6 +102,10 @@ _quant_config_default = { 'is_full_quantize': False, # if True, use onnx format to quant. 'onnx_format': False, + # quant post to get initial scale for quant_aware + 'quant_post_first': False, + # whether scale can be train + 'scale_trainable': True } @@ -254,7 +262,12 @@ def quant_aware(program, optimizer_func=None, executor=None, return_program=False, - draw_graph=False): + calib_config={}, + draw_graph=False, + return_scale_dict=False, + scale_dict=None, + model_type=None, + pattern_ops=None): """Add quantization and dequantization operators to "program" for quantization training or testing. @@ -271,7 +284,7 @@ def quant_aware(program, Default: ``None``. for_test(bool): If the 'program' parameter is a test program, this parameter should be set to ``True``. Otherwise, set to ``False``.Default: False - weight_quantize_func(function): Function that defines how to quantize weight. Using this + weight_quantize_func(function): Function that defines how to quantize weight. Using this can quickly test if user's quantization method works or not. In this function, user should both define quantization function and dequantization function, that is, the function's input is non-quantized weight and function returns dequantized weight. If None, will use @@ -301,6 +314,12 @@ def quant_aware(program, Default is False. draw_graph(bool): whether to draw graph when quantization is initialized. In order to prevent cycle, the ERNIE model needs to be set to True. Default is False. + return_scale_dict(bool): If user want to return scale dict, model_type and pattern_ops, this argument should be set True. + Default is False. + scale_dict(dict): Use scale dict to initialize scales in program. Default is None. + model_type(str): Model type can be 'transformer' or 'non-transformer'. If model type is transformer, patterns will be analyzed. + Default is None. + pattern_ops(dict): Pattern_ops contain pattern name and corresponding ops. Default is None. Returns: paddle.static.CompiledProgram | paddle.static.Program: Program with quantization and dequantization ``operators`` """ @@ -313,52 +332,164 @@ def quant_aware(program, config = _parse_configs(config) _logger.info("quant_aware config {}".format(config)) - main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) - - transform_pass_ops = [] - quant_dequant_ops = [] - for op_type in config['quantize_op_types']: - if op_type in TRANSFORM_PASS_OP_TYPES: - transform_pass_ops.append(op_type) - elif op_type in QUANT_DEQUANT_PASS_OP_TYPES: - quant_dequant_ops.append(op_type) - if len(transform_pass_ops) > 0: - trannsform_func = 'QuantizationTransformPassV2' if config[ - 'onnx_format'] else 'QuantizationTransformPass' - transform_pass = eval(trannsform_func)( - scope=scope, - place=place, - weight_bits=config['weight_bits'], - activation_bits=config['activation_bits'], - activation_quantize_type=config['activation_quantize_type'], - weight_quantize_type=config['weight_quantize_type'], - window_size=config['window_size'], - moving_rate=config['moving_rate'], - quantizable_op_type=transform_pass_ops, - skip_pattern=config['not_quant_pattern'], - weight_quantize_func=weight_quantize_func, - act_quantize_func=act_quantize_func, - weight_preprocess_func=weight_preprocess_func, - act_preprocess_func=act_preprocess_func, - optimizer_func=optimizer_func, - executor=executor) - - transform_pass.apply(main_graph) - - if len(quant_dequant_ops) > 0: - qdq_func = 'AddQuantDequantPassV2' if config[ - 'onnx_format'] else 'AddQuantDequantPass' - quant_dequant_pass = eval(qdq_func)( - scope=scope, - place=place, - moving_rate=config['moving_rate'], - quant_bits=config['activation_bits'], - skip_pattern=config['not_quant_pattern'], - quantizable_op_type=quant_dequant_ops) - quant_dequant_pass.apply(main_graph) + def find_next_ops(program, var_name): + """ + Find all followed ops for the input variable. + """ + block = program.global_block() + res_ops = [] + for op in block.ops: + if var_name in op.input_arg_names: + res_ops.append(op) + return res_ops + + def find_pre_ops(program, var_name): + """ + Find all followed ops for the input variable. + """ + block = program.global_block() + res_ops = [] + for op in block.ops: + if var_name in op.output_arg_names: + res_ops.append(op) + return res_ops + + def _is_skip_layernorm(program, op): + if get_weight(op) is not None: + return False + + output_names = op._op.output_arg_names + for output_name in output_names: + for next_op in find_next_ops(program, output_name): + if next_op.type == 'layer_norm': + return True + return False + + skip_tensor_list = [] + same_scale_tensor_list = [] + if model_type == 'transformer' and pattern_ops is None: + pattern_ops, _, model_type = get_patterns(program) + if model_type != 'transformer': + _logger.info( + 'Warning! After analysis, the real model type is not transformer! If you encounter this situation, please raise an issue let us know in which case "get_patterns" determines model type is not transformer.' + ) + if model_type == 'transformer': + not_skip_quant_list = [] + for part_name, ops in pattern_ops.items(): + if 'MHA' in part_name: + qkv_weight_tensor = [] + qkv_output_tensor = [] + ### get qkv + output_names = ops[0]._op.output_arg_names + for output_name in output_names: + for next_op in find_next_ops(program, output_name): + if next_op.type in ['mul', 'matmul_v2']: + qkv_weight_tensor.append(next_op.input('Y')[0]) + + same_scale_tensor_list.append(qkv_weight_tensor) + + for op in ops: + if op._op.type in ['matmul', 'matmul_v2'] and ( + not is_dynamic_weight_op(op)): + input_names = op._op.input_arg_names + for input_name in input_names: + pre_op = find_pre_ops(program, input_name)[0] + if pre_op.type == 'softmax' or pre_op.type == 'dropout': + continue + elif pre_op.type == 'scale': + qkv_output_tensor.append( + input_name + '#/#{}'.format( + pre_op.attr('scale'))) + else: + qkv_output_tensor.append(input_name) + elif op._op.type == 'elementwise_add': + if _is_skip_layernorm(program, op): + not_skip_quant_list.append(op) + same_scale_tensor_list.append(qkv_output_tensor) + elif 'FFN' in part_name: + for op in ops: + if op._op.type == 'elementwise_add': + if _is_skip_layernorm(program, op): + not_skip_quant_list.append(op) + tmp_graph = GraphWrapper(program) + for op in tmp_graph.ops(): + ### find elementwise_add in skip layernorm + if op._op.type == 'elementwise_add' and op not in not_skip_quant_list: + op._op._set_attr("op_namescope", "skip_quant") + + is_test = True if for_test else not config['scale_trainable'] + if config['quant_post_first'] and for_test: + if 'quantizable_op_type' not in calib_config: + calib_config['quantizable_op_type'] = config['quantize_op_types'] + exe = paddle.static.Executor() if executor is None else executor + post_training_quantization = PostTrainingQuantizationProgram( + exe, + program, + freeze_model=False, + skip_tensor_list=skip_tensor_list, + same_scale_tensor_list=same_scale_tensor_list, + scale_trainable=config['scale_trainable'], + batch_nums=10, + scale_dict=scale_dict, + return_graph=True, + **calib_config) + main_graph = post_training_quantization.quantize() + scale_dict = post_training_quantization._scale_dict + else: + main_graph = IrGraph(core.Graph(program.desc), for_test=for_test) + transform_pass_ops = [] + quant_dequant_ops = [] + for op_type in config['quantize_op_types']: + if op_type in TRANSFORM_PASS_OP_TYPES: + transform_pass_ops.append(op_type) + elif op_type in QUANT_DEQUANT_PASS_OP_TYPES: + quant_dequant_ops.append(op_type) + if len(transform_pass_ops) > 0: + trannsform_func = 'QuantizationTransformPassV2' if config[ + 'onnx_format'] else 'QuantizationTransformPass' + transform_pass = eval(trannsform_func)( + scope=scope, + place=place, + weight_bits=config['weight_bits'], + activation_bits=config['activation_bits'], + activation_quantize_type=config['activation_quantize_type'], + weight_quantize_type=config['weight_quantize_type'], + window_size=config['window_size'], + moving_rate=config['moving_rate'], + quantizable_op_type=transform_pass_ops, + skip_pattern=config['not_quant_pattern'], + weight_quantize_func=weight_quantize_func, + act_quantize_func=act_quantize_func, + weight_preprocess_func=weight_preprocess_func, + act_preprocess_func=act_preprocess_func, + optimizer_func=optimizer_func, + executor=executor, + is_test=is_test) + + transform_pass.apply(main_graph) + + if len(quant_dequant_ops) > 0: + qdq_func = 'AddQuantDequantPassV2' if config[ + 'onnx_format'] else 'AddQuantDequantPass' + quant_dequant_pass = eval(qdq_func)( + scope=scope, + place=place, + moving_rate=config['moving_rate'], + quant_bits=config['activation_bits'], + skip_pattern=config['not_quant_pattern'], + quantizable_op_type=quant_dequant_ops, + is_test=is_test, + scale_dict=scale_dict) + + quant_dequant_pass.apply(main_graph) out_scale_training_pass = OutScaleForTrainingPass( - scope=scope, place=place, moving_rate=config['moving_rate']) + scope=scope, + place=place, + moving_rate=config['moving_rate'], + is_test=is_test, + scale_dict=scale_dict) + out_scale_training_pass.apply(main_graph) if (weight_preprocess_func is not None or @@ -378,7 +509,11 @@ def quant_aware(program, quant_program = main_graph.to_program() else: quant_program = paddle.static.CompiledProgram(main_graph.graph) - return quant_program + + if return_scale_dict: + return quant_program, scale_dict, model_type, pattern_ops + else: + return quant_program def quant_post_static( diff --git a/tests/test_quant_post_quant_aware.py b/tests/test_quant_post_quant_aware.py new file mode 100644 index 0000000000000000000000000000000000000000..b4531f954b1daab295799896a2e6cc0321890d2a --- /dev/null +++ b/tests/test_quant_post_quant_aware.py @@ -0,0 +1,161 @@ +import sys +import random +sys.path.append("../") +import unittest +import paddle +import paddle.nn as nn +from paddle.io import Dataset +from paddleslim.quant import quant_aware, convert +from paddle.nn import TransformerEncoderLayer, TransformerEncoder, Linear +from paddleslim.quant import quant_aware, convert +from static_case import StaticCase +sys.path.append("../demo") +from models import MobileNet +from layers import conv_bn_layer +import paddle.dataset.mnist as reader +from paddle.fluid.framework import IrGraph +from paddle.fluid import core +import numpy as np + +np.random.seed(0) +random.seed(0) +paddle.seed(0) + + +class RandomDataset(Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + enc_input = np.random.random([4, 128]).astype('float32') + attn_mask = np.random.random([2, 4, 4]).astype('float32') + label = np.random.randint(0, 2, (1, )).astype('int64') + return enc_input, attn_mask, label + + def __len__(self): + return self.num_samples + + +class TestQuantPostQuantAwareCase1(StaticCase): + def test_accuracy(self): + def simple_transformer(enc_input, attn_mask): + encoder_layer = nn.TransformerEncoderLayer(128, 2, 512) + encoder = TransformerEncoder(encoder_layer, 2) + encoder_output = encoder(enc_input, attn_mask) + first_token = encoder_output[:, 0] + bias = paddle.full(shape=[1, 128], fill_value=1e-6) + linear = Linear(128, 2) + logits = linear(first_token + bias) + return logits + + enc_input = paddle.static.data( + name='enc_input', shape=[None, 4, 128], dtype='float32') + attn_mask = paddle.static.data( + name='attn_mask', shape=[None, 2, 4, 4], dtype='float32') + label = paddle.static.data(name='label', shape=[None, 1], dtype='int64') + out = simple_transformer(enc_input, attn_mask) + cost = paddle.nn.functional.loss.cross_entropy(input=out, label=label) + avg_cost = paddle.mean(x=cost) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + optimizer = paddle.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + weight_decay=paddle.regularizer.L2Decay(4e-5)) + optimizer.minimize(avg_cost) + main_prog = paddle.static.default_main_program() + val_prog = main_prog.clone(for_test=True) + + place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( + ) else paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + + train_dataset = RandomDataset(100) + test_dataset = RandomDataset(50) + train_loader = paddle.io.DataLoader( + train_dataset, + places=place, + feed_list=[enc_input, attn_mask, label], + drop_last=True, + return_list=False, + batch_size=10) + valid_loader = paddle.io.DataLoader( + test_dataset, + places=place, + feed_list=[enc_input, attn_mask, label], + batch_size=10, + return_list=False) + + def train(program): + iter = 0 + for data in train_loader(): + cost, top1 = exe.run(program, + feed=data, + fetch_list=[avg_cost, acc_top1]) + iter += 1 + if iter % 100 == 0: + print('train iter={}, avg loss {}, acc_top1 {}'.format( + iter, cost, top1)) + + def test(program): + iter = 0 + result = [[], []] + for data in valid_loader(): + cost, top1 = exe.run(program, + feed=data, + fetch_list=[avg_cost, acc_top1]) + iter += 1 + if iter % 100 == 0: + print('eval iter={}, avg loss {}, acc_top1 {}'.format( + iter, cost, top1)) + result[0].append(cost) + result[1].append(top1) + print(' avg loss {}, acc_top1 {}'.format( + np.mean(result[0]), np.mean(result[1]))) + return np.mean(result[1]) + + train(main_prog) + top1_1 = test(main_prog) + + config = { + 'weight_quantize_type': 'channel_wise_abs_max', + 'activation_quantize_type': 'moving_average_abs_max', + 'quantize_op_types': + ['conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'elementwise_add'], + 'quant_post_first': True, + 'scale_trainable': True + } + calib_config = { + 'data_loader': valid_loader, + 'algo': 'abs_max', + 'feed_list': ['enc_input', 'attn_mask', 'label'], + 'fetch_list': [avg_cost, acc_top1] + } + quant_eval_prog, scale_dict, _, _ = quant_aware( + val_prog, + place, + config, + for_test=True, + calib_config=calib_config, + model_type='transformer', + return_scale_dict=True) + quant_train_prog = quant_aware( + main_prog, + place, + config, + for_test=False, + calib_config=calib_config, + return_program=True, + scale_dict=scale_dict, + model_type='transformer') + train(quant_train_prog) + quant_eval_prog, int8_prog = convert( + quant_eval_prog, place, config, save_int8=True) + top1_2 = test(quant_eval_prog) + # values before quantization and after quantization should be close + print("before quantization: top1: {}".format(top1_1)) + print("after quantization: top1: {}".format(top1_2)) + + +if __name__ == '__main__': + unittest.main()