未验证 提交 c1ce54bf 编写于 作者: J jiangcheng 提交者: GitHub

CINN add fetch op for skip gc vars (#49553)

* CINN add fetch op for skip gc vars

* perfect test annotation

* break if not is_only_used_internal

* move skip_gc_var_names get out of for loop
上级 414ca6b9
......@@ -485,7 +485,8 @@ void AnalyseClusterVariables(
GraphNodeSet* cluster_inputs,
GraphNodeSet* cluster_outputs,
GraphNodeSet* cluster_internals,
bool is_inference_stage) {
bool is_inference_stage,
const std::unordered_set<std::string>& skip_gc_var_names) {
// collecting all input and output of op
for (auto* op_node : cluster) {
const auto& op_name = op_node->Name();
......@@ -510,9 +511,11 @@ void AnalyseClusterVariables(
// the internal node is must an output node of sub-graph,
// but not any input node of out-graph.
bool is_only_used_internal = true;
for (auto* next_op_node : var_node->outputs) {
is_only_used_internal &= (cluster.count(next_op_node) > 0);
// And should not in skip gc var
bool is_only_used_internal = !skip_gc_var_names.count(var_node->Name());
for (size_t i = 0; i < var_node->outputs.size() && is_only_used_internal;
++i) {
is_only_used_internal &= (cluster.count(var_node->outputs[i]) > 0);
}
if (is_only_used_internal) {
cluster_internals->insert(var_node);
......@@ -672,6 +675,12 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
return res;
};
std::unordered_set<std::string> skip_gc_var_names;
if (graph->Has(kSkipGcVarNames)) {
skip_gc_var_names =
graph->Get<std::unordered_set<std::string>>(kSkipGcVarNames);
}
auto* cinn_compiler = CinnCompiler::GetInstance();
for (const auto& node_vec : clusters) {
// Classify var node to inputs, outputs, and internals.
......@@ -685,7 +694,8 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
&cluster_inputs,
&cluster_outputs,
&cluster_internals,
is_inference_stage);
is_inference_stage,
skip_gc_var_names);
VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set);
VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs);
......
......@@ -670,6 +670,47 @@ TEST(BuildCinnPassTest, NoNeedBufferInput) {
std::unordered_set<std::string>({"var5", "var6"}));
}
TEST(BuildCinnPassTest, TestSkipGcVars) {
auto g = BuildGraphWithOneCinnSubgraph();
std::unordered_set<std::string> all_skip_gc_vars = {"var1", "var3"};
g->SetNotOwned(kSkipGcVarNames, &all_skip_gc_vars);
auto pass =
paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
pass->Apply(g.get());
// After search, the graph should as following
// fake1 --> v1 --
// | --> kCinnLaunchOp --> v4 --> fake2
// v2 --
const auto& nodes = g->Nodes();
ASSERT_EQ(nodes.size(), static_cast<size_t>(7));
ASSERT_TRUE(CheckGraphIndependence(nodes));
// A new op named kCinnLaunchOp should be added
ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
// After search, there should has just one cinn subgraph
// Note v3 has fetched because of v3 in kSkipGcVarNames
// And v1 is a feed var so v1 no need fetched though it in kSkipGcVarNames
// feed --> v1 --
// | --> mul --> v3 --> relu --> v4 --> fetch
// feed --> v2 -- --> fetch
auto compilation_keys = GetCompilationKeys(*g);
ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1));
auto* cinn_compiler = CinnCompiler::GetInstance();
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(10));
ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_EQ(CountNode(subnodes, "feed"), 2);
// var3 and var4 should has fetch op
ASSERT_EQ(CountNode(subnodes, "fetch"), 2);
}
} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册