提交 1b9c8d5f 编写于 作者: Z Zhen Wang

add clone function for IrGraph. test=develop

上级 08e75731
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
...@@ -29,6 +30,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { ...@@ -29,6 +30,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
ResolveHazard(var_nodes); ResolveHazard(var_nodes);
} }
Graph::Graph(const Graph &o) : Graph(o.program_) {}
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram( std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
const ProgramDesc &program) { const ProgramDesc &program) {
VLOG(3) << "block in program:" << program_.Size(); VLOG(3) << "block in program:" << program_.Size();
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
...@@ -71,6 +72,7 @@ namespace ir { ...@@ -71,6 +72,7 @@ namespace ir {
class Graph { class Graph {
public: public:
explicit Graph(const ProgramDesc &program); explicit Graph(const ProgramDesc &program);
Graph(const Graph &o);
virtual ~Graph() { virtual ~Graph() {
for (auto &attr : attrs_) { for (auto &attr : attrs_) {
......
...@@ -54,6 +54,8 @@ void BindGraph(py::module *m) { ...@@ -54,6 +54,8 @@ void BindGraph(py::module *m) {
"The graph is a Directed Acyclic Single Static Assignment Graph, see " "The graph is a Directed Acyclic Single Static Assignment Graph, see "
"`paddle::ir::Graph` for details.") "`paddle::ir::Graph` for details.")
.def(py::init<const ProgramDesc &>()) .def(py::init<const ProgramDesc &>())
.def("__init__",
[](Graph &self, const Graph &other) { new (&self) Graph(other); })
.def("has", &Graph::Has) .def("has", &Graph::Has)
.def("get_int", &Graph::Get<int>) .def("get_int", &Graph::Get<int>)
.def("get_float", &Graph::Get<float>) .def("get_float", &Graph::Get<float>)
......
...@@ -52,7 +52,7 @@ def residual_block(num): ...@@ -52,7 +52,7 @@ def residual_block(num):
class TestGraph(unittest.TestCase): class TestGraph(unittest.TestCase):
def test_graph_functions(self): def test_graph_functions(self, for_ci=True):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
...@@ -60,11 +60,20 @@ class TestGraph(unittest.TestCase): ...@@ -60,11 +60,20 @@ class TestGraph(unittest.TestCase):
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
backup_graph = graph.clone()
self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes()))
marked_nodes = set() marked_nodes = set()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name().find('conv2d') > -1: if op.name().find('conv2d') > -1:
marked_nodes.add(op) marked_nodes.add(op)
if not for_ci:
graph.draw('.', 'residual', marked_nodes) graph.draw('.', 'residual', marked_nodes)
backup_marked_nodes = set()
for op in backup_graph.all_op_nodes():
if op.name().find('conv2d') > -1:
backup_marked_nodes.add(op)
backup_graph.draw('.', 'backup', backup_marked_nodes)
self.assertFalse(graph.has_circle()) self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1) self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort() nodes = graph.topology_sort()
......
...@@ -2002,6 +2002,16 @@ class IrGraph(object): ...@@ -2002,6 +2002,16 @@ class IrGraph(object):
self.graph = graph self.graph = graph
self._for_test = for_test self._for_test = for_test
def clone(self):
"""
Create a new and duplicated IrGraph.
Returns:
IrGraph: A new and duplicated graph.
"""
g = core.Graph(self.graph)
return IrGraph(g, self._for_test)
def is_test(self): def is_test(self):
""" """
If the graph is used for testing, the function returns true. Otherwise, returns false. If the graph is used for testing, the function returns true. Otherwise, returns false.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册