From ce57365deb5680a37d735019ed1e437bbc60fc1c Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Thu, 12 Jan 2023 10:53:09 +0800 Subject: [PATCH] [CINN] temp fix batch_norm check as inplace op bug (#49738) --- .../framework/paddle2cinn/build_cinn_pass.cc | 75 ++++++++++++++----- .../framework/paddle2cinn/build_cinn_pass.h | 9 ++- 2 files changed, 65 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index a87d4fa148..efdeaf8d34 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -52,6 +52,16 @@ using framework::ir::Node; using GraphNodeVec = std::vector; using GraphNodeMap = std::unordered_map; +std::string GetDebugInfo(const std::unordered_set& var_names) { + std::string debug_info = "["; + for (auto& var : var_names) { + debug_info.append(var); + debug_info.append(", "); + } + debug_info.append("]"); + return debug_info; +} + OpTransInfo::OpTransInfo() { // judgment condition for the dynamic slice dynamic_op_cond_.emplace("slice", [](const ir::Node& node) -> bool { @@ -115,16 +125,6 @@ std::unordered_set OpTransInfo::GetDenyVarNames( const GraphNodeSet& cluster) const { std::unordered_set deny_var_set; - auto get_debug_info = [](const std::unordered_set& var_names) { - std::string debug_info = "["; - for (auto& var : var_names) { - debug_info.append(var); - debug_info.append(", "); - } - debug_info.append("]"); - return debug_info; - }; - for (auto* op : cluster) { if (deny_param_cond_.count(op->Name())) { const auto* desc = op->Op(); @@ -136,7 +136,7 @@ std::unordered_set OpTransInfo::GetDenyVarNames( op->Name().c_str())); auto deny_param_names = deny_param_cond_.at(op->Name()); - VLOG(4) << "We found deny param " << get_debug_info(deny_param_names) + VLOG(4) << "We found deny param " << GetDebugInfo(deny_param_names) << " in op [" << op->Name() << "]."; for (const auto& param_name : deny_param_names) { @@ -161,16 +161,51 @@ std::unordered_set OpTransInfo::GetDenyVarNames( } } - VLOG(4) << "All deny var names are " << get_debug_info(deny_var_set); + VLOG(4) << "All deny var names are " << GetDebugInfo(deny_var_set); return deny_var_set; } -bool OpTransInfo::IsInplaceOp(const OpDesc& op_desc) { +std::unordered_set OpTransInfo::GetIgnoreInplaceVarNames( + const OpDesc& op_desc) const { + if (!ignore_inplace_param_cond_.count(op_desc.Type())) { + return {}; + } + + const auto& ignore_inplace_names = + ignore_inplace_param_cond_.at(op_desc.Type()); + VLOG(4) << "We found ignore inplace param " + << GetDebugInfo(ignore_inplace_names) << " in op [" << op_desc.Type() + << "]."; + + std::unordered_set ignore_inplace_set; + for (const auto& param_name : ignore_inplace_names) { + if (op_desc.HasOutput(param_name)) { + const auto& arg_names = op_desc.Output(param_name); + ignore_inplace_set.insert(arg_names.begin(), arg_names.end()); + } + } + + VLOG(4) << "All ignore inplace var names are " + << GetDebugInfo(ignore_inplace_set); + + return ignore_inplace_set; +} + +bool OpTransInfo::IsInplaceOp( + const OpDesc& op_desc, + const std::unordered_set& deny_var_names) const { + const auto& ignore_inplace_set = GetIgnoreInplaceVarNames(op_desc); + auto inputs = op_desc.InputArgumentNames(); std::unordered_set input_set(inputs.begin(), inputs.end()); for (auto& name : op_desc.OutputArgumentNames()) { - if (input_set.count(name) > 0) return true; + if (input_set.count(name) > 0 && !deny_var_names.count(name) && + !ignore_inplace_set.count(name)) { + VLOG(4) << "The argument " << name << " in op " << op_desc.Type() + << " is a inplace op, skip!"; + return true; + } } return false; } @@ -630,8 +665,11 @@ void ReplaceSubGraphWithCinnOpNode( void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); + OpTransInfo trans_info; - auto teller = [&allow_ops, &deny_ops, &trans_info](const Node* node) { + const auto& deny_var_set = trans_info.GetDenyVarNames(graph->Nodes()); + auto teller = [&allow_ops, &deny_ops, &trans_info, &deny_var_set]( + const Node* node) { const auto& node_name = node->Name(); bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find( node_name) != nullptr; @@ -643,7 +681,8 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { bool is_support = registered && !trans_info.default_deny_ops().count(node_name) && - !is_dynamic && (node->IsOp() && !trans_info.IsInplaceOp(*node->Op())); + !is_dynamic && + (node->IsOp() && !trans_info.IsInplaceOp(*node->Op(), deny_var_set)); // if the op type is registered in CINN and allow_ops is not empty, return // true only when it is in allow_ops if (!allow_ops.empty()) { @@ -659,8 +698,8 @@ void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { // return true only when it is registered in CINN return is_support; }; - VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops; - VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops; + VLOG(4) << "The allowed Cinn Ops: " << GetDebugInfo(allow_ops); + VLOG(4) << "The denied Cinn Ops: " << GetDebugInfo(deny_ops); std::vector clusters = CinnSubgraphDetector(graph, teller)(); LOG(INFO) << "--- [build_cinn_pass] detected " << clusters.size() << " cinn supported subgraphs"; diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 55caae596c..93e5186421 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -67,7 +67,11 @@ class OpTransInfo { std::unordered_set GetDenyVarNames( const GraphNodeSet& cluster) const; - static bool IsInplaceOp(const OpDesc& op_desc); + std::unordered_set GetIgnoreInplaceVarNames( + const OpDesc& op_desc) const; + + bool IsInplaceOp(const OpDesc& op_desc, + const std::unordered_set& deny_var_names) const; private: DyOpCondT dynamic_op_cond_; @@ -75,6 +79,9 @@ class OpTransInfo { DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}}, {"batch_norm_grad", {"ReserveSpace"}}}; + DeParamCondT ignore_inplace_param_cond_{ + {"batch_norm", {"MeanOut", "VarianceOut"}}}; + std::unordered_set default_deny_ops_{"feed", "fetch"}; }; -- GitLab