diff --git a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py index b01c98aab9dae3296e19bf4108701e341d1f8ad9..fd248fc6f67d894358ca2f67b6d16b0bf14a391e 100644 --- a/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py +++ b/python/paddle/fluid/contrib/slim/graph/graph_wrapper.py @@ -205,6 +205,7 @@ class GraphWrapper(object): super(GraphWrapper, self).__init__() self.program = Program() if program is None else program self.persistables = {} + self.teacher_persistables = {} for var in self.program.list_vars(): if var.persistable: self.persistables[var.name] = var @@ -306,6 +307,8 @@ class GraphWrapper(object): graph(GraphWrapper): The graph to be merged by current graph. """ for var in graph.program.list_vars(): + if var.persistable: + self.teacher_persistables[var.name] = var new_var = self.program.global_block()._clone_variable( var, force_persistable=False) new_var.stop_gradient = var.stop_gradient @@ -479,7 +482,7 @@ class GraphWrapper(object): self.persistables[var.name] = var persistables = [] for var in self.persistables: - if 'reader' not in var and 'double_buffer' not in var: + if 'reader' not in var and 'double_buffer' not in var and var not in self.teacher_persistables: persistables.append(self.persistables[var]) io.save_vars(exe.exe, path, vars=persistables) diff --git a/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py b/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py index d8175e84124b4be54de8ec89ce792031c6845210..7d190ce0164315b2000e9e3c3bf26798b05b2bd3 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py +++ b/python/paddle/fluid/contrib/slim/tests/test_graph_wrapper.py @@ -143,6 +143,11 @@ class TestGraphWrapper(unittest.TestCase): self.build_program() self.assertEquals(self.train_graph.flops(), 354624) + def test_merge(self): + self.build_program() + self.train_graph.merge(self.eval_graph) + self.assertEquals(len(self.train_graph.ops()), 72) + if __name__ == '__main__': unittest.main()