未验证 提交 bad49b51 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] collect inplace var into cinn op desc's kInplaceVarNames attribute (#49898)

* [CINN] collect inplace var into cinn op desc's kInplaceVarNames attribute

* attr move from op desc to subgraph

* GetFetchIds from var_map instead of var_model_to_program_map_
上级 8e02f290
...@@ -144,8 +144,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -144,8 +144,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const auto& arg_names = desc->Input(param_name); const auto& arg_names = desc->Input(param_name);
for (const auto& arg_name : arg_names) { for (const auto& arg_name : arg_names) {
deny_var_set.insert(arg_name); deny_var_set.insert(arg_name);
VLOG(4) << "deny param [" << param_name << "]'s argument name"
<< " is [" << arg_name << "].";
} }
} }
...@@ -153,8 +151,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -153,8 +151,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const auto& arg_names = desc->Output(param_name); const auto& arg_names = desc->Output(param_name);
for (const auto& arg_name : arg_names) { for (const auto& arg_name : arg_names) {
deny_var_set.insert(arg_name); deny_var_set.insert(arg_name);
VLOG(4) << "deny param [" << param_name << "]'s argument name"
<< " is [" << arg_name << "].";
} }
} }
} }
...@@ -166,48 +162,27 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames( ...@@ -166,48 +162,27 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
return deny_var_set; return deny_var_set;
} }
std::unordered_set<std::string> OpTransInfo::GetIgnoreInplaceVarNames( std::unordered_set<std::string> OpTransInfo::GetInplaceVarNames(
const OpDesc& op_desc) const { const GraphNodeSet& cluster) {
if (!ignore_inplace_param_cond_.count(op_desc.Type())) { std::unordered_set<std::string> inplace_var_set;
return {};
}
const auto& ignore_inplace_names = for (auto* op : cluster) {
ignore_inplace_param_cond_.at(op_desc.Type()); // skip if not op
VLOG(4) << "We found ignore inplace param " if (!op->IsOp() || !op->Op()) {
<< GetDebugInfo(ignore_inplace_names) << " in op [" << op_desc.Type() continue;
<< "].";
std::unordered_set<std::string> ignore_inplace_set;
for (const auto& param_name : ignore_inplace_names) {
if (op_desc.HasOutput(param_name)) {
const auto& arg_names = op_desc.Output(param_name);
ignore_inplace_set.insert(arg_names.begin(), arg_names.end());
} }
} const auto& op_desc = *op->Op();
VLOG(4) << "All ignore inplace var names are " // check whether input and output have same argument
<< GetDebugInfo(ignore_inplace_set); auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
return ignore_inplace_set; for (auto& name : op_desc.OutputArgumentNames()) {
} if (input_set.count(name)) {
inplace_var_set.insert(name);
bool OpTransInfo::IsInplaceOp( }
const OpDesc& op_desc,
const std::unordered_set<std::string>& deny_var_names) const {
const auto& ignore_inplace_set = GetIgnoreInplaceVarNames(op_desc);
auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name) > 0 && !deny_var_names.count(name) &&
!ignore_inplace_set.count(name)) {
VLOG(4) << "The argument " << name << " in op " << op_desc.Type()
<< " is a inplace op, skip!";
return true;
} }
} }
return false; return inplace_var_set;
} }
namespace { namespace {
...@@ -503,6 +478,14 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, ...@@ -503,6 +478,14 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
// initialize empty map for kMemOptVarInfoFromMainGraph attribute, // initialize empty map for kMemOptVarInfoFromMainGraph attribute,
// it will be filled on the share_mem_opt_info_to_subgraph pass // it will be filled on the share_mem_opt_info_to_subgraph pass
subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph); subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
auto inplace_var_names = std::make_unique<std::unordered_set<std::string>>(
OpTransInfo::GetInplaceVarNames(cluster));
VLOG_IF(4, !inplace_var_names->empty())
<< "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names);
subgraph->Set<std::unordered_set<std::string>>(kInplaceVarNames,
inplace_var_names.release());
return subgraph; return subgraph;
} }
...@@ -594,7 +577,6 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, ...@@ -594,7 +577,6 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_outputs,
int64_t compilation_key, int64_t compilation_key,
const std::unordered_set<std::string>& deny_var_set,
Graph* graph) { Graph* graph) {
// Add the cinn launch op // Add the cinn launch op
framework::OpDesc cinn_op_desc; framework::OpDesc cinn_op_desc;
...@@ -615,6 +597,7 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, ...@@ -615,6 +597,7 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key); cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key);
cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(cluster)); ExtractOpRole(cluster));
cinn_op_desc.Flush(); cinn_op_desc.Flush();
auto* cinn_op_node = graph->CreateOpNode(&cinn_op_desc); auto* cinn_op_node = graph->CreateOpNode(&cinn_op_desc);
// Add new links from or to the cinn launch op node // Add new links from or to the cinn launch op node
...@@ -639,21 +622,15 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster, ...@@ -639,21 +622,15 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster,
// kCinnLaunchOp, and inputs ares cluster_inputs and outputs are // kCinnLaunchOp, and inputs ares cluster_inputs and outputs are
// cluster_outputs. // cluster_outputs.
// Meanwhile, move all links of cluster to the cinn op. // Meanwhile, move all links of cluster to the cinn op.
void ReplaceSubGraphWithCinnOpNode( void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster,
const GraphNodeSet& cluster, const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs,
const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_internals,
const GraphNodeSet& cluster_internals, int64_t compilation_key,
int64_t compilation_key, Graph* graph) {
const std::unordered_set<std::string>& deny_var_set,
Graph* graph) {
// Add the cinn op node whose name is "kCinnLaunchOp" into graph // Add the cinn op node whose name is "kCinnLaunchOp" into graph
AddCinnOpToGraph(cluster, AddCinnOpToGraph(
cluster_inputs, cluster, cluster_inputs, cluster_outputs, compilation_key, graph);
cluster_outputs,
compilation_key,
deny_var_set,
graph);
// Remove the cinn subgraph from graph // Remove the cinn subgraph from graph
RemoveSubGraphFromGraph(cluster, cluster_internals, graph); RemoveSubGraphFromGraph(cluster, cluster_internals, graph);
} }
...@@ -667,9 +644,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -667,9 +644,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
OpTransInfo trans_info; OpTransInfo trans_info;
const auto& deny_var_set = trans_info.GetDenyVarNames(graph->Nodes()); auto teller = [&allow_ops, &deny_ops, &trans_info](const Node* node) {
auto teller = [&allow_ops, &deny_ops, &trans_info, &deny_var_set](
const Node* node) {
const auto& node_name = node->Name(); const auto& node_name = node->Name();
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find( bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node_name) != nullptr; node_name) != nullptr;
...@@ -679,10 +654,9 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -679,10 +654,9 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
is_dynamic = trans_info.dynamic_op_cond().at(node_name)(*node); is_dynamic = trans_info.dynamic_op_cond().at(node_name)(*node);
} }
bool is_support = bool is_support = registered &&
registered && !trans_info.default_deny_ops().count(node_name) && !trans_info.default_deny_ops().count(node_name) &&
!is_dynamic && !is_dynamic;
(node->IsOp() && !trans_info.IsInplaceOp(*node->Op(), deny_var_set));
// if the op type is registered in CINN and allow_ops is not empty, return // if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops // true only when it is in allow_ops
if (!allow_ops.empty()) { if (!allow_ops.empty()) {
...@@ -714,19 +688,23 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -714,19 +688,23 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
return res; return res;
}; };
std::unordered_set<std::string> skip_gc_var_names; std::unordered_set<std::string> all_skip_gc_vars;
if (graph->Has(kSkipGcVarNames)) { if (graph->Has(kSkipGcVarNames)) {
skip_gc_var_names = all_skip_gc_vars =
graph->Get<std::unordered_set<std::string>>(kSkipGcVarNames); graph->Get<std::unordered_set<std::string>>(kSkipGcVarNames);
VLOG_IF(4, !all_skip_gc_vars.empty())
<< "All skip gc var names are: " << GetDebugInfo(all_skip_gc_vars);
} }
const auto& deny_var_set = trans_info.GetDenyVarNames(graph->Nodes());
VLOG_IF(4, !deny_var_set.empty())
<< "All deny var names are: " << GetDebugInfo(deny_var_set);
auto* cinn_compiler = CinnCompiler::GetInstance(); auto* cinn_compiler = CinnCompiler::GetInstance();
for (const auto& node_vec : clusters) { for (const auto& node_vec : clusters) {
// Classify var node to inputs, outputs, and internals. // Classify var node to inputs, outputs, and internals.
GraphNodeSet cluster_set(node_vec.begin(), node_vec.end()); GraphNodeSet cluster_set(node_vec.begin(), node_vec.end());
auto deny_var_set = trans_info.GetDenyVarNames(cluster_set);
GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals; GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals;
AnalyseClusterVariables(cluster_set, AnalyseClusterVariables(cluster_set,
deny_var_set, deny_var_set,
...@@ -734,7 +712,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -734,7 +712,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
&cluster_outputs, &cluster_outputs,
&cluster_internals, &cluster_internals,
is_inference_stage, is_inference_stage,
skip_gc_var_names); all_skip_gc_vars);
VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set); VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set);
VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs); VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs);
...@@ -747,8 +725,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -747,8 +725,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
cluster_set, cluster_internals, cluster_inputs, cluster_outputs); cluster_set, cluster_internals, cluster_inputs, cluster_outputs);
// Deliver the kSkipGcVarNames attr (if exists) to the subgraph // Deliver the kSkipGcVarNames attr (if exists) to the subgraph
if (graph->Has(kSkipGcVarNames)) { if (graph->Has(kSkipGcVarNames)) {
const auto& all_skip_gc_vars =
graph->Get<std::unordered_set<std::string>>(kSkipGcVarNames);
auto& sub_skip_gc_vars = auto& sub_skip_gc_vars =
subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames); subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames);
sub_skip_gc_vars = all_skip_gc_vars; sub_skip_gc_vars = all_skip_gc_vars;
...@@ -763,7 +739,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { ...@@ -763,7 +739,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
cluster_outputs, cluster_outputs,
cluster_internals, cluster_internals,
compilation_key, compilation_key,
deny_var_set,
graph); graph);
} }
} }
......
...@@ -39,6 +39,7 @@ constexpr char kOutputVars[] = "OutputVars"; ...@@ -39,6 +39,7 @@ constexpr char kOutputVars[] = "OutputVars";
constexpr char kMemOptVarInfoFromMainGraph[] = constexpr char kMemOptVarInfoFromMainGraph[] =
"mem_opt_var_info_from_main_graph"; "mem_opt_var_info_from_main_graph";
constexpr char kSkipGcVarNames[] = "skip_gc_vars"; constexpr char kSkipGcVarNames[] = "skip_gc_vars";
constexpr char kInplaceVarNames[] = "InplaceVars";
using Name2VarInfoMap = using Name2VarInfoMap =
std::unordered_map<std::string, std::unordered_map<std::string,
...@@ -67,11 +68,8 @@ class OpTransInfo { ...@@ -67,11 +68,8 @@ class OpTransInfo {
std::unordered_set<std::string> GetDenyVarNames( std::unordered_set<std::string> GetDenyVarNames(
const GraphNodeSet& cluster) const; const GraphNodeSet& cluster) const;
std::unordered_set<std::string> GetIgnoreInplaceVarNames( static std::unordered_set<std::string> GetInplaceVarNames(
const OpDesc& op_desc) const; const GraphNodeSet& cluster);
bool IsInplaceOp(const OpDesc& op_desc,
const std::unordered_set<std::string>& deny_var_names) const;
private: private:
DyOpCondT dynamic_op_cond_; DyOpCondT dynamic_op_cond_;
...@@ -79,9 +77,6 @@ class OpTransInfo { ...@@ -79,9 +77,6 @@ class OpTransInfo {
DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}}, DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}}; {"batch_norm_grad", {"ReserveSpace"}}};
DeParamCondT ignore_inplace_param_cond_{
{"batch_norm", {"MeanOut", "VarianceOut"}}};
std::unordered_set<std::string> default_deny_ops_{"feed", "fetch"}; std::unordered_set<std::string> default_deny_ops_{"feed", "fetch"};
}; };
......
...@@ -258,17 +258,16 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const { ...@@ -258,17 +258,16 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const {
std::unordered_set<std::string> CinnGraphSymbolization::GetFetchIds() const { std::unordered_set<std::string> CinnGraphSymbolization::GetFetchIds() const {
std::unordered_set<std::string> fetch_names; std::unordered_set<std::string> fetch_names;
fetch_names.reserve(fetch_var_names_.size()); fetch_names.reserve(fetch_var_names_.size());
std::for_each( std::for_each(fetch_var_names_.begin(),
fetch_var_names_.begin(), fetch_var_names_.end(),
fetch_var_names_.end(), [this, &fetch_names](const std::string& name) {
[this, &fetch_names](const std::string& name) { PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ( var_map_.count(name),
var_model_to_program_map_.count(name), 1,
1, platform::errors::PreconditionNotMet(
platform::errors::PreconditionNotMet( "Cannot find %s in var_map_", name.c_str()));
"Cannot find %s in var_model_to_program_map_", name.c_str())); fetch_names.insert(var_map_.at(name)->id);
fetch_names.insert(var_model_to_program_map_.at(name)); });
});
return fetch_names; return fetch_names;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册