From 1b9c8d5f061d01c022bc8136defb7e681d16f57a Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 7 Mar 2019 17:05:40 +0800 Subject: [PATCH] add clone function for IrGraph. test=develop --- paddle/fluid/framework/ir/graph.cc | 3 +++ paddle/fluid/framework/ir/graph.h | 2 ++ paddle/fluid/pybind/ir.cc | 2 ++ .../paddle/fluid/contrib/slim/tests/test_graph.py | 13 +++++++++++-- python/paddle/fluid/framework.py | 10 ++++++++++ 5 files changed, 28 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 5e954fa9c..38852eb7d 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include #include "paddle/fluid/framework/ir/graph.h" @@ -29,6 +30,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ResolveHazard(var_nodes); } +Graph::Graph(const Graph &o) : Graph(o.program_) {} + std::map> Graph::InitFromProgram( const ProgramDesc &program) { VLOG(3) << "block in program:" << program_.Size(); diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index cfd974e4b..9c5dbc045 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include "paddle/fluid/framework/ir/node.h" @@ -71,6 +72,7 @@ namespace ir { class Graph { public: explicit Graph(const ProgramDesc &program); + Graph(const Graph &o); virtual ~Graph() { for (auto &attr : attrs_) { diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 68f74a853..160d8d05c 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -54,6 +54,8 @@ void BindGraph(py::module *m) { "The graph is a Directed Acyclic Single Static Assignment Graph, see " "`paddle::ir::Graph` for details.") .def(py::init()) + .def("__init__", + [](Graph &self, const Graph &other) { new (&self) Graph(other); }) .def("has", &Graph::Has) .def("get_int", &Graph::Get) .def("get_float", &Graph::Get) diff --git a/python/paddle/fluid/contrib/slim/tests/test_graph.py b/python/paddle/fluid/contrib/slim/tests/test_graph.py index 2d2f1384d..5165865e9 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_graph.py +++ b/python/paddle/fluid/contrib/slim/tests/test_graph.py @@ -52,7 +52,7 @@ def residual_block(num): class TestGraph(unittest.TestCase): - def test_graph_functions(self): + def test_graph_functions(self, for_ci=True): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): @@ -60,11 +60,20 @@ class TestGraph(unittest.TestCase): opt = fluid.optimizer.Adam(learning_rate=0.001) opt.minimize(loss) graph = IrGraph(core.Graph(main.desc), for_test=False) + backup_graph = graph.clone() + self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes())) + marked_nodes = set() for op in graph.all_op_nodes(): if op.name().find('conv2d') > -1: marked_nodes.add(op) - graph.draw('.', 'residual', marked_nodes) + if not for_ci: + graph.draw('.', 'residual', marked_nodes) + backup_marked_nodes = set() + for op in backup_graph.all_op_nodes(): + if op.name().find('conv2d') > -1: + backup_marked_nodes.add(op) + backup_graph.draw('.', 'backup', backup_marked_nodes) self.assertFalse(graph.has_circle()) self.assertEqual(graph.graph_num(), 1) nodes = graph.topology_sort() diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5b9dd8693..6cfc433d6 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2002,6 +2002,16 @@ class IrGraph(object): self.graph = graph self._for_test = for_test + def clone(self): + """ + Create a new and duplicated IrGraph. + + Returns: + IrGraph: A new and duplicated graph. + """ + g = core.Graph(self.graph) + return IrGraph(g, self._for_test) + def is_test(self): """ If the graph is used for testing, the function returns true. Otherwise, returns false. -- GitLab