diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 5e954fa9c419b249bb8a4be5a78c01da85b017b2..6a9340b870df324f7dea03181bdb2b097e13e705 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include +#include #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/op_proto_maker.h" @@ -152,6 +152,39 @@ void Graph::ResolveHazard( } } +std::shared_ptr Graph::Clone() { + auto cloned_graph = std::make_shared(this->program_); + cloned_graph->ReleaseNodes(); + cloned_graph->num_node_created_ = 0; + std::unordered_map origin_to_cloned; + for (auto *n : this->node_set_) { + ir::Node *cloned_node = nullptr; + if (n->IsCtrlVar()) { + cloned_node = cloned_graph->CreateControlDepVar(); + } else if (!n->var_desc_ && !n->op_desc_) { // empty node + cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType()); + } else if (n->IsVar()) { + cloned_node = cloned_graph->CreateVarNode(n->Var()); + } else if (n->IsOp()) { + cloned_node = cloned_graph->CreateOpNode(n->Op()); + } + if (cloned_node) { + origin_to_cloned[n] = cloned_node; + } else { + PADDLE_THROW("The cloned node's type is not supported!"); + } + } + for (auto *n : this->node_set_) { + for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) { + origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]); + } + for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) { + origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]); + } + } + return cloned_graph; +} + bool IsControlDepVar(const ir::Node &var) { return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos; } diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index cfd974e4bd679fdd06739f4c943bb197865020fb..44ba4d3d2c528d1dbe261b0723d2a40ce3a70cf2 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include "paddle/fluid/framework/ir/node.h" @@ -212,6 +213,10 @@ class Graph { void ResolveHazard( const std::map> &var_nodes); + // Create a new and duplicated graph. + // WARN: The method only clones the graph structure, not its attributes. + std::shared_ptr Clone(); + private: std::map> InitFromProgram( const ProgramDesc &program); diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index 9eade9eaa8f00fe6e76063344f47968f4e97cf7f..72fb876d98dc84164398583baf22c49014af483a 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 68f74a8531fff0c49c8a62d12f5cde7af77faf8a..c69ccd507210f976c1cb8ad072928b96693a948d 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" @@ -54,12 +55,14 @@ void BindGraph(py::module *m) { "The graph is a Directed Acyclic Single Static Assignment Graph, see " "`paddle::ir::Graph` for details.") .def(py::init()) + .def("clone", &Graph::Clone) .def("has", &Graph::Has) .def("get_int", &Graph::Get) .def("get_float", &Graph::Get) .def("get_double", &Graph::Get) .def("get_string", &Graph::Get) - .def("get_marked_nodes", &Graph::Get>) + .def("get_marked_nodes", &Graph::Get>, + return_value_policy::reference) .def("set", [](Graph &self, const std::string &attr_name, int attr) { return self.Set(attr_name, new int(attr)); }) .def("set", @@ -103,7 +106,8 @@ void BindGraph(py::module *m) { .def("retrieve_node", &Graph::RetrieveNode, return_value_policy::reference) .def("resolve_hazard", &Graph::ResolveHazard) - .def("origin_program_desc", &Graph::OriginProgram); + .def("origin_program_desc", &Graph::OriginProgram, + return_value_policy::reference); } void BindNode(py::module *m) { diff --git a/python/paddle/fluid/contrib/slim/tests/test_graph.py b/python/paddle/fluid/contrib/slim/tests/test_graph.py index 2d2f1384dec65ee19dcade8a46f80bd3f9eb7013..3629fed160ed657cfe8ce370a606d72b1d310f87 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_graph.py +++ b/python/paddle/fluid/contrib/slim/tests/test_graph.py @@ -13,58 +13,92 @@ # limitations under the license. from __future__ import print_function +import os +import six import unittest +import paddle import paddle.fluid as fluid -import six from paddle.fluid.framework import IrGraph from paddle.fluid import core +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["CPU_NUM"] = "1" -def residual_block(num): - def conv_bn_layer(input, - ch_out, - filter_size, - stride, - padding, - act='relu', - bias_attr=False): - tmp = fluid.layers.conv2d( - input=input, - filter_size=filter_size, - num_filters=ch_out, - stride=stride, - padding=padding, - act=None, - bias_attr=bias_attr) - return fluid.layers.batch_norm(input=tmp, act=act) - data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') +def conv_block(): + img = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') - hidden = data - for _ in six.moves.xrange(num): - conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) - short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) - hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') - fc = fluid.layers.fc(input=hidden, size=10) - loss = fluid.layers.cross_entropy(input=fc, label=label) - loss = fluid.layers.mean(loss) - return loss + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu") + conv_pool_1 = fluid.layers.batch_norm(conv_pool_1) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + act="relu") + prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + avg_loss = fluid.layers.mean(loss) + return [img, label], avg_loss class TestGraph(unittest.TestCase): - def test_graph_functions(self): + def graph_apis(self, use_cuda=False, for_ci=True): main = fluid.Program() startup = fluid.Program() with fluid.program_guard(main, startup): - loss = residual_block(2) + feeds, loss = conv_block() opt = fluid.optimizer.Adam(learning_rate=0.001) opt.minimize(loss) graph = IrGraph(core.Graph(main.desc), for_test=False) + backup_graph = graph.clone() + self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes())) + build_strategy = fluid.BuildStrategy() + build_strategy.memory_optimize = False + build_strategy.enable_inplace = False + origin_binary = fluid.CompiledProgram(graph.graph).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy) + backup_binary = fluid.CompiledProgram( + backup_graph.graph).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy) + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup) + iters = 5 + batch_size = 8 + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=batch_size) + feeder = fluid.DataFeeder(feed_list=feeds, place=place) + + def train(binary): + for _ in range(iters): + data = next(train_reader()) + loss_v = exe.run(binary, + feed=feeder.feed(data), + fetch_list=[loss.name]) + print('{}: {}'.format('loss', loss_v)) + + train(origin_binary) + train(backup_binary) + marked_nodes = set() for op in graph.all_op_nodes(): if op.name().find('conv2d') > -1: marked_nodes.add(op) - graph.draw('.', 'residual', marked_nodes) + if not for_ci: + graph.draw('.', 'residual', marked_nodes) + backup_marked_nodes = set() + for op in backup_graph.all_op_nodes(): + if op.name().find('conv2d') > -1: + backup_marked_nodes.add(op) + backup_graph.draw('.', 'backup', backup_marked_nodes) self.assertFalse(graph.has_circle()) self.assertEqual(graph.graph_num(), 1) nodes = graph.topology_sort() @@ -75,6 +109,13 @@ class TestGraph(unittest.TestCase): graph.safe_remove_nodes(marked_nodes) self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes)) + def test_graph_apis_cpu(self): + self.graph_apis(use_cuda=False, for_ci=True) + + def test_graph_apis_cuda(self): + if fluid.core.is_compiled_with_cuda(): + self.graph_apis(use_cuda=True, for_ci=True) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 5b9dd8693190d1c9a6a71cc6bc6e6acf2a06aa6f..31f90f5f5f271433b8741468f8a7bf42f48d5ac7 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -2002,6 +2002,19 @@ class IrGraph(object): self.graph = graph self._for_test = for_test + def clone(self): + """ + Create a new and duplicated IrGraph. + + Warns: + The method only clones the graph structure, not its attributes. + + Returns: + IrGraph: A new and duplicated graph. + """ + g = self.graph.clone() + return IrGraph(g, self._for_test) + def is_test(self): """ If the graph is used for testing, the function returns true. Otherwise, returns false.