未验证 提交 bad49b51 编写于 作者: J jiangcheng 提交者: GitHub

[CINN] collect inplace var into cinn op desc's kInplaceVarNames attribute (#49898)

* [CINN] collect inplace var into cinn op desc's kInplaceVarNames attribute

* attr move from op desc to subgraph

* GetFetchIds from var_map instead of var_model_to_program_map_
上级 8e02f290
......@@ -144,8 +144,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const auto& arg_names = desc->Input(param_name);
for (const auto& arg_name : arg_names) {
deny_var_set.insert(arg_name);
VLOG(4) << "deny param [" << param_name << "]'s argument name"
<< " is [" << arg_name << "].";
}
}
......@@ -153,8 +151,6 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const auto& arg_names = desc->Output(param_name);
for (const auto& arg_name : arg_names) {
deny_var_set.insert(arg_name);
VLOG(4) << "deny param [" << param_name << "]'s argument name"
<< " is [" << arg_name << "].";
}
}
}
......@@ -166,48 +162,27 @@ std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
return deny_var_set;
}
std::unordered_set<std::string> OpTransInfo::GetIgnoreInplaceVarNames(
const OpDesc& op_desc) const {
if (!ignore_inplace_param_cond_.count(op_desc.Type())) {
return {};
}
std::unordered_set<std::string> OpTransInfo::GetInplaceVarNames(
const GraphNodeSet& cluster) {
std::unordered_set<std::string> inplace_var_set;
const auto& ignore_inplace_names =
ignore_inplace_param_cond_.at(op_desc.Type());
VLOG(4) << "We found ignore inplace param "
<< GetDebugInfo(ignore_inplace_names) << " in op [" << op_desc.Type()
<< "].";
std::unordered_set<std::string> ignore_inplace_set;
for (const auto& param_name : ignore_inplace_names) {
if (op_desc.HasOutput(param_name)) {
const auto& arg_names = op_desc.Output(param_name);
ignore_inplace_set.insert(arg_names.begin(), arg_names.end());
for (auto* op : cluster) {
// skip if not op
if (!op->IsOp() || !op->Op()) {
continue;
}
}
VLOG(4) << "All ignore inplace var names are "
<< GetDebugInfo(ignore_inplace_set);
return ignore_inplace_set;
}
bool OpTransInfo::IsInplaceOp(
const OpDesc& op_desc,
const std::unordered_set<std::string>& deny_var_names) const {
const auto& ignore_inplace_set = GetIgnoreInplaceVarNames(op_desc);
auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name) > 0 && !deny_var_names.count(name) &&
!ignore_inplace_set.count(name)) {
VLOG(4) << "The argument " << name << " in op " << op_desc.Type()
<< " is a inplace op, skip!";
return true;
const auto& op_desc = *op->Op();
// check whether input and output have same argument
auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name)) {
inplace_var_set.insert(name);
}
}
}
return false;
return inplace_var_set;
}
namespace {
......@@ -503,6 +478,14 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
// initialize empty map for kMemOptVarInfoFromMainGraph attribute,
// it will be filled on the share_mem_opt_info_to_subgraph pass
subgraph->GetOrInit<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
auto inplace_var_names = std::make_unique<std::unordered_set<std::string>>(
OpTransInfo::GetInplaceVarNames(cluster));
VLOG_IF(4, !inplace_var_names->empty())
<< "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names);
subgraph->Set<std::unordered_set<std::string>>(kInplaceVarNames,
inplace_var_names.release());
return subgraph;
}
......@@ -594,7 +577,6 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
int64_t compilation_key,
const std::unordered_set<std::string>& deny_var_set,
Graph* graph) {
// Add the cinn launch op
framework::OpDesc cinn_op_desc;
......@@ -615,6 +597,7 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key);
cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
ExtractOpRole(cluster));
cinn_op_desc.Flush();
auto* cinn_op_node = graph->CreateOpNode(&cinn_op_desc);
// Add new links from or to the cinn launch op node
......@@ -639,21 +622,15 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster,
// kCinnLaunchOp, and inputs ares cluster_inputs and outputs are
// cluster_outputs.
// Meanwhile, move all links of cluster to the cinn op.
void ReplaceSubGraphWithCinnOpNode(
const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const GraphNodeSet& cluster_internals,
int64_t compilation_key,
const std::unordered_set<std::string>& deny_var_set,
Graph* graph) {
void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const GraphNodeSet& cluster_internals,
int64_t compilation_key,
Graph* graph) {
// Add the cinn op node whose name is "kCinnLaunchOp" into graph
AddCinnOpToGraph(cluster,
cluster_inputs,
cluster_outputs,
compilation_key,
deny_var_set,
graph);
AddCinnOpToGraph(
cluster, cluster_inputs, cluster_outputs, compilation_key, graph);
// Remove the cinn subgraph from graph
RemoveSubGraphFromGraph(cluster, cluster_internals, graph);
}
......@@ -667,9 +644,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
OpTransInfo trans_info;
const auto& deny_var_set = trans_info.GetDenyVarNames(graph->Nodes());
auto teller = [&allow_ops, &deny_ops, &trans_info, &deny_var_set](
const Node* node) {
auto teller = [&allow_ops, &deny_ops, &trans_info](const Node* node) {
const auto& node_name = node->Name();
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node_name) != nullptr;
......@@ -679,10 +654,9 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
is_dynamic = trans_info.dynamic_op_cond().at(node_name)(*node);
}
bool is_support =
registered && !trans_info.default_deny_ops().count(node_name) &&
!is_dynamic &&
(node->IsOp() && !trans_info.IsInplaceOp(*node->Op(), deny_var_set));
bool is_support = registered &&
!trans_info.default_deny_ops().count(node_name) &&
!is_dynamic;
// if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops
if (!allow_ops.empty()) {
......@@ -714,19 +688,23 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
return res;
};
std::unordered_set<std::string> skip_gc_var_names;
std::unordered_set<std::string> all_skip_gc_vars;
if (graph->Has(kSkipGcVarNames)) {
skip_gc_var_names =
all_skip_gc_vars =
graph->Get<std::unordered_set<std::string>>(kSkipGcVarNames);
VLOG_IF(4, !all_skip_gc_vars.empty())
<< "All skip gc var names are: " << GetDebugInfo(all_skip_gc_vars);
}
const auto& deny_var_set = trans_info.GetDenyVarNames(graph->Nodes());
VLOG_IF(4, !deny_var_set.empty())
<< "All deny var names are: " << GetDebugInfo(deny_var_set);
auto* cinn_compiler = CinnCompiler::GetInstance();
for (const auto& node_vec : clusters) {
// Classify var node to inputs, outputs, and internals.
GraphNodeSet cluster_set(node_vec.begin(), node_vec.end());
auto deny_var_set = trans_info.GetDenyVarNames(cluster_set);
GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals;
AnalyseClusterVariables(cluster_set,
deny_var_set,
......@@ -734,7 +712,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
&cluster_outputs,
&cluster_internals,
is_inference_stage,
skip_gc_var_names);
all_skip_gc_vars);
VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set);
VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs);
......@@ -747,8 +725,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
cluster_set, cluster_internals, cluster_inputs, cluster_outputs);
// Deliver the kSkipGcVarNames attr (if exists) to the subgraph
if (graph->Has(kSkipGcVarNames)) {
const auto& all_skip_gc_vars =
graph->Get<std::unordered_set<std::string>>(kSkipGcVarNames);
auto& sub_skip_gc_vars =
subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames);
sub_skip_gc_vars = all_skip_gc_vars;
......@@ -763,7 +739,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
cluster_outputs,
cluster_internals,
compilation_key,
deny_var_set,
graph);
}
}
......
......@@ -39,6 +39,7 @@ constexpr char kOutputVars[] = "OutputVars";
constexpr char kMemOptVarInfoFromMainGraph[] =
"mem_opt_var_info_from_main_graph";
constexpr char kSkipGcVarNames[] = "skip_gc_vars";
constexpr char kInplaceVarNames[] = "InplaceVars";
using Name2VarInfoMap =
std::unordered_map<std::string,
......@@ -67,11 +68,8 @@ class OpTransInfo {
std::unordered_set<std::string> GetDenyVarNames(
const GraphNodeSet& cluster) const;
std::unordered_set<std::string> GetIgnoreInplaceVarNames(
const OpDesc& op_desc) const;
bool IsInplaceOp(const OpDesc& op_desc,
const std::unordered_set<std::string>& deny_var_names) const;
static std::unordered_set<std::string> GetInplaceVarNames(
const GraphNodeSet& cluster);
private:
DyOpCondT dynamic_op_cond_;
......@@ -79,9 +77,6 @@ class OpTransInfo {
DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}};
DeParamCondT ignore_inplace_param_cond_{
{"batch_norm", {"MeanOut", "VarianceOut"}}};
std::unordered_set<std::string> default_deny_ops_{"feed", "fetch"};
};
......
......@@ -258,17 +258,16 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const {
std::unordered_set<std::string> CinnGraphSymbolization::GetFetchIds() const {
std::unordered_set<std::string> fetch_names;
fetch_names.reserve(fetch_var_names_.size());
std::for_each(
fetch_var_names_.begin(),
fetch_var_names_.end(),
[this, &fetch_names](const std::string& name) {
PADDLE_ENFORCE_EQ(
var_model_to_program_map_.count(name),
1,
platform::errors::PreconditionNotMet(
"Cannot find %s in var_model_to_program_map_", name.c_str()));
fetch_names.insert(var_model_to_program_map_.at(name));
});
std::for_each(fetch_var_names_.begin(),
fetch_var_names_.end(),
[this, &fetch_names](const std::string& name) {
PADDLE_ENFORCE_EQ(
var_map_.count(name),
1,
platform::errors::PreconditionNotMet(
"Cannot find %s in var_map_", name.c_str()));
fetch_names.insert(var_map_.at(name)->id);
});
return fetch_names;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册