未验证 提交 94b7c1ea 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #16107 from wzzju/add_graph_clone

Add clone function for IrGraph.
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <unordered_set> #include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
...@@ -152,6 +152,39 @@ void Graph::ResolveHazard( ...@@ -152,6 +152,39 @@ void Graph::ResolveHazard(
} }
} }
std::shared_ptr<Graph> Graph::Clone() {
auto cloned_graph = std::make_shared<Graph>(this->program_);
cloned_graph->ReleaseNodes();
cloned_graph->num_node_created_ = 0;
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
for (auto *n : this->node_set_) {
ir::Node *cloned_node = nullptr;
if (n->IsCtrlVar()) {
cloned_node = cloned_graph->CreateControlDepVar();
} else if (!n->var_desc_ && !n->op_desc_) { // empty node
cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType());
} else if (n->IsVar()) {
cloned_node = cloned_graph->CreateVarNode(n->Var());
} else if (n->IsOp()) {
cloned_node = cloned_graph->CreateOpNode(n->Op());
}
if (cloned_node) {
origin_to_cloned[n] = cloned_node;
} else {
PADDLE_THROW("The cloned node's type is not supported!");
}
}
for (auto *n : this->node_set_) {
for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) {
origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]);
}
for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) {
origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]);
}
}
return cloned_graph;
}
bool IsControlDepVar(const ir::Node &var) { bool IsControlDepVar(const ir::Node &var) {
return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos; return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos;
} }
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
...@@ -212,6 +213,10 @@ class Graph { ...@@ -212,6 +213,10 @@ class Graph {
void ResolveHazard( void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes); const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
// Create a new and duplicated graph.
// WARN: The method only clones the graph structure, not its attributes.
std::shared_ptr<Graph> Clone();
private: private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram( std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program); const ProgramDesc &program);
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <typeindex> #include <typeindex>
#include <typeinfo> #include <typeinfo>
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...@@ -54,12 +55,14 @@ void BindGraph(py::module *m) { ...@@ -54,12 +55,14 @@ void BindGraph(py::module *m) {
"The graph is a Directed Acyclic Single Static Assignment Graph, see " "The graph is a Directed Acyclic Single Static Assignment Graph, see "
"`paddle::ir::Graph` for details.") "`paddle::ir::Graph` for details.")
.def(py::init<const ProgramDesc &>()) .def(py::init<const ProgramDesc &>())
.def("clone", &Graph::Clone)
.def("has", &Graph::Has) .def("has", &Graph::Has)
.def("get_int", &Graph::Get<int>) .def("get_int", &Graph::Get<int>)
.def("get_float", &Graph::Get<float>) .def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>) .def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>) .def("get_string", &Graph::Get<std::string>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>) .def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>,
return_value_policy::reference)
.def("set", [](Graph &self, const std::string &attr_name, .def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); }) int attr) { return self.Set(attr_name, new int(attr)); })
.def("set", .def("set",
...@@ -103,7 +106,8 @@ void BindGraph(py::module *m) { ...@@ -103,7 +106,8 @@ void BindGraph(py::module *m) {
.def("retrieve_node", &Graph::RetrieveNode, .def("retrieve_node", &Graph::RetrieveNode,
return_value_policy::reference) return_value_policy::reference)
.def("resolve_hazard", &Graph::ResolveHazard) .def("resolve_hazard", &Graph::ResolveHazard)
.def("origin_program_desc", &Graph::OriginProgram); .def("origin_program_desc", &Graph::OriginProgram,
return_value_policy::reference);
} }
void BindNode(py::module *m) { void BindNode(py::module *m) {
......
...@@ -13,58 +13,92 @@ ...@@ -13,58 +13,92 @@
# limitations under the license. # limitations under the license.
from __future__ import print_function from __future__ import print_function
import os
import six
import unittest import unittest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import six
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid import core from paddle.fluid import core
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CPU_NUM"] = "1"
def residual_block(num):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') def conv_block():
img = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64') label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data conv_pool_1 = fluid.nets.simple_img_conv_pool(
for _ in six.moves.xrange(num): input=img,
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) filter_size=5,
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) num_filters=20,
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') pool_size=2,
fc = fluid.layers.fc(input=hidden, size=10) pool_stride=2,
loss = fluid.layers.cross_entropy(input=fc, label=label) act="relu")
loss = fluid.layers.mean(loss) conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
return loss conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu")
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
return [img, label], avg_loss
class TestGraph(unittest.TestCase): class TestGraph(unittest.TestCase):
def test_graph_functions(self): def graph_apis(self, use_cuda=False, for_ci=True):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
with fluid.program_guard(main, startup): with fluid.program_guard(main, startup):
loss = residual_block(2) feeds, loss = conv_block()
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False) graph = IrGraph(core.Graph(main.desc), for_test=False)
backup_graph = graph.clone()
self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes()))
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
origin_binary = fluid.CompiledProgram(graph.graph).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
backup_binary = fluid.CompiledProgram(
backup_graph.graph).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
iters = 5
batch_size = 8
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
def train(binary):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(binary,
feed=feeder.feed(data),
fetch_list=[loss.name])
print('{}: {}'.format('loss', loss_v))
train(origin_binary)
train(backup_binary)
marked_nodes = set() marked_nodes = set()
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name().find('conv2d') > -1: if op.name().find('conv2d') > -1:
marked_nodes.add(op) marked_nodes.add(op)
graph.draw('.', 'residual', marked_nodes) if not for_ci:
graph.draw('.', 'residual', marked_nodes)
backup_marked_nodes = set()
for op in backup_graph.all_op_nodes():
if op.name().find('conv2d') > -1:
backup_marked_nodes.add(op)
backup_graph.draw('.', 'backup', backup_marked_nodes)
self.assertFalse(graph.has_circle()) self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1) self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort() nodes = graph.topology_sort()
...@@ -75,6 +109,13 @@ class TestGraph(unittest.TestCase): ...@@ -75,6 +109,13 @@ class TestGraph(unittest.TestCase):
graph.safe_remove_nodes(marked_nodes) graph.safe_remove_nodes(marked_nodes)
self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes)) self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes))
def test_graph_apis_cpu(self):
self.graph_apis(use_cuda=False, for_ci=True)
def test_graph_apis_cuda(self):
if fluid.core.is_compiled_with_cuda():
self.graph_apis(use_cuda=True, for_ci=True)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -2002,6 +2002,19 @@ class IrGraph(object): ...@@ -2002,6 +2002,19 @@ class IrGraph(object):
self.graph = graph self.graph = graph
self._for_test = for_test self._for_test = for_test
def clone(self):
"""
Create a new and duplicated IrGraph.
Warns:
The method only clones the graph structure, not its attributes.
Returns:
IrGraph: A new and duplicated graph.
"""
g = self.graph.clone()
return IrGraph(g, self._for_test)
def is_test(self): def is_test(self):
""" """
If the graph is used for testing, the function returns true. Otherwise, returns false. If the graph is used for testing, the function returns true. Otherwise, returns false.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册