diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index ab259a0fc85abb7600cc49123f242fdfd8dc147b..3516e71b837917cae2d60193ec5e3798c9d1a211 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -368,10 +368,32 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, subgraph.get()); AddOutputVar(output_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get()); - // Store the input variables whose buffer are not needed as - // attribute of the graph. + // Save lists of input variables, internal variables and output variables + // of the cluster as attributes of the subgraph for convenience. + auto collect_names_fn = []( + const GraphNodeSet& nodes, + const std::unordered_set& ignore_names) { + auto result = std::make_unique>(); + for (auto* node : nodes) { + if (ignore_names.count(node->Name())) { + continue; + } + result->emplace_back(node->Name()); + } + return result; + }; + subgraph->Set>( + kInternalVars, collect_names_fn(cluster_internals, {}).release()); + subgraph->Set>( + kOutputVars, collect_names_fn(cluster_outputs, {}).release()); + // Divide input variables into two parts: one is common and will be used + // in execution, the other may be empty and it is those variables whose + // buffer are not needed and only be used in graph symbolization auto no_need_buffer_feeds = std::make_unique>( ExtractNoNeedBufferFeeds(cluster, cluster_inputs)); + subgraph->Set>( + kInputVars, + collect_names_fn(cluster_inputs, *no_need_buffer_feeds).release()); subgraph->Set>( kNoNeedBufferFeeds, no_need_buffer_feeds.release()); // initialize empty map for kMemOptVarInfoFromMainGraph attribute, @@ -458,33 +480,18 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, framework::OpDesc cinn_op_desc; cinn_op_desc.SetType(kCinnLaunchOp); - // Divide input variables as two parts: - // the ones that data buffer are not needed and remain ones - std::vector op_kx_inputs, no_need_buffer_inputs; const auto& subgraph = CinnCompiler::GetInstance()->FindGraph(compilation_key); - auto& no_need_buffer_feeds = + const auto& no_need_buffer_feeds = subgraph.Get>(kNoNeedBufferFeeds); - for (const auto* n : cluster_inputs) { - const auto& var_name = n->Name(); - if (no_need_buffer_feeds.count(var_name)) { - no_need_buffer_inputs.emplace_back(var_name); - } else { - op_kx_inputs.emplace_back(var_name); - } - } - cinn_op_desc.SetInput(operators::kX, op_kx_inputs); - cinn_op_desc.SetInput(operators::kNoNeedBufferX, no_need_buffer_inputs); - - std::vector output_names; - std::for_each(cluster_outputs.begin(), cluster_outputs.end(), - [&output_names, &deny_var_set](Node* n) { - if (n->Var() != nullptr && !deny_var_set.count(n->Name())) { - output_names.emplace_back(n->Name()); - } - }); - cinn_op_desc.SetOutput(operators::kOutputs, output_names); + cinn_op_desc.SetInput(operators::kX, + subgraph.Get>(kInputVars)); + cinn_op_desc.SetInput(operators::kNoNeedBufferX, + std::vector(no_need_buffer_feeds.begin(), + no_need_buffer_feeds.end())); + cinn_op_desc.SetOutput(operators::kOutputs, + subgraph.Get>(kOutputVars)); cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key); cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), ExtractOpRole(cluster)); diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 9bb25b6b52e5466b1665fc080511fbe63d8011df..8cb920831cc543a073652051c1ba234e974179c3 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -21,7 +21,10 @@ namespace framework { namespace paddle2cinn { constexpr char kCinnLaunchOp[] = "cinn_launch"; -constexpr char kNoNeedBufferFeeds[] = "no_need_buffer_feeds"; +constexpr char kInputVars[] = "InputVars"; +constexpr char kNoNeedBufferFeeds[] = "NoNeedBufferFeeds"; +constexpr char kInternalVars[] = "InternalVars"; +constexpr char kOutputVars[] = "OutputVars"; constexpr char kMemOptVarInfoFromMainGraph[] = "mem_opt_var_info_from_main_graph"; diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc index 79e6da987ef09db5ed43dfb8168dd13fa0cf885e..919fc60d4cb61b6079965e3c8ab7d43ca9a2b211 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -652,6 +652,19 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) { ASSERT_EQ(no_need_buffer_feeds.size(), 2); ASSERT_EQ(no_need_buffer_feeds, std::unordered_set({"var2", "var3"})); + + // check the attributes of variable lists are saved correctly + ASSERT_TRUE(subgraph.Has(kInputVars)); + EXPECT_EQ(subgraph.Get>(kInputVars), + std::vector({"var1"})); + ASSERT_TRUE(subgraph.Has(kInternalVars)); + EXPECT_EQ(subgraph.Get>(kInternalVars), + std::vector({"var4"})); + ASSERT_TRUE(subgraph.Has(kOutputVars)); + const auto& output_vars = subgraph.Get>(kOutputVars); + EXPECT_EQ( + std::unordered_set(output_vars.begin(), output_vars.end()), + std::unordered_set({"var5", "var6"})); } } // namespace paddle2cinn