未验证 提交 6cdc5a4b 编写于 作者: J jiangcheng 提交者: GitHub

Optimize the subgraph generated by BuildCinnPass (#36503)

* add feed op and new var for the generated subgraph

* perfect the test script of build_cinn_pass 

* remove useless clear and perfect some annotation
上级 7b67f398
......@@ -64,10 +64,81 @@ using framework::ir::Node;
using GraphNodeVec = std::vector<Node*>;
using GraphNodeSet = std::unordered_set<Node*>;
// Deal with subgraph's feed input var node:
// create a new input var node and it's feed op node
void AddFeedOpAndVar(const std::unordered_set<Node*>& feed_vars,
const GraphNodeSet& cluster,
const std::unordered_map<Node*, Node*>& old_op2new_op,
Graph* graph) {
for (auto* old_var : feed_vars) {
// create feed op
OpDesc desc;
desc.SetType("feed");
desc.SetOutput("Out", {old_var->Name()});
auto op = graph->CreateOpNode(&desc);
// create new feed var node (SSAGraph)
auto var = graph->CreateVarNode(old_var->Var());
// link feed op and feed var
op->outputs = {var};
var->inputs = {op};
// link feed var to cluster op
for (auto* old_op : old_var->outputs) {
if (cluster.count(old_op)) {
var->outputs.emplace_back(old_op2new_op.at(old_op));
old_op2new_op.at(old_op)->inputs.emplace_back(var);
}
// Do not need relink old op or old var here, they will be
// fixed in RemoveLinkFromCluster, here we just deal with
// new subgraph's node.
}
}
}
// Deal with subgraph's parameter var node:
// create a new input var node, it's data will get by scope,
// so it don't need feed op
void AddParamVar(const std::unordered_set<Node*>& param_vars,
const GraphNodeSet& cluster,
const std::unordered_map<Node*, Node*>& old_op2new_op,
Graph* graph) {
for (auto* old_var : param_vars) {
auto var = graph->CreateVarNode(old_var->Var());
for (auto* old_op : old_var->outputs) {
if (cluster.count(old_op)) {
var->outputs.emplace_back(old_op2new_op.at(old_op));
old_op2new_op.at(old_op)->inputs.emplace_back(var);
}
}
}
}
// Deal with subgraph's outputs var node:
// create a new output var node and it's fetch op
void AddOutputVar(const std::unordered_set<Node*>& output_vars,
const GraphNodeSet& cluster,
const std::unordered_map<Node*, Node*>& old_op2new_op,
Graph* graph) {
for (auto* old_var : output_vars) {
auto var = graph->CreateVarNode(old_var->Var());
for (auto* old_op : old_var->inputs) {
if (cluster.count(old_op)) {
var->inputs.emplace_back(old_op2new_op.at(old_op));
old_op2new_op.at(old_op)->outputs.emplace_back(var);
}
}
}
}
// Create new subgraph with and op nodes are cluster nodes, and all
// var node are from internal nodes
std::unique_ptr<Graph> CreateNewSubGraph(
const GraphNodeSet& cluster, const GraphNodeSet& cluster_internals) {
std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_internals,
const GraphNodeSet& cluster_inputs) {
// Graph's constructor must has one parameter, and in our code,
// the ProgramDesc is useless, so here we pass a temporary object.
auto sub_graph = std::make_unique<Graph>(framework::ProgramDesc());
......@@ -84,6 +155,8 @@ std::unique_ptr<Graph> CreateNewSubGraph(
old_var2new_var[var] = sub_node;
}
std::unordered_set<Node*> need_feed_vars;
std::unordered_set<Node *> param_vars, output_vars;
// the subgraph is independently, so here we only need link
// to the node in new subgraph, and discard the link to
// out-graph.
......@@ -91,15 +164,36 @@ std::unique_ptr<Graph> CreateNewSubGraph(
for (auto* var : op->inputs) {
if (cluster_internals.count(var)) {
old_op2new_op[op]->inputs.emplace_back(old_var2new_var[var]);
} else if (cluster_inputs.count(var)) {
if (var->Var()->IsParameter()) {
// Parameters have been preserved in scope, compared to feed var,
// param just need add new var and don't need add feed op.
// The var is used for check whether we need preserve the tensor
// when transform paddle scope to CINN scope.
param_vars.insert(var);
} else {
// When the var is subgraph input and the var is not parameter,
// we need add a new feed op to feed the var.
need_feed_vars.insert(var);
}
}
}
for (auto* var : op->outputs) {
if (cluster_internals.count(var)) {
old_op2new_op[op]->outputs.emplace_back(old_var2new_var[var]);
} else {
// Create new output var node to guarantee the independency of
// subgraph. In other words, the subgraph has no connection with
// other graph, even the input graph.
output_vars.insert(var);
}
}
}
AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, sub_graph.get());
AddParamVar(param_vars, cluster, old_op2new_op, sub_graph.get());
AddOutputVar(output_vars, cluster, old_op2new_op, sub_graph.get());
for (auto* var : cluster_internals) {
for (auto* op : var->inputs) {
if (cluster.count(op)) {
......@@ -118,10 +212,12 @@ std::unique_ptr<Graph> CreateNewSubGraph(
// This interface is used to classify all variables involved in a cluster into
// three types: inputs, outputs, and internals.
// Specially, the internal node is a node that only used by sub-graph, and
// The input node is some subgraph op's input but not any subgraph op's output.
// The output node is some subgraph op's output and some out-graph op's input.
// Specially, the internal node is a node that only used by subgraph, and
// out-graph should not using this node at all.
// inputs & outputs & internals == NULL
// inputs | outputs | internals == all graph node
// cluster_inputs & cluster_outputs & cluster_internals == NULL
// cluster_outputs | cluster_internals == all graph op's outputs node
void AnalyseClusterVariables(const GraphNodeSet& cluster,
GraphNodeSet* cluster_inputs,
GraphNodeSet* cluster_outputs,
......@@ -154,10 +250,6 @@ void AnalyseClusterVariables(const GraphNodeSet& cluster,
}
}
// if a output node also exists in input list, remove.
for (auto* var_node : *cluster_inputs) {
cluster_outputs->erase(var_node);
}
// if a output node also exists in internal list, remove.
for (auto* var_node : *cluster_internals) {
cluster_outputs->erase(var_node);
......@@ -206,14 +298,23 @@ void RemoveLinkFromCluster(const GraphNodeSet& cluster,
// removing useless link from cluster_inputs to cluster
for (auto* var_node : cluster_inputs) {
auto preserved_nodes = get_preserved_ops(var_node->outputs);
var_node->outputs.assign(preserved_nodes.begin(), preserved_nodes.end());
auto preserved_ops = get_preserved_ops(var_node->outputs);
var_node->outputs.assign(preserved_ops.begin(), preserved_ops.end());
// According to SSA form, a var node must not be any two op's output,
// and the cluster_inputs var nodes is defined as an out-graph op's
// output, so the cluster_inputs var nodes are not any subgraph op's
// output. Do not reassign input list here.
}
// removing useless link from cluster to cluster_outputs
for (auto* var_node : cluster_outputs) {
auto preserved_nodes = get_preserved_ops(var_node->inputs);
var_node->inputs.assign(preserved_nodes.begin(), preserved_nodes.end());
auto preserved_ops = get_preserved_ops(var_node->inputs);
var_node->inputs.assign(preserved_ops.begin(), preserved_ops.end());
// Note that cluster_outputs var node maybe some subgraph op's input,
// here we need remove them.
preserved_ops = get_preserved_ops(var_node->outputs);
var_node->outputs.assign(preserved_ops.begin(), preserved_ops.end());
}
}
......@@ -272,7 +373,7 @@ void SearchAllSubgraphs(Graph* graph,
&cluster_internals);
cinn_subgraphs->emplace_back(
CreateNewSubGraph(cluster_set, cluster_internals));
CreateNewSubGraph(cluster_set, cluster_internals, cluster_inputs));
// replacing subgraph to a new special op node
ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs,
......
......@@ -54,6 +54,35 @@ inline Node* GetNode(const std::unordered_set<Node*>& nodes,
[&op_name](const Node* node) { return node->Name() == op_name; });
}
inline bool CheckGraphIndependence(const std::unordered_set<Node*>& nodes) {
auto check_node_ok = [&nodes](Node* n1, Node* n2) -> bool {
if (n1->IsOp() && !n2->IsVar()) {
return false;
}
if (n1->IsVar() && !n2->IsOp()) {
return false;
}
if (nodes.count(n2) == 0) {
return false;
}
return true;
};
for (auto node : nodes) {
for (auto in : node->inputs) {
if (!check_node_ok(node, in)) {
return false;
}
}
for (auto out : node->outputs) {
if (!check_node_ok(node, out)) {
return false;
}
}
}
return true;
}
std::unique_ptr<Graph> BuildNoCinnSubgraph() {
ProgramDesc prog;
auto g = std::make_unique<Graph>(prog);
......@@ -67,6 +96,8 @@ std::unique_ptr<Graph> BuildNoCinnSubgraph() {
VarDesc var1("var1");
VarDesc var2("var2");
var2.SetPersistable(true);
var2.SetIsParameter(true);
VarDesc var3("var3");
VarDesc var4("var4");
......@@ -109,6 +140,7 @@ TEST(BuildCinnPassTest, NoCinnSubgraph) {
// After search, origin graph should no change
ASSERT_EQ(previous_nodes, g->Nodes());
ASSERT_TRUE(CheckGraphIndependence(g->Nodes()));
// After search, there should one cinn subgraph
ASSERT_TRUE(cinn_subgraphs.empty());
......@@ -119,11 +151,8 @@ std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() {
auto g = std::make_unique<Graph>(prog);
// v1 --
// |
// | --> mul --> v3 --
// | |
// v2 -- | --> add --> v5 --> relu --> v6
// |
// v4 --
OpDesc add_op;
......@@ -135,6 +164,8 @@ std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() {
VarDesc var1("var1");
VarDesc var2("var2");
var2.SetPersistable(true);
var2.SetIsParameter(true);
VarDesc var3("var3");
VarDesc var4("var4");
VarDesc var5("var5");
......@@ -192,6 +223,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
// v4 --|
const auto& nodes = g->Nodes();
ASSERT_EQ(nodes.size(), static_cast<size_t>(5));
ASSERT_TRUE(CheckGraphIndependence(nodes));
// A new op named kCinnLaunchOp should be added
ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
......@@ -214,16 +246,34 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
ASSERT_FALSE(CheckNodeExisted(nodes, "relu"));
// After search, there should has just one cinn subgraph
// mul --> v3 --> add --> v5 --> relu
// feed --> v1 --
// | --> mul --> v3 --
// v2 -- | --> add --> v5 --> relu --> v6
// feed --> v4 --
ASSERT_EQ(cinn_subgraphs.size(), static_cast<size_t>(1));
const auto& subgraph = cinn_subgraphs.back();
const auto& subnodes = subgraph->Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes.size(), static_cast<size_t>(11));
ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "add"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 2);
// No-parameter input should has feed op
auto new_v1 = GetNode(subnodes, "var1");
ASSERT_EQ(new_v1->inputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v1->outputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v1->inputs[0]->Name(), "feed");
ASSERT_EQ(new_v1->outputs[0]->Name(), "mul");
// Parameter input should not has feed op
auto new_v2 = GetNode(subnodes, "var2");
ASSERT_TRUE(new_v2->inputs.empty());
ASSERT_EQ(new_v2->outputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v2->outputs[0]->Name(), "mul");
}
std::unique_ptr<Graph> BuildGraphWithOneCinnSubgraph() {
......@@ -231,9 +281,7 @@ std::unique_ptr<Graph> BuildGraphWithOneCinnSubgraph() {
auto g = std::make_unique<Graph>(prog);
// fake1 --> v1 --
// |
// | --> mul --> v3 --> relu --> v4 --> fake2
// |
// v2 --
OpDesc fake1_op;
......@@ -247,6 +295,8 @@ std::unique_ptr<Graph> BuildGraphWithOneCinnSubgraph() {
VarDesc var1("var1");
VarDesc var2("var2");
var2.SetPersistable(true);
var2.SetIsParameter(true);
VarDesc var3("var3");
VarDesc var4("var4");
......@@ -299,6 +349,7 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
// v2 --
const auto& nodes = g->Nodes();
ASSERT_EQ(nodes.size(), static_cast<size_t>(6));
ASSERT_TRUE(CheckGraphIndependence(nodes));
// A new op named kCinnLaunchOp should be added
ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
......@@ -312,15 +363,19 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
ASSERT_TRUE(CheckNodeExisted(nodes, "fake2"));
// After search, there should has just one cinn subgraph
// mul --> v3 --> relu
// feed --> v1 --
// | --> mul --> v3 --> relu --> v4
// v2 --
ASSERT_EQ(cinn_subgraphs.size(), static_cast<size_t>(1));
const auto& subgraph = cinn_subgraphs.back();
const auto& subnodes = subgraph->Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(3));
ASSERT_EQ(subnodes.size(), static_cast<size_t>(7));
ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 1);
}
std::unique_ptr<Graph> BuildGraphWithMultiCinnSubgraph() {
......@@ -328,9 +383,7 @@ std::unique_ptr<Graph> BuildGraphWithMultiCinnSubgraph() {
auto g = std::make_unique<Graph>(prog);
// fake1 --> v1 --
// |
// | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3
// |
// v2 --
OpDesc fake1_op;
......@@ -346,6 +399,8 @@ std::unique_ptr<Graph> BuildGraphWithMultiCinnSubgraph() {
VarDesc var1("var1");
VarDesc var2("var2");
var2.SetPersistable(true);
var2.SetIsParameter(true);
VarDesc var3("var3");
VarDesc var4("var4");
VarDesc var5("var5");
......@@ -406,6 +461,7 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
// v2 -
const auto& nodes = g->Nodes();
ASSERT_EQ(nodes.size(), static_cast<size_t>(10));
ASSERT_TRUE(CheckGraphIndependence(nodes));
// A new op named kCinnLaunchOp should be added
ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
......@@ -424,15 +480,27 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
// and each of subgraphs just has one node.
ASSERT_EQ(cinn_subgraphs.size(), static_cast<size_t>(2));
// subgraph1: relu
// subgraph1:
// feed --> v4 --> relu --> v5
// subgraph2:
// feed --> v1 --
// | --> mul --> v3
// v2 --
const auto& subgraph1 = cinn_subgraphs[0];
const auto& subnodes1 = subgraph1->Nodes();
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(1));
ASSERT_TRUE(CheckGraphIndependence(subnodes1));
// subgraph2: mul
const auto& subgraph2 = cinn_subgraphs[1];
const auto& subnodes2 = subgraph2->Nodes();
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(1));
ASSERT_TRUE(CheckGraphIndependence(subnodes2));
if (CheckNodeExisted(subnodes1, "relu")) {
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(4));
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(5));
} else {
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(4));
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(5));
}
}
} // namespace paddle2cinn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册