未验证 提交 a59d7762 编写于 作者: C ceci3 提交者: GitHub

skip quant matmul in mha (#1232)

上级 171f5cfe
......@@ -313,6 +313,17 @@ class AutoCompression:
model_filename=model_filename, params_filename=params_filename,
executor=exe))
_, _, model_type = get_patterns(inference_program)
if self.model_filename is None:
new_model_filename = '__new_model__'
else:
new_model_filename = 'new_' + self.model_filename
program_bytes = inference_program._remove_training_info(
clip_extra=False).desc.serialize_to_string()
with open(os.path.join(self.model_dir, new_model_filename), "wb") as f:
f.write(program_bytes)
shutil.move(
os.path.join(self.model_dir, new_model_filename),
os.path.join(self.model_dir, self.model_filename))
_logger.info(f"Detect model type: {model_type}")
return model_type
......
......@@ -43,7 +43,7 @@ def find_final_nodes(program):
return final_nodes
def _is_mha(pattern_ops, pattern_ops_type):
def _is_mha(pattern_ops, pattern_ops_type, skip_quant_tensor_list=[]):
""" judge whether this pattern is multihead attention """
if pattern_ops_type.count('softmax') != 1 or pattern_ops_type.count(
'fetch') > 0:
......@@ -53,6 +53,7 @@ def _is_mha(pattern_ops, pattern_ops_type):
for op in pattern_ops:
if op.type() in ['matmul', 'matmul_v2']:
if not is_dynamic_weight_op(op):
skip_quant_tensor_list.extend(op._op.input('X'))
matmul_num += 1
if matmul_num == 2:
return True
......@@ -81,6 +82,7 @@ def _is_ffn(pattern_ops, pattern_ops_type):
def get_patterns(program, only_final_node=True):
""" distinguish the pattern in the program and get distillation node """
distill_node = []
skip_quant_tensor_list = []
patterns = {}
graph = GraphWrapper(program)
block_num = 0
......@@ -110,7 +112,8 @@ def get_patterns(program, only_final_node=True):
pattern_name = shortcut_start_op.type() + '$' + str(op.idx(
))
if _is_mha(pattern_ops, pattern_ops_type):
if _is_mha(pattern_ops, pattern_ops_type,
skip_quant_tensor_list):
model_type = 'transformer'
pattern_name = 'MHA$' + str(block_num)
......@@ -145,4 +148,12 @@ def get_patterns(program, only_final_node=True):
distill_node.append('teacher_' + out_var.name())
distill_node.append(out_var.name())
#### skip quant matmul in attention
if model_type == 'transformer':
for block_id in range(len(program.blocks)):
for op in program.blocks[block_id].ops:
for inp_name in op.input_arg_names:
if inp_name in skip_quant_tensor_list:
op._set_attr("op_namescope", "skip_quant")
return patterns, distill_node, model_type
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册