diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 5e954fa9c419b249bb8a4be5a78c01da85b017b2..38852eb7d06c5442a8d2fa2ddfb3228b5b8ec247 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include #include "paddle/fluid/framework/ir/graph.h" @@ -29,6 +30,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ResolveHazard(var_nodes); } +Graph::Graph(const Graph &o) : Graph(o.program_) {} + std::map> Graph::InitFromProgram( const ProgramDesc &program) { VLOG(3) << "block in program:" << program_.Size(); diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index cfd974e4bd679fdd06739f4c943bb197865020fb..9c5dbc0455f3317b5ed137118cf3c480de15ed9a 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include "paddle/fluid/framework/ir/node.h" @@ -71,6 +72,7 @@ namespace ir { class Graph { public: explicit Graph(const ProgramDesc &program); + Graph(const Graph &o); virtual ~Graph() { for (auto &attr : attrs_) { diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 68f74a8531fff0c49c8a62d12f5cde7af77faf8a..160d8d05c0a9501a884ea5c5bc1190e4d9663a1f 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -54,6 +54,8 @@ void BindGraph(py::module *m) { "The graph is a Directed Acyclic Single Static Assignment Graph, see " "`paddle::ir::Graph` for details.") .def(py::init()) + .def("__init__", + [](Graph &self, const Graph &other) { new (&self) Graph(other); }) .def("has", &Graph::Has) .def("get_int", &Graph::Get) .def("get_float", &Graph::Get) diff --git a/python/paddle/fluid/contrib/slim/tests/test_graph.py b/python/paddle/fluid/contrib/slim/tests/test_graph.py index 2d2f1384dec65ee19dcade8a46f80bd3f9eb7013..5165865e9b546fde09f80a06077c146edfbdc92f 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_graph.py +++ b/python/paddle/fluid/contrib/slim/tests/test_graph.py @@ -52,7 +52,7 @@ def residual_block(num): class TestGraph(unittest.TestCase): - def test_graph_functions(self): + def test_graph_functions(self, for_ci=True): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -60,11 +60,20 @@ class TestGraph(unittest.TestCase): opt = fluid.optimizer.Adam(learning_rate=0.001) opt.minimize(loss) graph = IrGraph(core.Graph(main.desc), for_test=False) + backup_graph = graph.clone() + self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes())) + marked_nodes = set() for op in graph.all_op_nodes(): if op.name().find('conv2d') > -1: marked_nodes.add(op) - graph.draw('.', 'residual', marked_nodes) + if not for_ci: + graph.draw('.', 'residual', marked_nodes) + backup_marked_nodes = set() + for op in backup_graph.all_op_nodes(): + if op.name().find('conv2d') > -1: + backup_marked_nodes.add(op) + backup_graph.draw('.', 'backup', backup_marked_nodes) self.assertFalse(graph.has_circle()) self.assertEqual(graph.graph_num(), 1) nodes = graph.topology_sort() diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5b9dd8693190d1c9a6a71cc6bc6e6acf2a06aa6f..6cfc433d6485ff1fed43c8843ff48f3fa460c4ad 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2002,6 +2002,16 @@ class IrGraph(object): self.graph = graph self._for_test = for_test + def clone(self): + """ + Create a new and duplicated IrGraph. + + Returns: + IrGraph: A new and duplicated graph. + """ + g = core.Graph(self.graph) + return IrGraph(g, self._for_test) + def is_test(self): """ If the graph is used for testing, the function returns true. Otherwise, returns false.