From 041d6211f2661d6aa029ee3b3e22856c16becae9 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Thu, 23 Jun 2022 17:35:17 +0800 Subject: [PATCH] fix find op itself in transformer prune (#1181) * fix find op itself in transformer prune * update --- .../auto_compression/transformer_pruner.py | 18 ++++++++++++++++++ paddleslim/auto_compression/utils/predict.py | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/paddleslim/auto_compression/transformer_pruner.py b/paddleslim/auto_compression/transformer_pruner.py index 8371bfab..0cb3011e 100644 --- a/paddleslim/auto_compression/transformer_pruner.py +++ b/paddleslim/auto_compression/transformer_pruner.py @@ -49,6 +49,22 @@ def find_next_ops(block, var_name): return res_ops +def find_op_itself(block, var_name, op_type): + """ + Find ops itself from block by the output variable. + """ + res_ops = [] + for op in block.ops: + if var_name in op.output_arg_names: + if op.type == op_type: + res_ops.append(op) + if len(res_ops) > 1: + _logger.error( + 'the function of find_op_itself has more than one op, maybe something wrong.' + ) + return res_ops + + def insert_eltmul_op(block, op, head_mask, block_num): """ Insert elementwise mul op to matmul input_mask and head_mask to program""" op_idx = block.ops.index(op) @@ -305,6 +321,8 @@ class TransformerPruner: next_op = find_next_ops(block, var_name) if next_op[0].type == 'dropout': op = next_op[0] + else: ### find op itself + op = find_op_itself(block, var_name, op.type())[0] insert_eltmul_op(block, op, head_mask, block_num) logits = block.var(fetch_list[0]) labels = block.create_var( diff --git a/paddleslim/auto_compression/utils/predict.py b/paddleslim/auto_compression/utils/predict.py index a6e5d219..af1a09b9 100644 --- a/paddleslim/auto_compression/utils/predict.py +++ b/paddleslim/auto_compression/utils/predict.py @@ -19,7 +19,7 @@ def with_variable_shape(model_dir, model_filename=None, params_filename=None): paddle.enable_static() exe = paddle.static.Executor(paddle.CPUPlace()) [inference_program, feed_target_names, fetch_targets] = ( - paddle.static.load_inference_model( + paddle.fluid.io.load_inference_model( model_dir, exe, model_filename=model_filename, -- GitLab