未验证 提交 3ad495e8 编写于 作者: Z Zhen Wang 提交者: GitHub

Add the `GetFetchNames` method in CinnGraphSymbolization. (#37218)

* Add the `GetFetchNames` method in CinnGraphSymbolization.

* Use unordered_set instead vector as the type of fetch_var_names.

* Reuse the definition of kCompilationKey.

* Use CompileOptions to set fetch_var_ids.

* Update the argument passing of GraphCompiler.Build.

* Fix some bugs in CinnGraphSymbolization::GetFetchIds.
上级 c4862d99
...@@ -34,6 +34,7 @@ limitations under the License. */ ...@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn_launch_op.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
...@@ -381,7 +382,7 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, ...@@ -381,7 +382,7 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
input_names.emplace_back(n->Name()); input_names.emplace_back(n->Name());
} }
}); });
cinn_op_desc.SetInput("X", input_names); cinn_op_desc.SetInput(operators::kX, input_names);
std::vector<std::string> output_names; std::vector<std::string> output_names;
std::for_each(cluster_outputs.begin(), cluster_outputs.end(), std::for_each(cluster_outputs.begin(), cluster_outputs.end(),
[&output_names, &deny_var_set](Node* n) { [&output_names, &deny_var_set](Node* n) {
...@@ -389,8 +390,8 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, ...@@ -389,8 +390,8 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
output_names.emplace_back(n->Name()); output_names.emplace_back(n->Name());
} }
}); });
cinn_op_desc.SetOutput("Out", output_names); cinn_op_desc.SetOutput(operators::kOutputs, output_names);
cinn_op_desc.SetAttr(kCompilationKey, compilation_key); cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key);
cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(cluster)); ExtractOpRole(cluster));
cinn_op_desc.Flush(); cinn_op_desc.Flush();
......
...@@ -21,7 +21,6 @@ namespace framework { ...@@ -21,7 +21,6 @@ namespace framework {
namespace paddle2cinn { namespace paddle2cinn {
constexpr char kCinnLaunchOp[] = "cinn_launch"; constexpr char kCinnLaunchOp[] = "cinn_launch";
constexpr char kCompilationKey[] = "compilation_key";
// A pass named BuildCinnPass, the function of this pass is: // A pass named BuildCinnPass, the function of this pass is:
// //
......
...@@ -27,6 +27,7 @@ limitations under the License. */ ...@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/operators/cinn_launch_op.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -91,8 +92,8 @@ std::vector<std::string> GetCompilationKeys(const Graph& graph) { ...@@ -91,8 +92,8 @@ std::vector<std::string> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys; std::vector<std::string> compilation_keys;
for (auto& node : graph.Nodes()) { for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) { if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back( compilation_keys.emplace_back(BOOST_GET_CONST(
BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey))); std::string, node->Op()->GetAttr(operators::kCompilationKey)));
} }
} }
return compilation_keys; return compilation_keys;
......
...@@ -201,11 +201,15 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( ...@@ -201,11 +201,15 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
ApplyPass(cinn_graph.get(), "OpFusion"); ApplyPass(cinn_graph.get(), "OpFusion");
auto scope = BuildScope(target, cinn_graph); auto scope = BuildScope(target, cinn_graph);
auto fetch_ids = symbol.GetFetchIds();
VLOG(4) << "All fetch var ids in CINN: "
<< string::join_strings(fetch_ids, ',');
auto graph_compiler = auto graph_compiler =
std::make_unique<GraphCompiler>(target, scope, cinn_graph); std::make_unique<GraphCompiler>(target, scope, cinn_graph);
GraphCompiler::CompileOptions options; GraphCompiler::CompileOptions options;
options.with_instantiate_variables = false; options.with_instantiate_variables = false;
auto compiled_res = graph_compiler->Build(options); auto compiled_res = graph_compiler->Build(options, std::move(fetch_ids));
auto compiled_obj = std::make_unique<CinnCompiledObject>(); auto compiled_obj = std::make_unique<CinnCompiledObject>();
*compiled_obj = {std::move(graph_compiler), *compiled_obj = {std::move(graph_compiler),
std::move(compiled_res.runtime_program), scope, std::move(compiled_res.runtime_program), scope,
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" #include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/cinn_launch_op.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -62,8 +63,8 @@ std::vector<std::string> GetCompilationKeys(const Graph& graph) { ...@@ -62,8 +63,8 @@ std::vector<std::string> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys; std::vector<std::string> compilation_keys;
for (auto& node : graph.Nodes()) { for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) { if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back( compilation_keys.emplace_back(BOOST_GET_CONST(
BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey))); std::string, node->Op()->GetAttr(operators::kCompilationKey)));
} }
} }
return compilation_keys; return compilation_keys;
...@@ -86,7 +87,8 @@ std::unordered_map<std::string, std::vector<int64_t>> GetInputsInfo( ...@@ -86,7 +87,8 @@ std::unordered_map<std::string, std::vector<int64_t>> GetInputsInfo(
std::unordered_set<std::string> inputs; std::unordered_set<std::string> inputs;
for (auto& node : graph.Nodes()) { for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) { if (node->IsOp() && node->Name() == kCinnLaunchOp) {
if (BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)) != if (BOOST_GET_CONST(std::string,
node->Op()->GetAttr(operators::kCompilationKey)) !=
key) { key) {
continue; continue;
} }
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <queue> #include <queue>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
...@@ -225,6 +226,21 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const { ...@@ -225,6 +226,21 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const {
} }
} }
std::unordered_set<std::string> CinnGraphSymbolization::GetFetchIds() const {
std::unordered_set<std::string> fetch_names;
fetch_names.reserve(fetch_var_names_.size());
std::for_each(
fetch_var_names_.begin(), fetch_var_names_.end(),
[this, &fetch_names](const std::string& name) {
PADDLE_ENFORCE_EQ(
var_model_to_program_map_.count(name), 1,
platform::errors::PreconditionNotMet(
"Cannot find %s in var_model_to_program_map_", name.c_str()));
fetch_names.insert(var_model_to_program_map_.at(name));
});
return fetch_names;
}
::cinn::frontend::Program CinnGraphSymbolization::operator()() { ::cinn::frontend::Program CinnGraphSymbolization::operator()() {
std::string builder_name = "NetBuilder_of_graph_" + std::to_string(graph_id_); std::string builder_name = "NetBuilder_of_graph_" + std::to_string(graph_id_);
VLOG(4) << "NetBuilder Name " << builder_name; VLOG(4) << "NetBuilder Name " << builder_name;
...@@ -235,7 +251,7 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const { ...@@ -235,7 +251,7 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const {
auto cinn_scope = CreateCinnScope(feed_map); auto cinn_scope = CreateCinnScope(feed_map);
OpMapperContext ctx(*cinn_scope, target_, &builder, &var_map_, OpMapperContext ctx(*cinn_scope, target_, &builder, &var_map_,
&var_model_to_program_map_); &var_model_to_program_map_, &fetch_var_names_);
// add all tensor's feed info into context // add all tensor's feed info into context
for (auto& feed_pair : feed_map) { for (auto& feed_pair : feed_map) {
ctx.AddFeedInfo(feed_pair.first, feed_pair.second); ctx.AddFeedInfo(feed_pair.first, feed_pair.second);
......
...@@ -15,8 +15,10 @@ limitations under the License. */ ...@@ -15,8 +15,10 @@ limitations under the License. */
#pragma once #pragma once
#include <map> #include <map>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -84,6 +86,9 @@ class CinnGraphSymbolization { ...@@ -84,6 +86,9 @@ class CinnGraphSymbolization {
return var_model_to_program_map_; return var_model_to_program_map_;
} }
// get fetch var ids used in CINN
std::unordered_set<std::string> GetFetchIds() const;
using OpMapperContext = ::cinn::frontend::OpMapperContext; using OpMapperContext = ::cinn::frontend::OpMapperContext;
using FeedInfoMap = using FeedInfoMap =
std::unordered_map<std::string, OpMapperContext::FeedInfo>; std::unordered_map<std::string, OpMapperContext::FeedInfo>;
...@@ -95,10 +100,13 @@ class CinnGraphSymbolization { ...@@ -95,10 +100,13 @@ class CinnGraphSymbolization {
const ::cinn::common::Target& target_; const ::cinn::common::Target& target_;
const std::map<std::string, const LoDTensor*>& input_tensors_; const std::map<std::string, const LoDTensor*>& input_tensors_;
// preserve local variable map // preserve cinn variable map
std::unordered_map<std::string, ::cinn::frontend::Variable> var_map_; std::unordered_map<std::string, ::cinn::frontend::Variable> var_map_;
std::unordered_map<std::string, std::string> var_model_to_program_map_; std::unordered_map<std::string, std::string> var_model_to_program_map_;
// fetch var names used in paddle
std::unordered_set<std::string> fetch_var_names_;
// transform all paddle var desc in feed list into cinn_var_descs_ // transform all paddle var desc in feed list into cinn_var_descs_
FeedInfoMap GetFeedInfoMapFromInput() const; FeedInfoMap GetFeedInfoMapFromInput() const;
......
...@@ -48,7 +48,8 @@ class CinnGraphSymbolizationForTest { ...@@ -48,7 +48,8 @@ class CinnGraphSymbolizationForTest {
return OpMapperContext(*cinn_symbol_->CreateCinnScope(feed_map), return OpMapperContext(*cinn_symbol_->CreateCinnScope(feed_map),
cinn_symbol_->target_, builder, cinn_symbol_->target_, builder,
&cinn_symbol_->var_map_, &cinn_symbol_->var_map_,
&cinn_symbol_->var_model_to_program_map_); &cinn_symbol_->var_model_to_program_map_,
&cinn_symbol_->fetch_var_names_);
} }
FeedInfoMap GetFeedInfoMapFromInput() { FeedInfoMap GetFeedInfoMapFromInput() {
...@@ -292,6 +293,7 @@ TEST_F(CinnGraphSymbolizationTest, basic) { ...@@ -292,6 +293,7 @@ TEST_F(CinnGraphSymbolizationTest, basic) {
ASSERT_NO_THROW((*symbol_)()); ASSERT_NO_THROW((*symbol_)());
ASSERT_FALSE(symbol_->var_map().empty()); ASSERT_FALSE(symbol_->var_map().empty());
ASSERT_FALSE(symbol_->var_model_to_program_map().empty()); ASSERT_FALSE(symbol_->var_model_to_program_map().empty());
ASSERT_TRUE(symbol_->GetFetchIds().empty());
} }
} // namespace paddle2cinn } // namespace paddle2cinn
......
...@@ -30,9 +30,9 @@ ...@@ -30,9 +30,9 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static constexpr char kX[] = "X"; constexpr char kX[] = "X";
static constexpr char kOutputs[] = "Out"; constexpr char kOutputs[] = "Out";
static constexpr char kCompilationKey[] = "compilation_key"; constexpr char kCompilationKey[] = "compilation_key";
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using CinnTensor = ::cinn::hlir::framework::Tensor; using CinnTensor = ::cinn::hlir::framework::Tensor;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册