未验证 提交 9b662bef 编写于 作者: Z Zhen Wang 提交者: GitHub

Add a feed op before each input parameter var. (#44499)

* Add a feed op before each input parameter var.

* Fix some issues about the unit test build_cinn_pass_test.
上级 33cc0f7a
...@@ -141,14 +141,14 @@ int ExtractOpRole(const GraphNodeSet& cluster) { ...@@ -141,14 +141,14 @@ int ExtractOpRole(const GraphNodeSet& cluster) {
} }
} }
// Deal with subgraph's feed input var node: // Deal with input var nodes of the target subgraph:
// create a new input var node and it's feed op node // create a new input var node and it's feed op node
void AddFeedOpAndVar(const GraphNodeSet& feed_vars, void AddFeedOpAndVar(const GraphNodeSet& input_vars,
const GraphNodeSet& cluster, const GraphNodeSet& cluster,
const GraphNodeMap& old_op2new_op, const GraphNodeMap& old_op2new_op,
const GraphNodeMap& old_var2new_var, const GraphNodeMap& old_var2new_var,
Graph* graph) { Graph* graph) {
for (auto* old_var : feed_vars) { for (auto* old_var : input_vars) {
// create feed op // create feed op
OpDesc desc; OpDesc desc;
desc.SetType("feed"); desc.SetType("feed");
...@@ -157,7 +157,7 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars, ...@@ -157,7 +157,7 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars,
// get new feed var node // get new feed var node
auto* var = old_var2new_var.at(old_var); auto* var = old_var2new_var.at(old_var);
VLOG(4) << "Add Feed Op before: " << var->Name(); VLOG(4) << "Add Feed Op before the input var: " << var->Name();
// link feed op and feed var // link feed op and feed var
IR_NODE_LINK_TO(op, var); IR_NODE_LINK_TO(op, var);
...@@ -174,26 +174,6 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars, ...@@ -174,26 +174,6 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars,
} }
} }
// Deal with subgraph's parameter var node:
// create a new input var node, it's data will get by scope,
// so it don't need feed op
void AddParamVar(const GraphNodeSet& param_vars,
const GraphNodeSet& cluster,
const GraphNodeMap& old_op2new_op,
const GraphNodeMap& old_var2new_var,
Graph* graph) {
for (auto* old_var : param_vars) {
auto* var = old_var2new_var.at(old_var);
VLOG(4) << "Add Param Var Node: " << var->Name();
for (auto* old_op : old_var->outputs) {
if (cluster.count(old_op)) {
IR_NODE_LINK_TO(var, old_op2new_op.at(old_op));
}
}
}
}
// Deal with subgraph's outputs var node: // Deal with subgraph's outputs var node:
// create a new output var node and it's fetch op // create a new output var node and it's fetch op
void AddOutputVar(const GraphNodeSet& output_vars, void AddOutputVar(const GraphNodeSet& output_vars,
...@@ -389,7 +369,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, ...@@ -389,7 +369,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
AddFeedOpAndVar( AddFeedOpAndVar(
need_feed_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get()); need_feed_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get());
AddParamVar( AddFeedOpAndVar(
param_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get()); param_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get());
AddOutputVar( AddOutputVar(
output_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get()); output_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get());
......
...@@ -277,13 +277,13 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ...@@ -277,13 +277,13 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph.Nodes(); const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(12)); ASSERT_EQ(subnodes.size(), static_cast<size_t>(13));
ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add")); ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 2); ASSERT_EQ(CountNode(subnodes, "feed"), 3);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1); ASSERT_EQ(CountNode(subnodes, "fetch"), 1);
// No-parameter input should has feed op // No-parameter input should has feed op
...@@ -293,9 +293,10 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ...@@ -293,9 +293,10 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
ASSERT_EQ(new_v1->inputs[0]->Name(), "feed"); ASSERT_EQ(new_v1->inputs[0]->Name(), "feed");
ASSERT_EQ(new_v1->outputs[0]->Name(), "mul"); ASSERT_EQ(new_v1->outputs[0]->Name(), "mul");
// Parameter input should not has feed op // Parameter input should also have the feed op
auto new_v2 = GetNode(subnodes, "var2"); auto new_v2 = GetNode(subnodes, "var2");
ASSERT_TRUE(new_v2->inputs.empty()); ASSERT_EQ(new_v2->inputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v2->inputs[0]->Name(), "feed");
ASSERT_EQ(new_v2->outputs.size(), static_cast<size_t>(1)); ASSERT_EQ(new_v2->outputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v2->outputs[0]->Name(), "mul"); ASSERT_EQ(new_v2->outputs[0]->Name(), "mul");
...@@ -400,12 +401,12 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) { ...@@ -400,12 +401,12 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph.Nodes(); const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(8)); ASSERT_EQ(subnodes.size(), static_cast<size_t>(9));
ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 1); ASSERT_EQ(CountNode(subnodes, "feed"), 2);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1); ASSERT_EQ(CountNode(subnodes, "fetch"), 1);
} }
...@@ -526,10 +527,10 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { ...@@ -526,10 +527,10 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
if (CheckNodeExisted(subnodes1, "relu")) { if (CheckNodeExisted(subnodes1, "relu")) {
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(5)); ASSERT_EQ(subnodes1.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(6)); ASSERT_EQ(subnodes2.size(), static_cast<size_t>(7));
} else { } else {
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(5)); ASSERT_EQ(subnodes2.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(6)); ASSERT_EQ(subnodes1.size(), static_cast<size_t>(7));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册