From aa63d5ac6df57661f73ab7f6b6b93c88d759a115 Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Wed, 11 Sep 2019 11:30:28 +0800 Subject: [PATCH] Make the distillation process not save teacher variables in PaddleSlim (#19633) * split teacher checkpoints with student checkpoints, test=develop * add unittest for graph.merge(), test=develop --- python/paddle/fluid/contrib/slim/graph/graph_wrapper.py | 5 ++++- python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py | 5 +++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py index b01c98aab9d..fd248fc6f67 100644 --- a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py +++ b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py @@ -205,6 +205,7 @@ class GraphWrapper(object): super(GraphWrapper, self).__init__() self.program = Program() if program is None else program self.persistables = {} + self.teacher_persistables = {} for var in self.program.list_vars(): if var.persistable: self.persistables[var.name] = var @@ -306,6 +307,8 @@ class GraphWrapper(object): graph(GraphWrapper): The graph to be merged by current graph. """ for var in graph.program.list_vars(): + if var.persistable: + self.teacher_persistables[var.name] = var new_var = self.program.global_block()._clone_variable( var, force_persistable=False) new_var.stop_gradient = var.stop_gradient @@ -479,7 +482,7 @@ class GraphWrapper(object): self.persistables[var.name] = var persistables = [] for var in self.persistables: - if 'reader' not in var and 'double_buffer' not in var: + if 'reader' not in var and 'double_buffer' not in var and var not in self.teacher_persistables: persistables.append(self.persistables[var]) io.save_vars(exe.exe, path, vars=persistables) diff --git a/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py b/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py index d8175e84124..7d190ce0164 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py +++ b/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py @@ -143,6 +143,11 @@ class TestGraphWrapper(unittest.TestCase): self.build_program() self.assertEquals(self.train_graph.flops(), 354624) + def test_merge(self): + self.build_program() + self.train_graph.merge(self.eval_graph) + self.assertEquals(len(self.train_graph.ops()), 72) + if __name__ == '__main__': unittest.main() -- GitLab