未验证 提交 3ad495e8 编写于 作者: Z Zhen Wang 提交者: GitHub

Add the `GetFetchNames` method in CinnGraphSymbolization. (#37218)

* Add the `GetFetchNames` method in CinnGraphSymbolization.

* Use unordered_set instead vector as the type of fetch_var_names.

* Reuse the definition of kCompilationKey.

* Use CompileOptions to set fetch_var_ids.

* Update the argument passing of GraphCompiler.Build.

* Fix some bugs in CinnGraphSymbolization::GetFetchIds.
上级 c4862d99
......@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn_launch_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
......@@ -381,7 +382,7 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
input_names.emplace_back(n->Name());
}
});
cinn_op_desc.SetInput("X", input_names);
cinn_op_desc.SetInput(operators::kX, input_names);
std::vector<std::string> output_names;
std::for_each(cluster_outputs.begin(), cluster_outputs.end(),
[&output_names, &deny_var_set](Node* n) {
......@@ -389,8 +390,8 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
output_names.emplace_back(n->Name());
}
});
cinn_op_desc.SetOutput("Out", output_names);
cinn_op_desc.SetAttr(kCompilationKey, compilation_key);
cinn_op_desc.SetOutput(operators::kOutputs, output_names);
cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key);
cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(cluster));
cinn_op_desc.Flush();
......
......@@ -21,7 +21,6 @@ namespace framework {
namespace paddle2cinn {
constexpr char kCinnLaunchOp[] = "cinn_launch";
constexpr char kCompilationKey[] = "compilation_key";
// A pass named BuildCinnPass, the function of this pass is:
//
......
......@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/operators/cinn_launch_op.h"
namespace paddle {
namespace framework {
......@@ -91,8 +92,8 @@ std::vector<std::string> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(
BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)));
compilation_keys.emplace_back(BOOST_GET_CONST(
std::string, node->Op()->GetAttr(operators::kCompilationKey)));
}
}
return compilation_keys;
......
......@@ -201,11 +201,15 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
ApplyPass(cinn_graph.get(), "OpFusion");
auto scope = BuildScope(target, cinn_graph);
auto fetch_ids = symbol.GetFetchIds();
VLOG(4) << "All fetch var ids in CINN: "
<< string::join_strings(fetch_ids, ',');
auto graph_compiler =
std::make_unique<GraphCompiler>(target, scope, cinn_graph);
GraphCompiler::CompileOptions options;
options.with_instantiate_variables = false;
auto compiled_res = graph_compiler->Build(options);
auto compiled_res = graph_compiler->Build(options, std::move(fetch_ids));
auto compiled_obj = std::make_unique<CinnCompiledObject>();
*compiled_obj = {std::move(graph_compiler),
std::move(compiled_res.runtime_program), scope,
......
......@@ -34,6 +34,7 @@
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/cinn_launch_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
......@@ -62,8 +63,8 @@ std::vector<std::string> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(
BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)));
compilation_keys.emplace_back(BOOST_GET_CONST(
std::string, node->Op()->GetAttr(operators::kCompilationKey)));
}
}
return compilation_keys;
......@@ -86,7 +87,8 @@ std::unordered_map<std::string, std::vector<int64_t>> GetInputsInfo(
std::unordered_set<std::string> inputs;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
if (BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)) !=
if (BOOST_GET_CONST(std::string,
node->Op()->GetAttr(operators::kCompilationKey)) !=
key) {
continue;
}
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm>
#include <queue>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
......@@ -225,6 +226,21 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const {
}
}
std::unordered_set<std::string> CinnGraphSymbolization::GetFetchIds() const {
std::unordered_set<std::string> fetch_names;
fetch_names.reserve(fetch_var_names_.size());
std::for_each(
fetch_var_names_.begin(), fetch_var_names_.end(),
[this, &fetch_names](const std::string& name) {
PADDLE_ENFORCE_EQ(
var_model_to_program_map_.count(name), 1,
platform::errors::PreconditionNotMet(
"Cannot find %s in var_model_to_program_map_", name.c_str()));
fetch_names.insert(var_model_to_program_map_.at(name));
});
return fetch_names;
}
::cinn::frontend::Program CinnGraphSymbolization::operator()() {
std::string builder_name = "NetBuilder_of_graph_" + std::to_string(graph_id_);
VLOG(4) << "NetBuilder Name " << builder_name;
......@@ -235,7 +251,7 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const {
auto cinn_scope = CreateCinnScope(feed_map);
OpMapperContext ctx(*cinn_scope, target_, &builder, &var_map_,
&var_model_to_program_map_);
&var_model_to_program_map_, &fetch_var_names_);
// add all tensor's feed info into context
for (auto& feed_pair : feed_map) {
ctx.AddFeedInfo(feed_pair.first, feed_pair.second);
......
......@@ -15,8 +15,10 @@ limitations under the License. */
#pragma once
#include <map>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -84,6 +86,9 @@ class CinnGraphSymbolization {
return var_model_to_program_map_;
}
// get fetch var ids used in CINN
std::unordered_set<std::string> GetFetchIds() const;
using OpMapperContext = ::cinn::frontend::OpMapperContext;
using FeedInfoMap =
std::unordered_map<std::string, OpMapperContext::FeedInfo>;
......@@ -95,10 +100,13 @@ class CinnGraphSymbolization {
const ::cinn::common::Target& target_;
const std::map<std::string, const LoDTensor*>& input_tensors_;
// preserve local variable map
// preserve cinn variable map
std::unordered_map<std::string, ::cinn::frontend::Variable> var_map_;
std::unordered_map<std::string, std::string> var_model_to_program_map_;
// fetch var names used in paddle
std::unordered_set<std::string> fetch_var_names_;
// transform all paddle var desc in feed list into cinn_var_descs_
FeedInfoMap GetFeedInfoMapFromInput() const;
......
......@@ -48,7 +48,8 @@ class CinnGraphSymbolizationForTest {
return OpMapperContext(*cinn_symbol_->CreateCinnScope(feed_map),
cinn_symbol_->target_, builder,
&cinn_symbol_->var_map_,
&cinn_symbol_->var_model_to_program_map_);
&cinn_symbol_->var_model_to_program_map_,
&cinn_symbol_->fetch_var_names_);
}
FeedInfoMap GetFeedInfoMapFromInput() {
......@@ -292,6 +293,7 @@ TEST_F(CinnGraphSymbolizationTest, basic) {
ASSERT_NO_THROW((*symbol_)());
ASSERT_FALSE(symbol_->var_map().empty());
ASSERT_FALSE(symbol_->var_model_to_program_map().empty());
ASSERT_TRUE(symbol_->GetFetchIds().empty());
}
} // namespace paddle2cinn
......
......@@ -30,9 +30,9 @@
namespace paddle {
namespace operators {
static constexpr char kX[] = "X";
static constexpr char kOutputs[] = "Out";
static constexpr char kCompilationKey[] = "compilation_key";
constexpr char kX[] = "X";
constexpr char kOutputs[] = "Out";
constexpr char kCompilationKey[] = "compilation_key";
using LoDTensor = framework::LoDTensor;
using CinnTensor = ::cinn::hlir::framework::Tensor;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册