提交 aa63d5ac 编写于 作者: B Bai Yifan 提交者: whs

Make the distillation process not save teacher variables in PaddleSlim (#19633)

* split teacher checkpoints with student checkpoints, test=develop

* add unittest for graph.merge(), test=develop
上级 bb4f8dee
......@@ -205,6 +205,7 @@ class GraphWrapper(object):
super(GraphWrapper, self).__init__()
self.program = Program() if program is None else program
self.persistables = {}
self.teacher_persistables = {}
for var in self.program.list_vars():
if var.persistable:
self.persistables[var.name] = var
......@@ -306,6 +307,8 @@ class GraphWrapper(object):
graph(GraphWrapper): The graph to be merged by current graph.
"""
for var in graph.program.list_vars():
if var.persistable:
self.teacher_persistables[var.name] = var
new_var = self.program.global_block()._clone_variable(
var, force_persistable=False)
new_var.stop_gradient = var.stop_gradient
......@@ -479,7 +482,7 @@ class GraphWrapper(object):
self.persistables[var.name] = var
persistables = []
for var in self.persistables:
if 'reader' not in var and 'double_buffer' not in var:
if 'reader' not in var and 'double_buffer' not in var and var not in self.teacher_persistables:
persistables.append(self.persistables[var])
io.save_vars(exe.exe, path, vars=persistables)
......
......@@ -143,6 +143,11 @@ class TestGraphWrapper(unittest.TestCase):
self.build_program()
self.assertEquals(self.train_graph.flops(), 354624)
def test_merge(self):
self.build_program()
self.train_graph.merge(self.eval_graph)
self.assertEquals(len(self.train_graph.ops()), 72)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册