From a59d7762deff8dc9f017261ccb5a49187b3b3dbf Mon Sep 17 00:00:00 2001 From: ceci3 Date: Mon, 4 Jul 2022 11:02:39 +0800 Subject: [PATCH] skip quant matmul in mha (#1232) --- paddleslim/auto_compression/compressor.py | 11 +++++++++++ paddleslim/common/patterns.py | 15 +++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/paddleslim/auto_compression/compressor.py b/paddleslim/auto_compression/compressor.py index cdba847b..7d21bbe7 100644 --- a/paddleslim/auto_compression/compressor.py +++ b/paddleslim/auto_compression/compressor.py @@ -313,6 +313,17 @@ class AutoCompression: model_filename=model_filename, params_filename=params_filename, executor=exe)) _, _, model_type = get_patterns(inference_program) + if self.model_filename is None: + new_model_filename = '__new_model__' + else: + new_model_filename = 'new_' + self.model_filename + program_bytes = inference_program._remove_training_info( + clip_extra=False).desc.serialize_to_string() + with open(os.path.join(self.model_dir, new_model_filename), "wb") as f: + f.write(program_bytes) + shutil.move( + os.path.join(self.model_dir, new_model_filename), + os.path.join(self.model_dir, self.model_filename)) _logger.info(f"Detect model type: {model_type}") return model_type diff --git a/paddleslim/common/patterns.py b/paddleslim/common/patterns.py index 5f19ad7e..def7faa4 100644 --- a/paddleslim/common/patterns.py +++ b/paddleslim/common/patterns.py @@ -43,7 +43,7 @@ def find_final_nodes(program): return final_nodes -def _is_mha(pattern_ops, pattern_ops_type): +def _is_mha(pattern_ops, pattern_ops_type, skip_quant_tensor_list=[]): """ judge whether this pattern is multihead attention """ if pattern_ops_type.count('softmax') != 1 or pattern_ops_type.count( 'fetch') > 0: @@ -53,6 +53,7 @@ def _is_mha(pattern_ops, pattern_ops_type): for op in pattern_ops: if op.type() in ['matmul', 'matmul_v2']: if not is_dynamic_weight_op(op): + skip_quant_tensor_list.extend(op._op.input('X')) matmul_num += 1 if matmul_num == 2: return True @@ -81,6 +82,7 @@ def _is_ffn(pattern_ops, pattern_ops_type): def get_patterns(program, only_final_node=True): """ distinguish the pattern in the program and get distillation node """ distill_node = [] + skip_quant_tensor_list = [] patterns = {} graph = GraphWrapper(program) block_num = 0 @@ -110,7 +112,8 @@ def get_patterns(program, only_final_node=True): pattern_name = shortcut_start_op.type() + '$' + str(op.idx( )) - if _is_mha(pattern_ops, pattern_ops_type): + if _is_mha(pattern_ops, pattern_ops_type, + skip_quant_tensor_list): model_type = 'transformer' pattern_name = 'MHA$' + str(block_num) @@ -145,4 +148,12 @@ def get_patterns(program, only_final_node=True): distill_node.append('teacher_' + out_var.name()) distill_node.append(out_var.name()) + #### skip quant matmul in attention + if model_type == 'transformer': + for block_id in range(len(program.blocks)): + for op in program.blocks[block_id].ops: + for inp_name in op.input_arg_names: + if inp_name in skip_quant_tensor_list: + op._set_attr("op_namescope", "skip_quant") + return patterns, distill_node, model_type -- GitLab