diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index f7306bfc9a28e776229436cb0454fe95df93cc06..5de57fed5d54a57c263335c8347b2543e7619562 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -693,10 +693,18 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { VLOG(4) << "Cluster internal vars: " << cluster_debug_info(cluster_internals); - // Create a new subgraph according to the found cluster and - // save it in CinnCompiler - auto compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph( - cluster_set, cluster_internals, cluster_inputs, cluster_outputs)); + // Create a new subgraph with the cluster and save it into the CinnCompiler + auto subgraph = CreateNewSubGraph( + cluster_set, cluster_internals, cluster_inputs, cluster_outputs); + // Deliver the kSkipGcVarNames attr (if exists) to the subgraph + if (graph->Has(kSkipGcVarNames)) { + const auto& all_skip_gc_vars = + graph->Get>(kSkipGcVarNames); + auto& sub_skip_gc_vars = + subgraph->GetOrInit>(kSkipGcVarNames); + sub_skip_gc_vars = all_skip_gc_vars; + } + auto compilation_key = cinn_compiler->AddGraph(std::move(subgraph)); VLOG(4) << "Compilation Key:\n" << cinn_compiler->ReadableKey(compilation_key); diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 42b98b329833f8013b1d59c95fd4be2922b54bab..55caae596cee4eccff38673c874bb0ff831cb62f 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -38,6 +38,7 @@ constexpr char kInternalVars[] = "InternalVars"; constexpr char kOutputVars[] = "OutputVars"; constexpr char kMemOptVarInfoFromMainGraph[] = "mem_opt_var_info_from_main_graph"; +constexpr char kSkipGcVarNames[] = "skip_gc_vars"; using Name2VarInfoMap = std::unordered_map(kMemOptVarInfoFromMainGraph); runtime_graph_->SetNotOwned(kMemOptVarInfoFromMainGraph, &outer_varinfo); - // collect skip_eager_vars + // use kSkipGcVarNames attr of graph to initialize skip_gc_vars_ + if (graph.Has(kSkipGcVarNames)) { + const auto& skip_gc_vars = + graph.Get>(kSkipGcVarNames); + skip_gc_vars_.insert(skip_gc_vars.begin(), skip_gc_vars.end()); + VLOG(4) << "Append skip_gc_vars:[" + << string::join_strings(skip_gc_vars, ',') << "]"; + } + + // collect variables name list to be skipped in GC skip_eager_vars_.reserve(input_var_names.size() + output_var_names.size()); auto add_skip_var_fn = [&outer_varinfo, this](const std::string& var_name) { - // if a var exists at outer_varinfo map, - // that means it can be erased after graph execution + // if a var exists at the outer_varinfo map, that means it will be + // erased by the following eager_deletion_op of current cinn_launch op if (!outer_varinfo.count(var_name)) { skip_eager_vars_.emplace_back(var_name); skip_gc_vars_.insert(var_name); + VLOG(4) << "Append a skip_gc_var:" << var_name; } }; std::for_each( @@ -112,12 +124,13 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph, "Distribution of variables in the graph compiled:" "input[%lu],internal[%lu],output[%lu]," "outer_eager_deletion[%lu],skip_eager_deletion[%lu]," - "initialized_beforehand[%lu]", + "skip_gc_vars_[%lu],initialized_beforehand[%lu]", input_var_names.size(), internal_var_names_.size(), output_var_names.size(), outer_varinfo.size(), skip_eager_vars_.size(), + skip_gc_vars_.size(), initialized_beforehand_vars_.size()); } diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.h b/paddle/fluid/operators/cinn/cinn_launch_context.h index e66658750bb230e0a81cb31095bacbf509a6d635..f4794e6335bb684c13c2b4cb6d84612395791e62 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.h +++ b/paddle/fluid/operators/cinn/cinn_launch_context.h @@ -161,13 +161,14 @@ class CinnLaunchContext { std::unique_ptr runtime_program_desc_; std::unique_ptr interpreter_core_; + // the name list of skip_gc_vars in runtime for InterpreterCore execution std::set skip_gc_vars_; // the ir::Graph object converted from the program compiled by CINN std::unique_ptr runtime_graph_; // a ParallelExecutor to execute the runtime graph std::unique_ptr parallel_executor_; - // the name list of skip_eager_vars in runtime + // the name list of skip_eager_vars in runtime for ParallelExecutor execution std::vector skip_eager_vars_; // because a cinn_pod_value_t does not own a cinn_buffer_t object,