diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index efdeaf8d34efe3ac676d15418f9e71272f347029..4d438122d145d7ec1742f36061e379a1ccf76f06 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -144,8 +144,6 @@ std::unordered_set OpTransInfo::GetDenyVarNames( const auto& arg_names = desc->Input(param_name); for (const auto& arg_name : arg_names) { deny_var_set.insert(arg_name); - VLOG(4) << "deny param [" << param_name << "]'s argument name" - << " is [" << arg_name << "]."; } } @@ -153,8 +151,6 @@ std::unordered_set OpTransInfo::GetDenyVarNames( const auto& arg_names = desc->Output(param_name); for (const auto& arg_name : arg_names) { deny_var_set.insert(arg_name); - VLOG(4) << "deny param [" << param_name << "]'s argument name" - << " is [" << arg_name << "]."; } } } @@ -166,48 +162,27 @@ std::unordered_set OpTransInfo::GetDenyVarNames( return deny_var_set; } -std::unordered_set OpTransInfo::GetIgnoreInplaceVarNames( - const OpDesc& op_desc) const { - if (!ignore_inplace_param_cond_.count(op_desc.Type())) { - return {}; - } +std::unordered_set OpTransInfo::GetInplaceVarNames( + const GraphNodeSet& cluster) { + std::unordered_set inplace_var_set; - const auto& ignore_inplace_names = - ignore_inplace_param_cond_.at(op_desc.Type()); - VLOG(4) << "We found ignore inplace param " - << GetDebugInfo(ignore_inplace_names) << " in op [" << op_desc.Type() - << "]."; - - std::unordered_set ignore_inplace_set; - for (const auto& param_name : ignore_inplace_names) { - if (op_desc.HasOutput(param_name)) { - const auto& arg_names = op_desc.Output(param_name); - ignore_inplace_set.insert(arg_names.begin(), arg_names.end()); + for (auto* op : cluster) { + // skip if not op + if (!op->IsOp() || !op->Op()) { + continue; } - } - - VLOG(4) << "All ignore inplace var names are " - << GetDebugInfo(ignore_inplace_set); - - return ignore_inplace_set; -} - -bool OpTransInfo::IsInplaceOp( - const OpDesc& op_desc, - const std::unordered_set& deny_var_names) const { - const auto& ignore_inplace_set = GetIgnoreInplaceVarNames(op_desc); - - auto inputs = op_desc.InputArgumentNames(); - std::unordered_set input_set(inputs.begin(), inputs.end()); - for (auto& name : op_desc.OutputArgumentNames()) { - if (input_set.count(name) > 0 && !deny_var_names.count(name) && - !ignore_inplace_set.count(name)) { - VLOG(4) << "The argument " << name << " in op " << op_desc.Type() - << " is a inplace op, skip!"; - return true; + const auto& op_desc = *op->Op(); + + // check whether input and output have same argument + auto inputs = op_desc.InputArgumentNames(); + std::unordered_set input_set(inputs.begin(), inputs.end()); + for (auto& name : op_desc.OutputArgumentNames()) { + if (input_set.count(name)) { + inplace_var_set.insert(name); + } } } - return false; + return inplace_var_set; } namespace { @@ -503,6 +478,14 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, // initialize empty map for kMemOptVarInfoFromMainGraph attribute, // it will be filled on the share_mem_opt_info_to_subgraph pass subgraph->GetOrInit(kMemOptVarInfoFromMainGraph); + + auto inplace_var_names = std::make_unique>( + OpTransInfo::GetInplaceVarNames(cluster)); + VLOG_IF(4, !inplace_var_names->empty()) + << "Inplace var in cluster are: " << GetDebugInfo(*inplace_var_names); + subgraph->Set>(kInplaceVarNames, + inplace_var_names.release()); + return subgraph; } @@ -594,7 +577,6 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs, int64_t compilation_key, - const std::unordered_set& deny_var_set, Graph* graph) { // Add the cinn launch op framework::OpDesc cinn_op_desc; @@ -615,6 +597,7 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, cinn_op_desc.SetAttr(operators::kCompilationKey, compilation_key); cinn_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), ExtractOpRole(cluster)); + cinn_op_desc.Flush(); auto* cinn_op_node = graph->CreateOpNode(&cinn_op_desc); // Add new links from or to the cinn launch op node @@ -639,21 +622,15 @@ void RemoveSubGraphFromGraph(const GraphNodeSet& cluster, // kCinnLaunchOp, and inputs ares cluster_inputs and outputs are // cluster_outputs. // Meanwhile, move all links of cluster to the cinn op. -void ReplaceSubGraphWithCinnOpNode( - const GraphNodeSet& cluster, - const GraphNodeSet& cluster_inputs, - const GraphNodeSet& cluster_outputs, - const GraphNodeSet& cluster_internals, - int64_t compilation_key, - const std::unordered_set& deny_var_set, - Graph* graph) { +void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster, + const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs, + const GraphNodeSet& cluster_internals, + int64_t compilation_key, + Graph* graph) { // Add the cinn op node whose name is "kCinnLaunchOp" into graph - AddCinnOpToGraph(cluster, - cluster_inputs, - cluster_outputs, - compilation_key, - deny_var_set, - graph); + AddCinnOpToGraph( + cluster, cluster_inputs, cluster_outputs, compilation_key, graph); // Remove the cinn subgraph from graph RemoveSubGraphFromGraph(cluster, cluster_internals, graph); } @@ -667,9 +644,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); OpTransInfo trans_info; - const auto& deny_var_set = trans_info.GetDenyVarNames(graph->Nodes()); - auto teller = [&allow_ops, &deny_ops, &trans_info, &deny_var_set]( - const Node* node) { + auto teller = [&allow_ops, &deny_ops, &trans_info](const Node* node) { const auto& node_name = node->Name(); bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find( node_name) != nullptr; @@ -679,10 +654,9 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { is_dynamic = trans_info.dynamic_op_cond().at(node_name)(*node); } - bool is_support = - registered && !trans_info.default_deny_ops().count(node_name) && - !is_dynamic && - (node->IsOp() && !trans_info.IsInplaceOp(*node->Op(), deny_var_set)); + bool is_support = registered && + !trans_info.default_deny_ops().count(node_name) && + !is_dynamic; // if the op type is registered in CINN and allow_ops is not empty, return // true only when it is in allow_ops if (!allow_ops.empty()) { @@ -714,19 +688,23 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { return res; }; - std::unordered_set skip_gc_var_names; + std::unordered_set all_skip_gc_vars; if (graph->Has(kSkipGcVarNames)) { - skip_gc_var_names = + all_skip_gc_vars = graph->Get>(kSkipGcVarNames); + VLOG_IF(4, !all_skip_gc_vars.empty()) + << "All skip gc var names are: " << GetDebugInfo(all_skip_gc_vars); } + const auto& deny_var_set = trans_info.GetDenyVarNames(graph->Nodes()); + VLOG_IF(4, !deny_var_set.empty()) + << "All deny var names are: " << GetDebugInfo(deny_var_set); + auto* cinn_compiler = CinnCompiler::GetInstance(); for (const auto& node_vec : clusters) { // Classify var node to inputs, outputs, and internals. GraphNodeSet cluster_set(node_vec.begin(), node_vec.end()); - auto deny_var_set = trans_info.GetDenyVarNames(cluster_set); - GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals; AnalyseClusterVariables(cluster_set, deny_var_set, @@ -734,7 +712,7 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { &cluster_outputs, &cluster_internals, is_inference_stage, - skip_gc_var_names); + all_skip_gc_vars); VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set); VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs); @@ -747,8 +725,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { cluster_set, cluster_internals, cluster_inputs, cluster_outputs); // Deliver the kSkipGcVarNames attr (if exists) to the subgraph if (graph->Has(kSkipGcVarNames)) { - const auto& all_skip_gc_vars = - graph->Get>(kSkipGcVarNames); auto& sub_skip_gc_vars = subgraph->GetOrInit>(kSkipGcVarNames); sub_skip_gc_vars = all_skip_gc_vars; @@ -763,7 +739,6 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { cluster_outputs, cluster_internals, compilation_key, - deny_var_set, graph); } } diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 93e5186421725a9546d043804ec1e6b8b36c52bf..7e5152048d95467f954a46ffc19f18c0c8b9772a 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -39,6 +39,7 @@ constexpr char kOutputVars[] = "OutputVars"; constexpr char kMemOptVarInfoFromMainGraph[] = "mem_opt_var_info_from_main_graph"; constexpr char kSkipGcVarNames[] = "skip_gc_vars"; +constexpr char kInplaceVarNames[] = "InplaceVars"; using Name2VarInfoMap = std::unordered_map GetDenyVarNames( const GraphNodeSet& cluster) const; - std::unordered_set GetIgnoreInplaceVarNames( - const OpDesc& op_desc) const; - - bool IsInplaceOp(const OpDesc& op_desc, - const std::unordered_set& deny_var_names) const; + static std::unordered_set GetInplaceVarNames( + const GraphNodeSet& cluster); private: DyOpCondT dynamic_op_cond_; @@ -79,9 +77,6 @@ class OpTransInfo { DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}}, {"batch_norm_grad", {"ReserveSpace"}}}; - DeParamCondT ignore_inplace_param_cond_{ - {"batch_norm", {"MeanOut", "VarianceOut"}}}; - std::unordered_set default_deny_ops_{"feed", "fetch"}; }; diff --git a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc index 94bc1241895ef618437e7903de8a7629080a0d8e..b703ca04f9274eae08e7141d80b9d18c54e367ed 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc @@ -258,17 +258,16 @@ void CinnGraphSymbolization::RunGraph(const OpMapperContext& ctx) const { std::unordered_set CinnGraphSymbolization::GetFetchIds() const { std::unordered_set fetch_names; fetch_names.reserve(fetch_var_names_.size()); - std::for_each( - fetch_var_names_.begin(), - fetch_var_names_.end(), - [this, &fetch_names](const std::string& name) { - PADDLE_ENFORCE_EQ( - var_model_to_program_map_.count(name), - 1, - platform::errors::PreconditionNotMet( - "Cannot find %s in var_model_to_program_map_", name.c_str())); - fetch_names.insert(var_model_to_program_map_.at(name)); - }); + std::for_each(fetch_var_names_.begin(), + fetch_var_names_.end(), + [this, &fetch_names](const std::string& name) { + PADDLE_ENFORCE_EQ( + var_map_.count(name), + 1, + platform::errors::PreconditionNotMet( + "Cannot find %s in var_map_", name.c_str())); + fetch_names.insert(var_map_.at(name)->id); + }); return fetch_names; }