未验证 提交 a1ad003c 编写于 作者: T TeFeng Chen 提交者: GitHub

save the name lists of variables of a cinn subgraph as its attributes (#39622)

* save the name lists of the input,internal and output variables of a subgraph as its attribute

* fix compile error
上级 4501abd6
......@@ -368,10 +368,32 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
subgraph.get());
AddOutputVar(output_vars, cluster, old_op2new_op, old_var2new_var,
subgraph.get());
// Store the input variables whose buffer are not needed as
// attribute of the graph.
// Save lists of input variables, internal variables and output variables
// of the cluster as attributes of the subgraph for convenience.
auto collect_names_fn = [](
const GraphNodeSet& nodes,
const std::unordered_set<std::string>& ignore_names) {
auto result = std::make_unique<std::vector<std::string>>();
for (auto* node : nodes) {
if (ignore_names.count(node->Name())) {
continue;
}
result->emplace_back(node->Name());
}
return result;
};
subgraph->Set<std::vector<std::string>>(
kInternalVars, collect_names_fn(cluster_internals, {}).release());
subgraph->Set<std::vector<std::string>>(
kOutputVars, collect_names_fn(cluster_outputs, {}).release());
// Divide input variables into two parts: one is common and will be used
// in execution, the other may be empty and it is those variables whose
// buffer are not needed and only be used in graph symbolization
auto no_need_buffer_feeds = std::make_unique<std::unordered_set<std::string>>(
ExtractNoNeedBufferFeeds(cluster, cluster_inputs));
subgraph->Set<std::vector<std::string>>(
kInputVars,
collect_names_fn(cluster_inputs, *no_need_buffer_feeds).release());
subgraph->Set<std::unordered_set<std::string>>(
kNoNeedBufferFeeds, no_need_buffer_feeds.release());
// initialize empty map for kMemOptVarInfoFromMainGraph attribute,
......@@ -458,33 +480,18 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
framework::OpDesc cinn_op_desc;
cinn_op_desc.SetType(kCinnLaunchOp);
// Divide input variables as two parts:
// the ones that data buffer are not needed and remain ones
std::vector<std::string> op_kx_inputs, no_need_buffer_inputs;
const auto& subgraph =
CinnCompiler::GetInstance()->FindGraph(compilation_key);
auto& no_need_buffer_feeds =
const auto& no_need_buffer_feeds =
subgraph.Get<std::unordered_set<std::string>>(kNoNeedBufferFeeds);
for (const auto* n : cluster_inputs) {
const auto& var_name = n->Name();
if (no_need_buffer_feeds.count(var_name)) {
no_need_buffer_inputs.emplace_back(var_name);
} else {
op_kx_inputs.emplace_back(var_name);
}
}
cinn_op_desc.SetInput(operators::kX, op_kx_inputs);
cinn_op_desc.SetInput(operators::kNoNeedBufferX, no_need_buffer_inputs);
std::vector<std::string> output_names;
std::for_each(cluster_outputs.begin(), cluster_outputs.end(),
[&output_names, &deny_var_set](Node* n) {
if (n->Var() != nullptr && !deny_var_set.count(n->Name())) {
output_names.emplace_back(n->Name());
}
});
cinn_op_desc.SetOutput(operators::kOutputs, output_names);
cinn_op_desc.SetInput(operators::kX,
subgraph.Get<std::vector<std::string>>(kInputVars));
cinn_op_desc.SetInput(operators::kNoNeedBufferX,
std::vector<std::string>(no_need_buffer_feeds.begin(),
no_need_buffer_feeds.end()));
cinn_op_desc.SetOutput(operators::kOutputs,
subgraph.Get<std::vector<std::string>>(kOutputVars));
cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key);
cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(cluster));
......
......@@ -21,7 +21,10 @@ namespace framework {
namespace paddle2cinn {
constexpr char kCinnLaunchOp[] = "cinn_launch";
constexpr char kNoNeedBufferFeeds[] = "no_need_buffer_feeds";
constexpr char kInputVars[] = "InputVars";
constexpr char kNoNeedBufferFeeds[] = "NoNeedBufferFeeds";
constexpr char kInternalVars[] = "InternalVars";
constexpr char kOutputVars[] = "OutputVars";
constexpr char kMemOptVarInfoFromMainGraph[] =
"mem_opt_var_info_from_main_graph";
......
......@@ -652,6 +652,19 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) {
ASSERT_EQ(no_need_buffer_feeds.size(), 2);
ASSERT_EQ(no_need_buffer_feeds,
std::unordered_set<std::string>({"var2", "var3"}));
// check the attributes of variable lists are saved correctly
ASSERT_TRUE(subgraph.Has(kInputVars));
EXPECT_EQ(subgraph.Get<std::vector<std::string>>(kInputVars),
std::vector<std::string>({"var1"}));
ASSERT_TRUE(subgraph.Has(kInternalVars));
EXPECT_EQ(subgraph.Get<std::vector<std::string>>(kInternalVars),
std::vector<std::string>({"var4"}));
ASSERT_TRUE(subgraph.Has(kOutputVars));
const auto& output_vars = subgraph.Get<std::vector<std::string>>(kOutputVars);
EXPECT_EQ(
std::unordered_set<std::string>(output_vars.begin(), output_vars.end()),
std::unordered_set<std::string>({"var5", "var6"}));
}
} // namespace paddle2cinn
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册