提交 ac6ef06f 编写于 作者: Z Zhen Wang

Add the Clone method in Graph. test=develop

上级 01eddf12
...@@ -152,6 +152,39 @@ void Graph::ResolveHazard( ...@@ -152,6 +152,39 @@ void Graph::ResolveHazard(
} }
} }
std::shared_ptr<Graph> Graph::Clone() {
auto cloned_graph = std::make_shared<Graph>(this->program_);
cloned_graph->ReleaseNodes();
cloned_graph->num_node_created_ = 0;
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
for (auto *n : this->node_set_) {
ir::Node *cloned_node = nullptr;
if (n->IsCtrlVar()) {
cloned_node = cloned_graph->CreateControlDepVar();
} else if (!n->var_desc_ && !n->op_desc_) { // empty node
cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType());
} else if (n->IsVar()) {
cloned_node = cloned_graph->CreateVarNode(n->Var());
} else if (n->IsOp()) {
cloned_node = cloned_graph->CreateOpNode(n->Op());
}
if (cloned_node) {
origin_to_cloned[n] = cloned_node;
} else {
PADDLE_THROW("The cloned node's type is not supported!");
}
}
for (auto *n : this->node_set_) {
for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) {
origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]);
}
for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) {
origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]);
}
}
return cloned_graph;
}
bool IsControlDepVar(const ir::Node &var) { bool IsControlDepVar(const ir::Node &var) {
return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos; return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos;
} }
......
...@@ -213,6 +213,10 @@ class Graph { ...@@ -213,6 +213,10 @@ class Graph {
void ResolveHazard( void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes); const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
// Create a new and duplicated graph.
// WARN: The method only clones the graph structure, not its attributes.
std::shared_ptr<Graph> Clone();
private: private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram( std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program); const ProgramDesc &program);
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <typeindex> #include <typeindex>
#include <typeinfo> #include <typeinfo>
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...@@ -54,6 +55,7 @@ void BindGraph(py::module *m) { ...@@ -54,6 +55,7 @@ void BindGraph(py::module *m) {
"The graph is a Directed Acyclic Single Static Assignment Graph, see " "The graph is a Directed Acyclic Single Static Assignment Graph, see "
"`paddle::ir::Graph` for details.") "`paddle::ir::Graph` for details.")
.def(py::init<const ProgramDesc &>()) .def(py::init<const ProgramDesc &>())
.def("clone", &Graph::Clone)
.def("has", &Graph::Has) .def("has", &Graph::Has)
.def("get_int", &Graph::Get<int>) .def("get_int", &Graph::Get<int>)
.def("get_float", &Graph::Get<float>) .def("get_float", &Graph::Get<float>)
......
...@@ -60,20 +60,12 @@ class TestGraph(unittest.TestCase): ...@@ -60,20 +60,12 @@ class TestGraph(unittest.TestCase):
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
backup_graph = graph.clone()
self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes()))
marked_nodes = set() marked_nodes = set()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name().find('conv2d') > -1: if op.name().find('conv2d') > -1:
marked_nodes.add(op) marked_nodes.add(op)
if not for_ci: if not for_ci:
graph.draw('.', 'residual', marked_nodes) graph.draw('.', 'residual', marked_nodes)
backup_marked_nodes = set()
for op in backup_graph.all_op_nodes():
if op.name().find('conv2d') > -1:
backup_marked_nodes.add(op)
backup_graph.draw('.', 'backup', backup_marked_nodes)
self.assertFalse(graph.has_circle()) self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1) self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort() nodes = graph.topology_sort()
...@@ -83,6 +75,14 @@ class TestGraph(unittest.TestCase): ...@@ -83,6 +75,14 @@ class TestGraph(unittest.TestCase):
nodes_num = len(graph.all_nodes()) nodes_num = len(graph.all_nodes())
graph.safe_remove_nodes(marked_nodes) graph.safe_remove_nodes(marked_nodes)
self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes)) self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes))
backup_graph = graph.clone()
self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes()))
if not for_ci:
backup_marked_nodes = set()
for op in backup_graph.all_op_nodes():
if op.name().find('conv2d') > -1:
backup_marked_nodes.add(op)
backup_graph.draw('.', 'backup', backup_marked_nodes)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -2012,7 +2012,7 @@ class IrGraph(object): ...@@ -2012,7 +2012,7 @@ class IrGraph(object):
Returns: Returns:
IrGraph: A new and duplicated graph. IrGraph: A new and duplicated graph.
""" """
g = core.Graph(self.graph.origin_program_desc()) g = self.graph.clone()
return IrGraph(g, self._for_test) return IrGraph(g, self._for_test)
def is_test(self): def is_test(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册