You need to sign in or sign up before continuing.
提交 aa63d5ac 编写于 作者: B Bai Yifan 提交者: whs

Make the distillation process not save teacher variables in PaddleSlim (#19633)

* split teacher checkpoints with student checkpoints, test=develop

* add unittest for graph.merge(), test=develop
上级 bb4f8dee
...@@ -205,6 +205,7 @@ class GraphWrapper(object): ...@@ -205,6 +205,7 @@ class GraphWrapper(object):
super(GraphWrapper, self).__init__() super(GraphWrapper, self).__init__()
self.program = Program() if program is None else program self.program = Program() if program is None else program
self.persistables = {} self.persistables = {}
self.teacher_persistables = {}
for var in self.program.list_vars(): for var in self.program.list_vars():
if var.persistable: if var.persistable:
self.persistables[var.name] = var self.persistables[var.name] = var
...@@ -306,6 +307,8 @@ class GraphWrapper(object): ...@@ -306,6 +307,8 @@ class GraphWrapper(object):
graph(GraphWrapper): The graph to be merged by current graph. graph(GraphWrapper): The graph to be merged by current graph.
""" """
for var in graph.program.list_vars(): for var in graph.program.list_vars():
if var.persistable:
self.teacher_persistables[var.name] = var
new_var = self.program.global_block()._clone_variable( new_var = self.program.global_block()._clone_variable(
var, force_persistable=False) var, force_persistable=False)
new_var.stop_gradient = var.stop_gradient new_var.stop_gradient = var.stop_gradient
...@@ -479,7 +482,7 @@ class GraphWrapper(object): ...@@ -479,7 +482,7 @@ class GraphWrapper(object):
self.persistables[var.name] = var self.persistables[var.name] = var
persistables = [] persistables = []
for var in self.persistables: for var in self.persistables:
if 'reader' not in var and 'double_buffer' not in var: if 'reader' not in var and 'double_buffer' not in var and var not in self.teacher_persistables:
persistables.append(self.persistables[var]) persistables.append(self.persistables[var])
io.save_vars(exe.exe, path, vars=persistables) io.save_vars(exe.exe, path, vars=persistables)
......
...@@ -143,6 +143,11 @@ class TestGraphWrapper(unittest.TestCase): ...@@ -143,6 +143,11 @@ class TestGraphWrapper(unittest.TestCase):
self.build_program() self.build_program()
self.assertEquals(self.train_graph.flops(), 354624) self.assertEquals(self.train_graph.flops(), 354624)
def test_merge(self):
self.build_program()
self.train_graph.merge(self.eval_graph)
self.assertEquals(len(self.train_graph.ops()), 72)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册