From 3ad495e8043fc1498a81fc052dff310eb7ff1263 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 18 Nov 2021 10:12:53 +0800 Subject: [PATCH] Add the `GetFetchNames` method in CinnGraphSymbolization. (#37218) * Add the `GetFetchNames` method in CinnGraphSymbolization. * Use unordered_set instead vector as the type of fetch_var_names. * Reuse the definition of kCompilationKey. * Use CompileOptions to set fetch_var_ids. * Update the argument passing of GraphCompiler.Build. * Fix some bugs in CinnGraphSymbolization::GetFetchIds. --- .../framework/paddle2cinn/build_cinn_pass.cc | 7 ++++--- .../framework/paddle2cinn/build_cinn_pass.h | 1 - .../paddle2cinn/build_cinn_pass_test.cc | 5 +++-- .../framework/paddle2cinn/cinn_compiler.cc | 6 +++++- .../paddle2cinn/cinn_compiler_test.cc | 8 +++++--- .../paddle2cinn/cinn_graph_symbolization.cc | 18 +++++++++++++++++- .../paddle2cinn/cinn_graph_symbolization.h | 10 +++++++++- .../cinn_graph_symbolization_test.cc | 4 +++- paddle/fluid/operators/cinn_launch_op.h | 6 +++--- 9 files changed, 49 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index f280214ad0b..7e08d883625 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -34,6 +34,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include "paddle/fluid/operators/cinn_launch_op.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" @@ -381,7 +382,7 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, input_names.emplace_back(n->Name()); } }); - cinn_op_desc.SetInput("X", input_names); + cinn_op_desc.SetInput(operators::kX, input_names); std::vector output_names; std::for_each(cluster_outputs.begin(), cluster_outputs.end(), [&output_names, &deny_var_set](Node* n) { @@ -389,8 +390,8 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, output_names.emplace_back(n->Name()); } }); - cinn_op_desc.SetOutput("Out", output_names); - cinn_op_desc.SetAttr(kCompilationKey, compilation_key); + cinn_op_desc.SetOutput(operators::kOutputs, output_names); + cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key); cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), ExtractOpRole(cluster)); cinn_op_desc.Flush(); diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 1c07fb314e9..8f7d5eb266e 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -21,7 +21,6 @@ namespace framework { namespace paddle2cinn { constexpr char kCinnLaunchOp[] = "cinn_launch"; -constexpr char kCompilationKey[] = "compilation_key"; // A pass named BuildCinnPass, the function of this pass is: // diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc index d76a855b122..1649cd24b34 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/var_desc.h" +#include "paddle/fluid/operators/cinn_launch_op.h" namespace paddle { namespace framework { @@ -91,8 +92,8 @@ std::vector GetCompilationKeys(const Graph& graph) { std::vector compilation_keys; for (auto& node : graph.Nodes()) { if (node->IsOp() && node->Name() == kCinnLaunchOp) { - compilation_keys.emplace_back( - BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey))); + compilation_keys.emplace_back(BOOST_GET_CONST( + std::string, node->Op()->GetAttr(operators::kCompilationKey))); } } return compilation_keys; diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index dd7f1395a1c..561ebeedb30 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -201,11 +201,15 @@ std::unique_ptr CinnCompiler::CompileGraph( ApplyPass(cinn_graph.get(), "OpFusion"); auto scope = BuildScope(target, cinn_graph); + auto fetch_ids = symbol.GetFetchIds(); + VLOG(4) << "All fetch var ids in CINN: " + << string::join_strings(fetch_ids, ','); + auto graph_compiler = std::make_unique(target, scope, cinn_graph); GraphCompiler::CompileOptions options; options.with_instantiate_variables = false; - auto compiled_res = graph_compiler->Build(options); + auto compiled_res = graph_compiler->Build(options, std::move(fetch_ids)); auto compiled_obj = std::make_unique(); *compiled_obj = {std::move(graph_compiler), std::move(compiled_res.runtime_program), scope, diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc index 923607fe5ac..7d339f53f0f 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc @@ -34,6 +34,7 @@ #include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/operators/cinn_launch_op.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" @@ -62,8 +63,8 @@ std::vector GetCompilationKeys(const Graph& graph) { std::vector compilation_keys; for (auto& node : graph.Nodes()) { if (node->IsOp() && node->Name() == kCinnLaunchOp) { - compilation_keys.emplace_back( - BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey))); + compilation_keys.emplace_back(BOOST_GET_CONST( + std::string, node->Op()->GetAttr(operators::kCompilationKey))); } } return compilation_keys; @@ -86,7 +87,8 @@ std::unordered_map> GetInputsInfo( std::unordered_set inputs; for (auto& node : graph.Nodes()) { if (node->IsOp() && node->Name() == kCinnLaunchOp) { - if (BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)) != + if (BOOST_GET_CONST(std::string, + node->Op()->GetAttr(operators::kCompilationKey)) != key) { continue; } diff --git a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc index 941e82cef1b..9bdaf61858f 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include #include @@ -225,6 +226,21 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const { } } +std::unordered_set CinnGraphSymbolization::GetFetchIds() const { + std::unordered_set fetch_names; + fetch_names.reserve(fetch_var_names_.size()); + std::for_each( + fetch_var_names_.begin(), fetch_var_names_.end(), + [this, &fetch_names](const std::string& name) { + PADDLE_ENFORCE_EQ( + var_model_to_program_map_.count(name), 1, + platform::errors::PreconditionNotMet( + "Cannot find %s in var_model_to_program_map_", name.c_str())); + fetch_names.insert(var_model_to_program_map_.at(name)); + }); + return fetch_names; +} + ::cinn::frontend::Program CinnGraphSymbolization::operator()() { std::string builder_name = "NetBuilder_of_graph_" + std::to_string(graph_id_); VLOG(4) << "NetBuilder Name " << builder_name; @@ -235,7 +251,7 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const { auto cinn_scope = CreateCinnScope(feed_map); OpMapperContext ctx(*cinn_scope, target_, &builder, &var_map_, - &var_model_to_program_map_); + &var_model_to_program_map_, &fetch_var_names_); // add all tensor's feed info into context for (auto& feed_pair : feed_map) { ctx.AddFeedInfo(feed_pair.first, feed_pair.second); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h index af60493044c..526eb65a56e 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h @@ -15,8 +15,10 @@ limitations under the License. */ #pragma once #include +#include #include #include +#include #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -84,6 +86,9 @@ class CinnGraphSymbolization { return var_model_to_program_map_; } + // get fetch var ids used in CINN + std::unordered_set GetFetchIds() const; + using OpMapperContext = ::cinn::frontend::OpMapperContext; using FeedInfoMap = std::unordered_map; @@ -95,10 +100,13 @@ class CinnGraphSymbolization { const ::cinn::common::Target& target_; const std::map& input_tensors_; - // preserve local variable map + // preserve cinn variable map std::unordered_map var_map_; std::unordered_map var_model_to_program_map_; + // fetch var names used in paddle + std::unordered_set fetch_var_names_; + // transform all paddle var desc in feed list into cinn_var_descs_ FeedInfoMap GetFeedInfoMapFromInput() const; diff --git a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization_test.cc index be2ca2f73e1..09df9a7ad2c 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization_test.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization_test.cc @@ -48,7 +48,8 @@ class CinnGraphSymbolizationForTest { return OpMapperContext(*cinn_symbol_->CreateCinnScope(feed_map), cinn_symbol_->target_, builder, &cinn_symbol_->var_map_, - &cinn_symbol_->var_model_to_program_map_); + &cinn_symbol_->var_model_to_program_map_, + &cinn_symbol_->fetch_var_names_); } FeedInfoMap GetFeedInfoMapFromInput() { @@ -292,6 +293,7 @@ TEST_F(CinnGraphSymbolizationTest, basic) { ASSERT_NO_THROW((*symbol_)()); ASSERT_FALSE(symbol_->var_map().empty()); ASSERT_FALSE(symbol_->var_model_to_program_map().empty()); + ASSERT_TRUE(symbol_->GetFetchIds().empty()); } } // namespace paddle2cinn diff --git a/paddle/fluid/operators/cinn_launch_op.h b/paddle/fluid/operators/cinn_launch_op.h index 348d1dda027..f7d1328bcef 100644 --- a/paddle/fluid/operators/cinn_launch_op.h +++ b/paddle/fluid/operators/cinn_launch_op.h @@ -30,9 +30,9 @@ namespace paddle { namespace operators { -static constexpr char kX[] = "X"; -static constexpr char kOutputs[] = "Out"; -static constexpr char kCompilationKey[] = "compilation_key"; +constexpr char kX[] = "X"; +constexpr char kOutputs[] = "Out"; +constexpr char kCompilationKey[] = "compilation_key"; using LoDTensor = framework::LoDTensor; using CinnTensor = ::cinn::hlir::framework::Tensor; -- GitLab