未验证 提交 6557ef3d 编写于 作者: Z zhouzj 提交者: GitHub

fix bug of pruning post-process nodes. (#1610)

上级 86f6115c
......@@ -121,11 +121,24 @@ def _get_distill_node(student_program, config):
return node
def _get_target_node(distill_node):
def _get_target_node(distill_node, teacher=False):
tmp_nodes = set()
if isinstance(distill_node[0], list):
for n_list in distill_node:
for n in n_list:
tmp_nodes.add(n)
else:
for n in distill_node:
tmp_nodes.add(n)
targets = []
for idx, node in enumerate(distill_node):
if idx % 2 != 0:
for node in tmp_nodes:
if teacher and 'teacher_' in node:
tmp = node.split('teacher_')[-1]
targets.append(tmp)
if not teacher and 'teacher_' not in node:
targets.append(node)
return targets
......@@ -189,7 +202,7 @@ def _load_program_and_merge(executor,
_remove_fetch_node(teacher_program)
target_nodes = _get_target_node(distill_node_pair)
target_nodes = _get_target_node(distill_node_pair, True)
teacher_program = teacher_program._prune(target_nodes)
data_name_map = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册