未验证 提交 9b662bef 编写于 作者: Z Zhen Wang 提交者: GitHub

Add a feed op before each input parameter var. (#44499)

* Add a feed op before each input parameter var.

* Fix some issues about the unit test build_cinn_pass_test.
上级 33cc0f7a
......@@ -141,14 +141,14 @@ int ExtractOpRole(const GraphNodeSet& cluster) {
}
}
// Deal with subgraph's feed input var node:
// Deal with input var nodes of the target subgraph:
// create a new input var node and it's feed op node
void AddFeedOpAndVar(const GraphNodeSet& feed_vars,
void AddFeedOpAndVar(const GraphNodeSet& input_vars,
const GraphNodeSet& cluster,
const GraphNodeMap& old_op2new_op,
const GraphNodeMap& old_var2new_var,
Graph* graph) {
for (auto* old_var : feed_vars) {
for (auto* old_var : input_vars) {
// create feed op
OpDesc desc;
desc.SetType("feed");
......@@ -157,7 +157,7 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars,
// get new feed var node
auto* var = old_var2new_var.at(old_var);
VLOG(4) << "Add Feed Op before: " << var->Name();
VLOG(4) << "Add Feed Op before the input var: " << var->Name();
// link feed op and feed var
IR_NODE_LINK_TO(op, var);
......@@ -174,26 +174,6 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars,
}
}
// Deal with subgraph's parameter var node:
// create a new input var node, it's data will get by scope,
// so it don't need feed op
void AddParamVar(const GraphNodeSet& param_vars,
const GraphNodeSet& cluster,
const GraphNodeMap& old_op2new_op,
const GraphNodeMap& old_var2new_var,
Graph* graph) {
for (auto* old_var : param_vars) {
auto* var = old_var2new_var.at(old_var);
VLOG(4) << "Add Param Var Node: " << var->Name();
for (auto* old_op : old_var->outputs) {
if (cluster.count(old_op)) {
IR_NODE_LINK_TO(var, old_op2new_op.at(old_op));
}
}
}
}
// Deal with subgraph's outputs var node:
// create a new output var node and it's fetch op
void AddOutputVar(const GraphNodeSet& output_vars,
......@@ -389,7 +369,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
AddFeedOpAndVar(
need_feed_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get());
AddParamVar(
AddFeedOpAndVar(
param_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get());
AddOutputVar(
output_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get());
......
......@@ -277,13 +277,13 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(12));
ASSERT_EQ(subnodes.size(), static_cast<size_t>(13));
ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 2);
ASSERT_EQ(CountNode(subnodes, "feed"), 3);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1);
// No-parameter input should has feed op
......@@ -293,9 +293,10 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
ASSERT_EQ(new_v1->inputs[0]->Name(), "feed");
ASSERT_EQ(new_v1->outputs[0]->Name(), "mul");
// Parameter input should not has feed op
// Parameter input should also have the feed op
auto new_v2 = GetNode(subnodes, "var2");
ASSERT_TRUE(new_v2->inputs.empty());
ASSERT_EQ(new_v2->inputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v2->inputs[0]->Name(), "feed");
ASSERT_EQ(new_v2->outputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v2->outputs[0]->Name(), "mul");
......@@ -400,12 +401,12 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(8));
ASSERT_EQ(subnodes.size(), static_cast<size_t>(9));
ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 1);
ASSERT_EQ(CountNode(subnodes, "feed"), 2);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1);
}
......@@ -526,10 +527,10 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
if (CheckNodeExisted(subnodes1, "relu")) {
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(6));
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(7));
} else {
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(6));
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(7));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册