提交 1b9c8d5f 编写于 作者: Z Zhen Wang

add clone function for IrGraph. test=develop

上级 08e75731
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h"
......@@ -29,6 +30,8 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
ResolveHazard(var_nodes);
}
Graph::Graph(const Graph &o) : Graph(o.program_) {}
std::map<std::string, std::vector<ir::Node *>> Graph::InitFromProgram(
const ProgramDesc &program) {
VLOG(3) << "block in program:" << program_.Size();
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
......@@ -71,6 +72,7 @@ namespace ir {
class Graph {
public:
explicit Graph(const ProgramDesc &program);
Graph(const Graph &o);
virtual ~Graph() {
for (auto &attr : attrs_) {
......
......@@ -54,6 +54,8 @@ void BindGraph(py::module *m) {
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
"`paddle::ir::Graph` for details.")
.def(py::init<const ProgramDesc &>())
.def("__init__",
[](Graph &self, const Graph &other) { new (&self) Graph(other); })
.def("has", &Graph::Has)
.def("get_int", &Graph::Get<int>)
.def("get_float", &Graph::Get<float>)
......
......@@ -52,7 +52,7 @@ def residual_block(num):
class TestGraph(unittest.TestCase):
def test_graph_functions(self):
def test_graph_functions(self, for_ci=True):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
......@@ -60,11 +60,20 @@ class TestGraph(unittest.TestCase):
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False)
backup_graph = graph.clone()
self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes()))
marked_nodes = set()
for op in graph.all_op_nodes():
if op.name().find('conv2d') > -1:
marked_nodes.add(op)
graph.draw('.', 'residual', marked_nodes)
if not for_ci:
graph.draw('.', 'residual', marked_nodes)
backup_marked_nodes = set()
for op in backup_graph.all_op_nodes():
if op.name().find('conv2d') > -1:
backup_marked_nodes.add(op)
backup_graph.draw('.', 'backup', backup_marked_nodes)
self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort()
......
......@@ -2002,6 +2002,16 @@ class IrGraph(object):
self.graph = graph
self._for_test = for_test
def clone(self):
"""
Create a new and duplicated IrGraph.
Returns:
IrGraph: A new and duplicated graph.
"""
g = core.Graph(self.graph)
return IrGraph(g, self._for_test)
def is_test(self):
"""
If the graph is used for testing, the function returns true. Otherwise, returns false.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册