From 6557ef3d6891c67f571cfe9a1963a519e382612c Mon Sep 17 00:00:00 2001 From: zhouzj <41366441+zzjjay@users.noreply.github.com> Date: Wed, 28 Dec 2022 10:03:34 +0800 Subject: [PATCH] fix bug of pruning post-process nodes. (#1610) --- .../create_compressed_program.py | 21 +++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/paddleslim/auto_compression/create_compressed_program.py b/paddleslim/auto_compression/create_compressed_program.py index 677bba96..669d2433 100644 --- a/paddleslim/auto_compression/create_compressed_program.py +++ b/paddleslim/auto_compression/create_compressed_program.py @@ -121,11 +121,24 @@ def _get_distill_node(student_program, config): return node -def _get_target_node(distill_node): +def _get_target_node(distill_node, teacher=False): + tmp_nodes = set() + if isinstance(distill_node[0], list): + for n_list in distill_node: + for n in n_list: + tmp_nodes.add(n) + else: + for n in distill_node: + tmp_nodes.add(n) + targets = [] - for idx, node in enumerate(distill_node): - if idx % 2 != 0: + for node in tmp_nodes: + if teacher and 'teacher_' in node: + tmp = node.split('teacher_')[-1] + targets.append(tmp) + if not teacher and 'teacher_' not in node: targets.append(node) + return targets @@ -189,7 +202,7 @@ def _load_program_and_merge(executor, _remove_fetch_node(teacher_program) - target_nodes = _get_target_node(distill_node_pair) + target_nodes = _get_target_node(distill_node_pair, True) teacher_program = teacher_program._prune(target_nodes) data_name_map = {} -- GitLab