diff --git a/paddleslim/auto_compression/create_compressed_program.py b/paddleslim/auto_compression/create_compressed_program.py index 677bba963c89255b51d92102a96c7375d215d8c6..669d2433cbc0ad87bc5a36187908091fac17cffd 100644 --- a/paddleslim/auto_compression/create_compressed_program.py +++ b/paddleslim/auto_compression/create_compressed_program.py @@ -121,11 +121,24 @@ def _get_distill_node(student_program, config): return node -def _get_target_node(distill_node): +def _get_target_node(distill_node, teacher=False): + tmp_nodes = set() + if isinstance(distill_node[0], list): + for n_list in distill_node: + for n in n_list: + tmp_nodes.add(n) + else: + for n in distill_node: + tmp_nodes.add(n) + targets = [] - for idx, node in enumerate(distill_node): - if idx % 2 != 0: + for node in tmp_nodes: + if teacher and 'teacher_' in node: + tmp = node.split('teacher_')[-1] + targets.append(tmp) + if not teacher and 'teacher_' not in node: targets.append(node) + return targets @@ -189,7 +202,7 @@ def _load_program_and_merge(executor, _remove_fetch_node(teacher_program) - target_nodes = _get_target_node(distill_node_pair) + target_nodes = _get_target_node(distill_node_pair, True) teacher_program = teacher_program._prune(target_nodes) data_name_map = {}