未验证 提交 6557ef3d 编写于 作者: Z zhouzj 提交者: GitHub

fix bug of pruning post-process nodes. (#1610)

上级 86f6115c
...@@ -121,11 +121,24 @@ def _get_distill_node(student_program, config): ...@@ -121,11 +121,24 @@ def _get_distill_node(student_program, config):
return node return node
def _get_target_node(distill_node): def _get_target_node(distill_node, teacher=False):
tmp_nodes = set()
if isinstance(distill_node[0], list):
for n_list in distill_node:
for n in n_list:
tmp_nodes.add(n)
else:
for n in distill_node:
tmp_nodes.add(n)
targets = [] targets = []
for idx, node in enumerate(distill_node): for node in tmp_nodes:
if idx % 2 != 0: if teacher and 'teacher_' in node:
tmp = node.split('teacher_')[-1]
targets.append(tmp)
if not teacher and 'teacher_' not in node:
targets.append(node) targets.append(node)
return targets return targets
...@@ -189,7 +202,7 @@ def _load_program_and_merge(executor, ...@@ -189,7 +202,7 @@ def _load_program_and_merge(executor,
_remove_fetch_node(teacher_program) _remove_fetch_node(teacher_program)
target_nodes = _get_target_node(distill_node_pair) target_nodes = _get_target_node(distill_node_pair, True)
teacher_program = teacher_program._prune(target_nodes) teacher_program = teacher_program._prune(target_nodes)
data_name_map = {} data_name_map = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册