未验证 提交 94b7c1ea 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #16107 from wzzju/add_graph_clone

Add clone function for IrGraph.
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <unordered_set>
#include <unordered_map>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/op_proto_maker.h"
......@@ -152,6 +152,39 @@ void Graph::ResolveHazard(
}
}
std::shared_ptr<Graph> Graph::Clone() {
auto cloned_graph = std::make_shared<Graph>(this->program_);
cloned_graph->ReleaseNodes();
cloned_graph->num_node_created_ = 0;
std::unordered_map<ir::Node *, ir::Node *> origin_to_cloned;
for (auto *n : this->node_set_) {
ir::Node *cloned_node = nullptr;
if (n->IsCtrlVar()) {
cloned_node = cloned_graph->CreateControlDepVar();
} else if (!n->var_desc_ && !n->op_desc_) { // empty node
cloned_node = cloned_graph->CreateEmptyNode(n->Name(), n->NodeType());
} else if (n->IsVar()) {
cloned_node = cloned_graph->CreateVarNode(n->Var());
} else if (n->IsOp()) {
cloned_node = cloned_graph->CreateOpNode(n->Op());
}
if (cloned_node) {
origin_to_cloned[n] = cloned_node;
} else {
PADDLE_THROW("The cloned node's type is not supported!");
}
}
for (auto *n : this->node_set_) {
for (auto it = n->inputs.begin(); it != n->inputs.end(); it++) {
origin_to_cloned[n]->inputs.push_back(origin_to_cloned[*it]);
}
for (auto it = n->outputs.begin(); it != n->outputs.end(); it++) {
origin_to_cloned[n]->outputs.push_back(origin_to_cloned[*it]);
}
}
return cloned_graph;
}
bool IsControlDepVar(const ir::Node &var) {
return var.Name().find(ir::Node::kControlDepVarName) != std::string::npos;
}
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/node.h"
......@@ -212,6 +213,10 @@ class Graph {
void ResolveHazard(
const std::map<std::string, std::vector<ir::Node *>> &var_nodes);
// Create a new and duplicated graph.
// WARN: The method only clones the graph structure, not its attributes.
std::shared_ptr<Graph> Clone();
private:
std::map<std::string, std::vector<ir::Node *>> InitFromProgram(
const ProgramDesc &program);
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <typeindex>
#include <typeinfo>
......
......@@ -18,6 +18,7 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
......@@ -54,12 +55,14 @@ void BindGraph(py::module *m) {
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
"`paddle::ir::Graph` for details.")
.def(py::init<const ProgramDesc &>())
.def("clone", &Graph::Clone)
.def("has", &Graph::Has)
.def("get_int", &Graph::Get<int>)
.def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>,
return_value_policy::reference)
.def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); })
.def("set",
......@@ -103,7 +106,8 @@ void BindGraph(py::module *m) {
.def("retrieve_node", &Graph::RetrieveNode,
return_value_policy::reference)
.def("resolve_hazard", &Graph::ResolveHazard)
.def("origin_program_desc", &Graph::OriginProgram);
.def("origin_program_desc", &Graph::OriginProgram,
return_value_policy::reference);
}
void BindNode(py::module *m) {
......
......@@ -13,58 +13,92 @@
# limitations under the license.
from __future__ import print_function
import os
import six
import unittest
import paddle
import paddle.fluid as fluid
import six
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CPU_NUM"] = "1"
def residual_block(num):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
def conv_block():
img = fluid.layers.data(name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10)
loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss)
return loss
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu")
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
return [img, label], avg_loss
class TestGraph(unittest.TestCase):
def test_graph_functions(self):
def graph_apis(self, use_cuda=False, for_ci=True):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(2)
feeds, loss = conv_block()
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False)
backup_graph = graph.clone()
self.assertEqual(len(graph.all_nodes()), len(backup_graph.all_nodes()))
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
origin_binary = fluid.CompiledProgram(graph.graph).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
backup_binary = fluid.CompiledProgram(
backup_graph.graph).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
iters = 5
batch_size = 8
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
def train(binary):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(binary,
feed=feeder.feed(data),
fetch_list=[loss.name])
print('{}: {}'.format('loss', loss_v))
train(origin_binary)
train(backup_binary)
marked_nodes = set()
for op in graph.all_op_nodes():
if op.name().find('conv2d') > -1:
marked_nodes.add(op)
if not for_ci:
graph.draw('.', 'residual', marked_nodes)
backup_marked_nodes = set()
for op in backup_graph.all_op_nodes():
if op.name().find('conv2d') > -1:
backup_marked_nodes.add(op)
backup_graph.draw('.', 'backup', backup_marked_nodes)
self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort()
......@@ -75,6 +109,13 @@ class TestGraph(unittest.TestCase):
graph.safe_remove_nodes(marked_nodes)
self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes))
def test_graph_apis_cpu(self):
self.graph_apis(use_cuda=False, for_ci=True)
def test_graph_apis_cuda(self):
if fluid.core.is_compiled_with_cuda():
self.graph_apis(use_cuda=True, for_ci=True)
if __name__ == '__main__':
unittest.main()
......@@ -2002,6 +2002,19 @@ class IrGraph(object):
self.graph = graph
self._for_test = for_test
def clone(self):
"""
Create a new and duplicated IrGraph.
Warns:
The method only clones the graph structure, not its attributes.
Returns:
IrGraph: A new and duplicated graph.
"""
g = self.graph.clone()
return IrGraph(g, self._for_test)
def is_test(self):
"""
If the graph is used for testing, the function returns true. Otherwise, returns false.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册