diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index b25d3a7f3af92ff8d7f605fa32c741ce3f0414f9..593646164940b1f02f55f66f2a26e79991c69885 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -141,14 +141,14 @@ int ExtractOpRole(const GraphNodeSet& cluster) { } } -// Deal with subgraph's feed input var node: +// Deal with input var nodes of the target subgraph: // create a new input var node and it's feed op node -void AddFeedOpAndVar(const GraphNodeSet& feed_vars, +void AddFeedOpAndVar(const GraphNodeSet& input_vars, const GraphNodeSet& cluster, const GraphNodeMap& old_op2new_op, const GraphNodeMap& old_var2new_var, Graph* graph) { - for (auto* old_var : feed_vars) { + for (auto* old_var : input_vars) { // create feed op OpDesc desc; desc.SetType("feed"); @@ -157,7 +157,7 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars, // get new feed var node auto* var = old_var2new_var.at(old_var); - VLOG(4) << "Add Feed Op before: " << var->Name(); + VLOG(4) << "Add Feed Op before the input var: " << var->Name(); // link feed op and feed var IR_NODE_LINK_TO(op, var); @@ -174,26 +174,6 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars, } } -// Deal with subgraph's parameter var node: -// create a new input var node, it's data will get by scope, -// so it don't need feed op -void AddParamVar(const GraphNodeSet& param_vars, - const GraphNodeSet& cluster, - const GraphNodeMap& old_op2new_op, - const GraphNodeMap& old_var2new_var, - Graph* graph) { - for (auto* old_var : param_vars) { - auto* var = old_var2new_var.at(old_var); - VLOG(4) << "Add Param Var Node: " << var->Name(); - - for (auto* old_op : old_var->outputs) { - if (cluster.count(old_op)) { - IR_NODE_LINK_TO(var, old_op2new_op.at(old_op)); - } - } - } -} - // Deal with subgraph's outputs var node: // create a new output var node and it's fetch op void AddOutputVar(const GraphNodeSet& output_vars, @@ -389,7 +369,7 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, AddFeedOpAndVar( need_feed_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get()); - AddParamVar( + AddFeedOpAndVar( param_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get()); AddOutputVar( output_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get()); diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc index f951a09cfd56af4c81130eab00545432cbe27aa2..7d4d856e4cbf2b0b351ea1c1a423bc31acb10c5a 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -277,13 +277,13 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); const auto& subnodes = subgraph.Nodes(); - ASSERT_EQ(subnodes.size(), static_cast(12)); + ASSERT_EQ(subnodes.size(), static_cast(13)); ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); - ASSERT_EQ(CountNode(subnodes, "feed"), 2); + ASSERT_EQ(CountNode(subnodes, "feed"), 3); ASSERT_EQ(CountNode(subnodes, "fetch"), 1); // No-parameter input should has feed op @@ -293,9 +293,10 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ASSERT_EQ(new_v1->inputs[0]->Name(), "feed"); ASSERT_EQ(new_v1->outputs[0]->Name(), "mul"); - // Parameter input should not has feed op + // Parameter input should also have the feed op auto new_v2 = GetNode(subnodes, "var2"); - ASSERT_TRUE(new_v2->inputs.empty()); + ASSERT_EQ(new_v2->inputs.size(), static_cast(1)); + ASSERT_EQ(new_v2->inputs[0]->Name(), "feed"); ASSERT_EQ(new_v2->outputs.size(), static_cast(1)); ASSERT_EQ(new_v2->outputs[0]->Name(), "mul"); @@ -400,12 +401,12 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) { const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); const auto& subnodes = subgraph.Nodes(); - ASSERT_EQ(subnodes.size(), static_cast(8)); + ASSERT_EQ(subnodes.size(), static_cast(9)); ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); - ASSERT_EQ(CountNode(subnodes, "feed"), 1); + ASSERT_EQ(CountNode(subnodes, "feed"), 2); ASSERT_EQ(CountNode(subnodes, "fetch"), 1); } @@ -526,10 +527,10 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { if (CheckNodeExisted(subnodes1, "relu")) { ASSERT_EQ(subnodes1.size(), static_cast(5)); - ASSERT_EQ(subnodes2.size(), static_cast(6)); + ASSERT_EQ(subnodes2.size(), static_cast(7)); } else { ASSERT_EQ(subnodes2.size(), static_cast(5)); - ASSERT_EQ(subnodes1.size(), static_cast(6)); + ASSERT_EQ(subnodes1.size(), static_cast(7)); } }