未验证 提交 041d6211 编写于 作者: C ceci3 提交者: GitHub

fix find op itself in transformer prune (#1181)

* fix find op itself in transformer prune

* update
上级 a4e43aff
...@@ -49,6 +49,22 @@ def find_next_ops(block, var_name): ...@@ -49,6 +49,22 @@ def find_next_ops(block, var_name):
return res_ops return res_ops
def find_op_itself(block, var_name, op_type):
"""
Find ops itself from block by the output variable.
"""
res_ops = []
for op in block.ops:
if var_name in op.output_arg_names:
if op.type == op_type:
res_ops.append(op)
if len(res_ops) > 1:
_logger.error(
'the function of find_op_itself has more than one op, maybe something wrong.'
)
return res_ops
def insert_eltmul_op(block, op, head_mask, block_num): def insert_eltmul_op(block, op, head_mask, block_num):
""" Insert elementwise mul op to matmul input_mask and head_mask to program""" """ Insert elementwise mul op to matmul input_mask and head_mask to program"""
op_idx = block.ops.index(op) op_idx = block.ops.index(op)
...@@ -305,6 +321,8 @@ class TransformerPruner: ...@@ -305,6 +321,8 @@ class TransformerPruner:
next_op = find_next_ops(block, var_name) next_op = find_next_ops(block, var_name)
if next_op[0].type == 'dropout': if next_op[0].type == 'dropout':
op = next_op[0] op = next_op[0]
else: ### find op itself
op = find_op_itself(block, var_name, op.type())[0]
insert_eltmul_op(block, op, head_mask, block_num) insert_eltmul_op(block, op, head_mask, block_num)
logits = block.var(fetch_list[0]) logits = block.var(fetch_list[0])
labels = block.create_var( labels = block.create_var(
......
...@@ -19,7 +19,7 @@ def with_variable_shape(model_dir, model_filename=None, params_filename=None): ...@@ -19,7 +19,7 @@ def with_variable_shape(model_dir, model_filename=None, params_filename=None):
paddle.enable_static() paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
[inference_program, feed_target_names, fetch_targets] = ( [inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model( paddle.fluid.io.load_inference_model(
model_dir, model_dir,
exe, exe,
model_filename=model_filename, model_filename=model_filename,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册