From ac6ef06ffaddd198838986e0d421331b119fc1b9 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 7 Mar 2019 22:18:56 +0800 Subject: [PATCH] Add the Clone method in Graph. test=develop --- paddle/fluid/framework/ir/graph.cc | 33 +++++++++++++++++++ paddle/fluid/framework/ir/graph.h | 4 +++ paddle/fluid/framework/ir/node.h | 1 + paddle/fluid/pybind/ir.cc | 2 ++ .../fluid/contrib/slim/tests/test_graph.py | 16 ++++----- python/paddle/fluid/framework.py | 2 +- 6 files changed, 49 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 0721ee230..6a9340b87 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -152,6 +152,39 @@ void Graph::ResolveHazard( } } +std::shared_ptr Graph::Clone() { + auto cloned_graph = std::make_shared(this->program_); + cloned_graph->ReleaseNodes(); + cloned_graph->num_node_created_ = 0; + std::unordered_map origin_to_cloned; + for (auto *n : this->node_set_) { + ir::Node *cloned_node = nullptr; + if (n->IsCtrlVar()) { + cloned_node = cloned_graph->CreateControlDepVar(); + } else if (!n->var_desc_ && !n->op_desc_) { // empty node + cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType()); + } else if (n->IsVar()) { + cloned_node = cloned_graph->CreateVarNode(n->Var()); + } else if (n->IsOp()) { + cloned_node = cloned_graph->CreateOpNode(n->Op()); + } + if (cloned_node) { + origin_to_cloned[n] = cloned_node; + } else { + PADDLE_THROW("The cloned node's type is not supported!"); + } + } + for (auto *n : this->node_set_) { + for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) { + origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]); + } + for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) { + origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]); + } + } + return cloned_graph; +} + bool IsControlDepVar(const ir::Node &var) { return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos; } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index cfd82dcb0..44ba4d3d2 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -213,6 +213,10 @@ class Graph { void ResolveHazard( const std::map> &var_nodes); + // Create a new and duplicated graph. + // WARN: The method only clones the graph structure, not its attributes. + std::shared_ptr Clone(); + private: std::map> InitFromProgram( const ProgramDesc &program); diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 9eade9eaa..72fb876d9 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 298f976bc..c69ccd507 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" @@ -54,6 +55,7 @@ void BindGraph(py::module *m) { "The graph is a Directed Acyclic Single Static Assignment Graph, see " "`paddle::ir::Graph` for details.") .def(py::init()) + .def("clone", &Graph::Clone) .def("has", &Graph::Has) .def("get_int", &Graph::Get) .def("get_float", &Graph::Get) diff --git a/python/paddle/fluid/contrib/slim/tests/test_graph.py b/python/paddle/fluid/contrib/slim/tests/test_graph.py index 5165865e9..3245dfc06 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_graph.py +++ b/python/paddle/fluid/contrib/slim/tests/test_graph.py @@ -60,20 +60,12 @@ class TestGraph(unittest.TestCase): opt = fluid.optimizer.Adam(learning_rate=0.001) opt.minimize(loss) graph = IrGraph(core.Graph(main.desc), for_test=False) - backup_graph = graph.clone() - self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes())) - marked_nodes = set() for op in graph.all_op_nodes(): if op.name().find('conv2d') > -1: marked_nodes.add(op) if not for_ci: graph.draw('.', 'residual', marked_nodes) - backup_marked_nodes = set() - for op in backup_graph.all_op_nodes(): - if op.name().find('conv2d') > -1: - backup_marked_nodes.add(op) - backup_graph.draw('.', 'backup', backup_marked_nodes) self.assertFalse(graph.has_circle()) self.assertEqual(graph.graph_num(), 1) nodes = graph.topology_sort() @@ -83,6 +75,14 @@ class TestGraph(unittest.TestCase): nodes_num = len(graph.all_nodes()) graph.safe_remove_nodes(marked_nodes) self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes)) + backup_graph = graph.clone() + self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes())) + if not for_ci: + backup_marked_nodes = set() + for op in backup_graph.all_op_nodes(): + if op.name().find('conv2d') > -1: + backup_marked_nodes.add(op) + backup_graph.draw('.', 'backup', backup_marked_nodes) if __name__ == '__main__': diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 3c36eb645..31f90f5f5 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2012,7 +2012,7 @@ class IrGraph(object): Returns: IrGraph: A new and duplicated graph. """ - g = core.Graph(self.graph.origin_program_desc()) + g = self.graph.clone() return IrGraph(g, self._for_test) def is_test(self): -- GitLab