From 3aa6bd57d0eb3bf22fc391adbfaf650592b25987 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Wed, 21 Sep 2022 18:51:16 +0800 Subject: [PATCH] Enable PaddleInference to use CINN. (#45009) * use cinn in the paddle inference * fix some cmake errors * Avoid division by zero in the arange_kernel. * Avoid dynamic ops. * Remove some useless codes. * Use OpTransInfo to encapsulate some codes used in the build_cinn_pass. --- paddle/fluid/framework/CMakeLists.txt | 47 +++++++++++++++++ paddle/fluid/framework/ir/CMakeLists.txt | 50 +------------------ .../framework/paddle2cinn/CMakeLists.txt | 14 ++++-- .../framework/paddle2cinn/build_cinn_pass.cc | 46 +++++++++-------- .../framework/paddle2cinn/build_cinn_pass.h | 40 +++++++++++++++ paddle/fluid/inference/paddle_inference.map | 1 + paddle/infrt/tests/timer.h | 1 + paddle/phi/kernels/gpu/arange_kernel.cu | 4 ++ 8 files changed, 127 insertions(+), 76 deletions(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 85806014312..7434557a3ca 100755 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -29,6 +29,53 @@ function(windows_symbolic TARGET) endforeach() endfunction() +# Usage: pass_library(target inference) will append to paddle_inference_pass.h +set(pass_file + ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h.tmp) +set(pass_file_final + ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h) +file( + WRITE ${pass_file} + "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n" +) +file(APPEND ${pass_file} "\#pragma once\n") +file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n") +copy_if_different(${pass_file} ${pass_file_final}) + +function(pass_library TARGET DEST) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS DIR) + set(targetPrefix "") + + cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN}) + if(pass_library_DIR) + cc_library( + ${TARGET} + SRCS ${pass_library_DIR}/${TARGET}.cc + DEPS graph_pattern_detector pass fuse_pass_base op_version_registry + ${pass_library_DEPS}) + else() + cc_library( + ${TARGET} + SRCS ${TARGET}.cc + DEPS graph_pattern_detector pass fuse_pass_base op_version_registry + ${pass_library_DEPS}) + endif() + + # add more DEST here, such as train, dist and collect USE_PASS into a file automatically. + if(${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference") + if(NOT CMAKE_BUILD_TYPE STREQUAL "Release") + message(STATUS "add pass ${TARGET} ${DEST}") + endif() + file(APPEND ${pass_file} "USE_PASS(${TARGET});\n") + set(INFER_IR_PASSES + ${INFER_IR_PASSES} ${TARGET} + CACHE INTERNAL "") + endif() +endfunction() + add_subdirectory(ir) add_subdirectory(details) add_subdirectory(fleet) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 3a2ae0ff217..a58434eed61 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -1,16 +1,3 @@ -set(pass_file - ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h.tmp) -set(pass_file_final - ${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h) -file( - WRITE ${pass_file} - "// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n" -) -file(APPEND ${pass_file} "\#pragma once\n") -file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n") - -copy_if_different(${pass_file} ${pass_file_final}) - add_subdirectory(fuse_optimizer_ops_pass) add_subdirectory(memory_optimize_pass) add_subdirectory(multi_devices_graph_pass) @@ -20,42 +7,7 @@ if(NOT APPLE add_subdirectory(fusion_group) endif() -# Usage: pass_library(target inference) will append to paddle_inference_pass.h unset(INFER_IR_PASSES CACHE) # clear the global variable -function(pass_library TARGET DEST) - set(options "") - set(oneValueArgs "") - set(multiValueArgs SRCS DEPS DIR) - set(targetPrefix "") - - cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN}) - if(pass_library_DIR) - cc_library( - ${TARGET} - SRCS ${pass_library_DIR}/${TARGET}.cc - DEPS graph_pattern_detector pass fuse_pass_base op_version_registry - ${pass_library_DEPS}) - else() - cc_library( - ${TARGET} - SRCS ${TARGET}.cc - DEPS graph_pattern_detector pass fuse_pass_base op_version_registry - ${pass_library_DEPS}) - endif() - - # add more DEST here, such as train, dist and collect USE_PASS into a file automatically. - if(${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference") - if(NOT CMAKE_BUILD_TYPE STREQUAL "Release") - message(STATUS "add pass ${TARGET} ${DEST}") - endif() - file(APPEND ${pass_file} "USE_PASS(${TARGET});\n") - set(INFER_IR_PASSES - ${INFER_IR_PASSES} ${TARGET} - CACHE INTERNAL "") - endif() -endfunction() - cc_library( node SRCS node.cc @@ -266,7 +218,7 @@ cc_library( DEPS pass graph_pattern_detector) set(GLOB_PASS_LIB - ${PASS_LIBRARY} + ${INFER_IR_PASSES} CACHE INTERNAL "Global PASS library") cc_library( diff --git a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt index 5b8e62d4f07..9049da91792 100644 --- a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt +++ b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt @@ -1,12 +1,16 @@ +pass_library( + build_cinn_pass + base + DEPS + subgraph_detector + cinn_compiler + errors + enforce) + cc_library( cinn_cache_key SRCS cinn_cache_key.cc DEPS graph graph_helper lod_tensor proto_desc) -cc_library( - build_cinn_pass - SRCS build_cinn_pass.cc - DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors - enforce) cc_library( transform_desc SRCS transform_desc.cc diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index f0f35ea28cc..647ea868de6 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -50,21 +50,10 @@ using framework::ir::Graph; using framework::ir::Node; using GraphNodeVec = std::vector; -using GraphNodeSet = std::unordered_set; using GraphNodeMap = std::unordered_map; -namespace { -// The delim(`;`) that is used to split the FLAGS_allow_cinn_ops -// & FLAGS_deny_cinn_ops. -constexpr char kDelim[] = ";"; - -const std::unordered_map> - kDenyParamMap = {{"batch_norm", {"ReserveSpace"}}, - {"batch_norm_grad", {"ReserveSpace"}}}; - -const std::unordered_set kDefaultDenyOps = {"feed", "fetch"}; - -std::unordered_set GetDenyVarNames(const GraphNodeSet& cluster) { +std::unordered_set OpTransInfo::GetDenyVarNames( + const GraphNodeSet& cluster) const { std::unordered_set deny_var_set; auto get_debug_info = [](const std::unordered_set& var_names) { @@ -78,16 +67,16 @@ std::unordered_set GetDenyVarNames(const GraphNodeSet& cluster) { }; for (auto* op : cluster) { - if (kDenyParamMap.count(op->Name())) { + if (deny_param_cond.count(op->Name())) { const auto* desc = op->Op(); PADDLE_ENFORCE_NE(desc, nullptr, platform::errors::PreconditionNotMet( "The Op %s's OpDesc should not be NULL, which has " - "a parameter in kDenyParamMap.", + "a parameter in deny_param_cond.", op->Name().c_str())); - auto deny_param_names = kDenyParamMap.at(op->Name()); + auto deny_param_names = deny_param_cond.at(op->Name()); VLOG(4) << "We found deny param " << get_debug_info(deny_param_names) << " in op [" << op->Name() << "]."; @@ -118,6 +107,11 @@ std::unordered_set GetDenyVarNames(const GraphNodeSet& cluster) { return deny_var_set; } +namespace { +// The delim(`;`) that is used to split the FLAGS_allow_cinn_ops +// & FLAGS_deny_cinn_ops. +constexpr char kDelim[] = ";"; + std::unordered_set StringSplit(const std::string& str, const std::string& delim) { std::regex reg(delim); @@ -561,30 +555,38 @@ static bool IsInplaceOp(const OpDesc& op_desc) { void SearchAllSubgraphs(Graph* graph) { auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); - auto teller = [&allow_ops, &deny_ops](const Node* node) { + OpTransInfo trans_info; + auto teller = [&allow_ops, &deny_ops, &trans_info](const Node* node) { const auto& node_name = node->Name(); bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find( node_name) != nullptr; + // skip the dynamic ops + bool is_dynamic = false; + if (trans_info.dynamic_op_cond.count(node_name)) { + is_dynamic = trans_info.dynamic_op_cond.at(node_name)(node); + } // if the op type is registered in CINN and allow_ops is not empty, return // true only when it is in allow_ops if (!allow_ops.empty()) { - return registered && allow_ops.count(node_name); + return registered && !is_dynamic && allow_ops.count(node_name); } // if the op type is registered in CINN and deny_ops is not empty, return // true only when it is not in deny_ops if (!deny_ops.empty()) { - return registered && !deny_ops.count(node_name); + return registered && !is_dynamic && !deny_ops.count(node_name); } // if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops, // return true only when it is registered in CINN - return registered && !kDefaultDenyOps.count(node_name) && - (node->IsOp() && !IsInplaceOp(*node->Op())); + return registered && !trans_info.default_deny_ops.count(node_name) && + !is_dynamic && (node->IsOp() && !IsInplaceOp(*node->Op())); }; VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops; VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops; std::vector clusters = framework::ir::SubgraphDetector(graph, teller)(); + LOG(INFO) << "--- [build_cinn_pass] detected " << clusters.size() + << " cinn supported subgraphs"; auto cluster_debug_info = [](const GraphNodeSet& cluster) { std::string res = "("; @@ -601,7 +603,7 @@ void SearchAllSubgraphs(Graph* graph) { // Classify var node to inputs, outputs, and internals. GraphNodeSet cluster_set(node_vec.begin(), node_vec.end()); - auto deny_var_set = GetDenyVarNames(cluster_set); + auto deny_var_set = trans_info.GetDenyVarNames(cluster_set); GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals; AnalyseClusterVariables(cluster_set, diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index a902eacde82..8f11d069344 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -14,12 +14,21 @@ limitations under the License. */ #pragma once +#include +#include +#include +#include +#include + #include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" namespace paddle { namespace framework { namespace ir { class MemOptVarInfo; +class Node; } // namespace ir namespace paddle2cinn { @@ -31,9 +40,40 @@ constexpr char kInternalVars[] = "InternalVars"; constexpr char kOutputVars[] = "OutputVars"; constexpr char kMemOptVarInfoFromMainGraph[] = "mem_opt_var_info_from_main_graph"; + using Name2VarInfoMap = std::unordered_map>; +using GraphNodeSet = std::unordered_set; + +struct OpTransInfo { + const std::unordered_set default_deny_ops{"feed", "fetch"}; + + const std::unordered_map> + dynamic_op_cond{ + {"slice", [](const ir::Node* node) -> bool { + if (!node->IsOp()) { + return false; + } + auto* op_desc = node->Op(); + auto infer_flags = + op_desc->GetAttrIfExists>("infer_flags"); + if (std::find_if( + infer_flags.begin(), infer_flags.end(), [](int v) { + return v < 0; + }) != infer_flags.end()) { + return true; + } + return false; + }}}; + + const std::unordered_map> + deny_param_cond{{"batch_norm", {"ReserveSpace"}}, + {"batch_norm_grad", {"ReserveSpace"}}}; + + std::unordered_set GetDenyVarNames( + const GraphNodeSet& cluster) const; +}; // A pass named BuildCinnPass, the function of this pass is: // diff --git a/paddle/fluid/inference/paddle_inference.map b/paddle/fluid/inference/paddle_inference.map index f8d7fb582b8..0d2b0d659b3 100644 --- a/paddle/fluid/inference/paddle_inference.map +++ b/paddle/fluid/inference/paddle_inference.map @@ -4,6 +4,7 @@ *Pass*; *profile*; *phi*; + *cinn*; local: *; }; diff --git a/paddle/infrt/tests/timer.h b/paddle/infrt/tests/timer.h index 18372cbe541..cfd382d456c 100644 --- a/paddle/infrt/tests/timer.h +++ b/paddle/infrt/tests/timer.h @@ -14,6 +14,7 @@ #pragma once +#include #include namespace infrt { diff --git a/paddle/phi/kernels/gpu/arange_kernel.cu b/paddle/phi/kernels/gpu/arange_kernel.cu index 858191c44ee..26f85446deb 100644 --- a/paddle/phi/kernels/gpu/arange_kernel.cu +++ b/paddle/phi/kernels/gpu/arange_kernel.cu @@ -15,6 +15,8 @@ #include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/range_function.h" @@ -56,6 +58,8 @@ void ArangeKernel(const Context& dev_ctx, auto stream = dev_ctx.stream(); int block = std::min(size, static_cast(256)); + PADDLE_ENFORCE_NE( + block, 0, errors::OutOfRange("The value of block cannot be 0.")); int grid = (size + block - 1) / block; Range<<>>(start_value, step_value, size, out_data); } -- GitLab