未验证 提交 1221307b 编写于 作者: T TeFeng Chen 提交者: GitHub

delivery skip_gc_vars attr to cinn subgraph (#49471)

* delivery skip_gc_vars from the main graph to each subgraph compiled by CINN

* rearrange format and annotation

* fix lacking namespace

* fix segmentation fault cinn subgraph doesn't own kSkipGcVarNames

* deliver all skip_gc_vars of main graph

* add vlog for skip_gc_vars
上级 d5f1e300
...@@ -693,10 +693,18 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -693,10 +693,18 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
VLOG(4) << "Cluster internal vars: " VLOG(4) << "Cluster internal vars: "
<< cluster_debug_info(cluster_internals); << cluster_debug_info(cluster_internals);
// Create a new subgraph according to the found cluster and // Create a new subgraph with the cluster and save it into the CinnCompiler
// save it in CinnCompiler auto subgraph = CreateNewSubGraph(
auto compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph( cluster_set, cluster_internals, cluster_inputs, cluster_outputs);
cluster_set, cluster_internals, cluster_inputs, cluster_outputs)); // Deliver the kSkipGcVarNames attr (if exists) to the subgraph
if (graph->Has(kSkipGcVarNames)) {
const auto& all_skip_gc_vars =
graph->Get<std::unordered_set<std::string>>(kSkipGcVarNames);
auto& sub_skip_gc_vars =
subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames);
sub_skip_gc_vars = all_skip_gc_vars;
}
auto compilation_key = cinn_compiler->AddGraph(std::move(subgraph));
VLOG(4) << "Compilation Key:\n" VLOG(4) << "Compilation Key:\n"
<< cinn_compiler->ReadableKey(compilation_key); << cinn_compiler->ReadableKey(compilation_key);
......
...@@ -38,6 +38,7 @@ constexpr char kInternalVars[] = "InternalVars"; ...@@ -38,6 +38,7 @@ constexpr char kInternalVars[] = "InternalVars";
constexpr char kOutputVars[] = "OutputVars"; constexpr char kOutputVars[] = "OutputVars";
constexpr char kMemOptVarInfoFromMainGraph[] = constexpr char kMemOptVarInfoFromMainGraph[] =
"mem_opt_var_info_from_main_graph"; "mem_opt_var_info_from_main_graph";
constexpr char kSkipGcVarNames[] = "skip_gc_vars";
using Name2VarInfoMap = using Name2VarInfoMap =
std::unordered_map<std::string, std::unordered_map<std::string,
......
...@@ -41,6 +41,7 @@ ...@@ -41,6 +41,7 @@
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/utils/string/string_helper.h"
namespace paddle { namespace paddle {
namespace operators::details { namespace operators::details {
...@@ -50,6 +51,7 @@ using framework::Scope; ...@@ -50,6 +51,7 @@ using framework::Scope;
using CinnInstruction = ::cinn::hlir::framework::Instruction; using CinnInstruction = ::cinn::hlir::framework::Instruction;
using CinnRuntimeProgram = ::cinn::hlir::framework::Program; using CinnRuntimeProgram = ::cinn::hlir::framework::Program;
using framework::paddle2cinn::kMemOptVarInfoFromMainGraph; using framework::paddle2cinn::kMemOptVarInfoFromMainGraph;
using framework::paddle2cinn::kSkipGcVarNames;
using framework::paddle2cinn::Name2VarInfoMap; using framework::paddle2cinn::Name2VarInfoMap;
CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph, CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph,
...@@ -94,14 +96,24 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph, ...@@ -94,14 +96,24 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph,
auto& outer_varinfo = graph.Get<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph); auto& outer_varinfo = graph.Get<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
runtime_graph_->SetNotOwned<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph, runtime_graph_->SetNotOwned<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph,
&outer_varinfo); &outer_varinfo);
// collect skip_eager_vars // use kSkipGcVarNames attr of graph to initialize skip_gc_vars_
if (graph.Has(kSkipGcVarNames)) {
const auto& skip_gc_vars =
graph.Get<std::unordered_set<std::string>>(kSkipGcVarNames);
skip_gc_vars_.insert(skip_gc_vars.begin(), skip_gc_vars.end());
VLOG(4) << "Append skip_gc_vars:["
<< string::join_strings(skip_gc_vars, ',') << "]";
}
// collect variables name list to be skipped in GC
skip_eager_vars_.reserve(input_var_names.size() + output_var_names.size()); skip_eager_vars_.reserve(input_var_names.size() + output_var_names.size());
auto add_skip_var_fn = [&outer_varinfo, this](const std::string& var_name) { auto add_skip_var_fn = [&outer_varinfo, this](const std::string& var_name) {
// if a var exists at outer_varinfo map, // if a var exists at the outer_varinfo map, that means it will be
// that means it can be erased after graph execution // erased by the following eager_deletion_op of current cinn_launch op
if (!outer_varinfo.count(var_name)) { if (!outer_varinfo.count(var_name)) {
skip_eager_vars_.emplace_back(var_name); skip_eager_vars_.emplace_back(var_name);
skip_gc_vars_.insert(var_name); skip_gc_vars_.insert(var_name);
VLOG(4) << "Append a skip_gc_var:" << var_name;
} }
}; };
std::for_each( std::for_each(
...@@ -112,12 +124,13 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph, ...@@ -112,12 +124,13 @@ CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph,
"Distribution of variables in the graph compiled:" "Distribution of variables in the graph compiled:"
"input[%lu],internal[%lu],output[%lu]," "input[%lu],internal[%lu],output[%lu],"
"outer_eager_deletion[%lu],skip_eager_deletion[%lu]," "outer_eager_deletion[%lu],skip_eager_deletion[%lu],"
"initialized_beforehand[%lu]", "skip_gc_vars_[%lu],initialized_beforehand[%lu]",
input_var_names.size(), input_var_names.size(),
internal_var_names_.size(), internal_var_names_.size(),
output_var_names.size(), output_var_names.size(),
outer_varinfo.size(), outer_varinfo.size(),
skip_eager_vars_.size(), skip_eager_vars_.size(),
skip_gc_vars_.size(),
initialized_beforehand_vars_.size()); initialized_beforehand_vars_.size());
} }
......
...@@ -161,13 +161,14 @@ class CinnLaunchContext { ...@@ -161,13 +161,14 @@ class CinnLaunchContext {
std::unique_ptr<framework::ProgramDesc> runtime_program_desc_; std::unique_ptr<framework::ProgramDesc> runtime_program_desc_;
std::unique_ptr<framework::InterpreterCore> interpreter_core_; std::unique_ptr<framework::InterpreterCore> interpreter_core_;
// the name list of skip_gc_vars in runtime for InterpreterCore execution
std::set<std::string> skip_gc_vars_; std::set<std::string> skip_gc_vars_;
// the ir::Graph object converted from the program compiled by CINN // the ir::Graph object converted from the program compiled by CINN
std::unique_ptr<framework::ir::Graph> runtime_graph_; std::unique_ptr<framework::ir::Graph> runtime_graph_;
// a ParallelExecutor to execute the runtime graph // a ParallelExecutor to execute the runtime graph
std::unique_ptr<framework::ParallelExecutor> parallel_executor_; std::unique_ptr<framework::ParallelExecutor> parallel_executor_;
// the name list of skip_eager_vars in runtime // the name list of skip_eager_vars in runtime for ParallelExecutor execution
std::vector<std::string> skip_eager_vars_; std::vector<std::string> skip_eager_vars_;
// because a cinn_pod_value_t does not own a cinn_buffer_t object, // because a cinn_pod_value_t does not own a cinn_buffer_t object,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册