From a19b082ee1bd0b9cf62ce34bee29acc3b2cf92a3 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Tue, 11 Oct 2022 11:27:21 +0800 Subject: [PATCH] Fix some bugs hidden in build_cinn_pass. (#46843) * Fix some bugs hidden in build_cinn_pass. * Update codes about OpTransInfo. * Only support for the static reshape/reshape2 op. --- .../eager/auto_code_generator/CMakeLists.txt | 2 +- .../framework/paddle2cinn/build_cinn_pass.cc | 70 ++++++++++++++----- .../framework/paddle2cinn/build_cinn_pass.h | 54 +++++++------- 3 files changed, 81 insertions(+), 45 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/CMakeLists.txt b/paddle/fluid/eager/auto_code_generator/CMakeLists.txt index 3c1f6835c30..99f5e789081 100644 --- a/paddle/fluid/eager/auto_code_generator/CMakeLists.txt +++ b/paddle/fluid/eager/auto_code_generator/CMakeLists.txt @@ -100,7 +100,7 @@ else() legacy_eager_codegen COMMAND ${CMAKE_COMMAND} -E env - "LD_LIBRARY_PATH=$ENV{LD_LIBRARY_PATH}:${CMAKE_CURRENT_BINARY_DIR}/../../pybind" + "LD_LIBRARY_PATH=$ENV{LD_LIBRARY_PATH}:${CMAKE_CURRENT_BINARY_DIR}/../../pybind:${PADDLE_BINARY_DIR}/third_party/install/mklml/lib" "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" "${CODE_GEN_SPLIT_FILE_COUNT}" diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 647ea868de6..8becee8d485 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -52,6 +52,37 @@ using framework::ir::Node; using GraphNodeVec = std::vector; using GraphNodeMap = std::unordered_map; +OpTransInfo::OpTransInfo() { + // judgment condition for the dynamic slice + dynamic_op_cond_.emplace("slice", [](const ir::Node& node) -> bool { + if (!node.IsOp()) { + return false; + } + auto* op_desc = node.Op(); + auto infer_flags = + op_desc->GetAttrIfExists>("infer_flags"); + return std::find_if(infer_flags.begin(), infer_flags.end(), [](int v) { + return v < 0; + }) != infer_flags.end(); + }); + + // judgment condition for the dynamic reshape + dynamic_op_cond_.emplace("reshape", [](const ir::Node& node) -> bool { + if (!node.IsOp()) { + return false; + } + auto* op_desc = node.Op(); + bool has_shape_tensor = op_desc->Inputs().count("ShapeTensor") && + op_desc->Inputs().at("ShapeTensor").size(); + bool has_shape = op_desc->Inputs().count("Shape") && + op_desc->Inputs().at("Shape").size(); + return has_shape_tensor || has_shape; + }); + + // judgment condition for the dynamic reshape2 + dynamic_op_cond_.emplace("reshape2", dynamic_op_cond_.at("reshape")); +} + std::unordered_set OpTransInfo::GetDenyVarNames( const GraphNodeSet& cluster) const { std::unordered_set deny_var_set; @@ -67,16 +98,16 @@ std::unordered_set OpTransInfo::GetDenyVarNames( }; for (auto* op : cluster) { - if (deny_param_cond.count(op->Name())) { + if (deny_param_cond_.count(op->Name())) { const auto* desc = op->Op(); PADDLE_ENFORCE_NE(desc, nullptr, platform::errors::PreconditionNotMet( "The Op %s's OpDesc should not be NULL, which has " - "a parameter in deny_param_cond.", + "a parameter in deny_param_cond_.", op->Name().c_str())); - auto deny_param_names = deny_param_cond.at(op->Name()); + auto deny_param_names = deny_param_cond_.at(op->Name()); VLOG(4) << "We found deny param " << get_debug_info(deny_param_names) << " in op [" << op->Name() << "]."; @@ -107,6 +138,15 @@ std::unordered_set OpTransInfo::GetDenyVarNames( return deny_var_set; } +bool OpTransInfo::IsInplaceOp(const OpDesc& op_desc) { + auto inputs = op_desc.InputArgumentNames(); + std::unordered_set input_set(inputs.begin(), inputs.end()); + for (auto& name : op_desc.OutputArgumentNames()) { + if (input_set.count(name) > 0) return true; + } + return false; +} + namespace { // The delim(`;`) that is used to split the FLAGS_allow_cinn_ops // & FLAGS_deny_cinn_ops. @@ -539,15 +579,6 @@ void ReplaceSubGraphWithCinnOpNode( RemoveSubGraphFromGraph(cluster, cluster_internals, graph); } -static bool IsInplaceOp(const OpDesc& op_desc) { - auto inputs = op_desc.InputArgumentNames(); - std::unordered_set input_set(inputs.begin(), inputs.end()); - for (auto& name : op_desc.OutputArgumentNames()) { - if (input_set.count(name) > 0) return true; - } - return false; -} - // Search all subgraphs which all op node supported by CINN, // Here we using SubgraphDetector to detecte the subgraph that // all of op node supported by CINN. We using OpMapperRegistry @@ -562,24 +593,27 @@ void SearchAllSubgraphs(Graph* graph) { node_name) != nullptr; // skip the dynamic ops bool is_dynamic = false; - if (trans_info.dynamic_op_cond.count(node_name)) { - is_dynamic = trans_info.dynamic_op_cond.at(node_name)(node); + if (trans_info.dynamic_op_cond().count(node_name)) { + is_dynamic = trans_info.dynamic_op_cond().at(node_name)(*node); } + + bool is_support = + registered && !trans_info.default_deny_ops().count(node_name) && + !is_dynamic && (node->IsOp() && !trans_info.IsInplaceOp(*node->Op())); // if the op type is registered in CINN and allow_ops is not empty, return // true only when it is in allow_ops if (!allow_ops.empty()) { - return registered && !is_dynamic && allow_ops.count(node_name); + return is_support && allow_ops.count(node_name); } // if the op type is registered in CINN and deny_ops is not empty, return // true only when it is not in deny_ops if (!deny_ops.empty()) { - return registered && !is_dynamic && !deny_ops.count(node_name); + return is_support && !deny_ops.count(node_name); } // if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops, // return true only when it is registered in CINN - return registered && !trans_info.default_deny_ops.count(node_name) && - !is_dynamic && (node->IsOp() && !IsInplaceOp(*node->Op())); + return is_support; }; VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops; VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops; diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 8f11d069344..42b98b32983 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -21,8 +21,6 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/pass.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/errors.h" namespace paddle { namespace framework { @@ -46,33 +44,37 @@ using Name2VarInfoMap = std::shared_ptr>; using GraphNodeSet = std::unordered_set; -struct OpTransInfo { - const std::unordered_set default_deny_ops{"feed", "fetch"}; - - const std::unordered_map> - dynamic_op_cond{ - {"slice", [](const ir::Node* node) -> bool { - if (!node->IsOp()) { - return false; - } - auto* op_desc = node->Op(); - auto infer_flags = - op_desc->GetAttrIfExists>("infer_flags"); - if (std::find_if( - infer_flags.begin(), infer_flags.end(), [](int v) { - return v < 0; - }) != infer_flags.end()) { - return true; - } - return false; - }}}; - - const std::unordered_map> - deny_param_cond{{"batch_norm", {"ReserveSpace"}}, - {"batch_norm_grad", {"ReserveSpace"}}}; +// OpTransInfo contains informations used to detect subgraphs +// supported by the CINN compiler. +class OpTransInfo { + using DyOpCondT = + std::unordered_map>; + using DeParamCondT = + std::unordered_map>; + + public: + OpTransInfo(); + + const DyOpCondT& dynamic_op_cond() const { return dynamic_op_cond_; } + + const DeParamCondT& deny_param_cond() const { return deny_param_cond_; } + + const std::unordered_set& default_deny_ops() const { + return default_deny_ops_; + } std::unordered_set GetDenyVarNames( const GraphNodeSet& cluster) const; + + static bool IsInplaceOp(const OpDesc& op_desc); + + private: + DyOpCondT dynamic_op_cond_; + + DeParamCondT deny_param_cond_{{"batch_norm", {"ReserveSpace"}}, + {"batch_norm_grad", {"ReserveSpace"}}}; + + std::unordered_set default_deny_ops_{"feed", "fetch"}; }; // A pass named BuildCinnPass, the function of this pass is: -- GitLab