未验证 提交 3aa6bd57 编写于 作者: Z Zhen Wang 提交者: GitHub

Enable PaddleInference to use CINN. (#45009)

* use cinn in the paddle inference

* fix some cmake errors

* Avoid division by zero in the arange_kernel.

* Avoid dynamic ops.

* Remove some useless codes.

* Use OpTransInfo to encapsulate some codes used in the build_cinn_pass.
上级 face8f1f
......@@ -29,6 +29,53 @@ function(windows_symbolic TARGET)
endforeach()
endfunction()
# Usage: pass_library(target inference) will append to paddle_inference_pass.h
set(pass_file
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h.tmp)
set(pass_file_final
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
file(
WRITE ${pass_file}
"// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n"
)
file(APPEND ${pass_file} "\#pragma once\n")
file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
copy_if_different(${pass_file} ${pass_file_final})
function(pass_library TARGET DEST)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS DIR)
set(targetPrefix "")
cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
if(pass_library_DIR)
cc_library(
${TARGET}
SRCS ${pass_library_DIR}/${TARGET}.cc
DEPS graph_pattern_detector pass fuse_pass_base op_version_registry
${pass_library_DEPS})
else()
cc_library(
${TARGET}
SRCS ${TARGET}.cc
DEPS graph_pattern_detector pass fuse_pass_base op_version_registry
${pass_library_DEPS})
endif()
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
if(${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference")
if(NOT CMAKE_BUILD_TYPE STREQUAL "Release")
message(STATUS "add pass ${TARGET} ${DEST}")
endif()
file(APPEND ${pass_file} "USE_PASS(${TARGET});\n")
set(INFER_IR_PASSES
${INFER_IR_PASSES} ${TARGET}
CACHE INTERNAL "")
endif()
endfunction()
add_subdirectory(ir)
add_subdirectory(details)
add_subdirectory(fleet)
......
set(pass_file
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h.tmp)
set(pass_file_final
${PADDLE_BINARY_DIR}/paddle/fluid/inference/api/paddle_inference_pass.h)
file(
WRITE ${pass_file}
"// Generated by the paddle/fluid/framework/ir/CMakeLists.txt. DO NOT EDIT!\n\n"
)
file(APPEND ${pass_file} "\#pragma once\n")
file(APPEND ${pass_file} "\#include \"paddle/fluid/framework/ir/pass.h\"\n")
copy_if_different(${pass_file} ${pass_file_final})
add_subdirectory(fuse_optimizer_ops_pass)
add_subdirectory(memory_optimize_pass)
add_subdirectory(multi_devices_graph_pass)
......@@ -20,42 +7,7 @@ if(NOT APPLE
add_subdirectory(fusion_group)
endif()
# Usage: pass_library(target inference) will append to paddle_inference_pass.h
unset(INFER_IR_PASSES CACHE) # clear the global variable
function(pass_library TARGET DEST)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS DIR)
set(targetPrefix "")
cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN})
if(pass_library_DIR)
cc_library(
${TARGET}
SRCS ${pass_library_DIR}/${TARGET}.cc
DEPS graph_pattern_detector pass fuse_pass_base op_version_registry
${pass_library_DEPS})
else()
cc_library(
${TARGET}
SRCS ${TARGET}.cc
DEPS graph_pattern_detector pass fuse_pass_base op_version_registry
${pass_library_DEPS})
endif()
# add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
if(${DEST} STREQUAL "base" OR ${DEST} STREQUAL "inference")
if(NOT CMAKE_BUILD_TYPE STREQUAL "Release")
message(STATUS "add pass ${TARGET} ${DEST}")
endif()
file(APPEND ${pass_file} "USE_PASS(${TARGET});\n")
set(INFER_IR_PASSES
${INFER_IR_PASSES} ${TARGET}
CACHE INTERNAL "")
endif()
endfunction()
cc_library(
node
SRCS node.cc
......@@ -266,7 +218,7 @@ cc_library(
DEPS pass graph_pattern_detector)
set(GLOB_PASS_LIB
${PASS_LIBRARY}
${INFER_IR_PASSES}
CACHE INTERNAL "Global PASS library")
cc_library(
......
pass_library(
build_cinn_pass
base
DEPS
subgraph_detector
cinn_compiler
errors
enforce)
cc_library(
cinn_cache_key
SRCS cinn_cache_key.cc
DEPS graph graph_helper lod_tensor proto_desc)
cc_library(
build_cinn_pass
SRCS build_cinn_pass.cc
DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors
enforce)
cc_library(
transform_desc
SRCS transform_desc.cc
......
......@@ -50,21 +50,10 @@ using framework::ir::Graph;
using framework::ir::Node;
using GraphNodeVec = std::vector<Node*>;
using GraphNodeSet = std::unordered_set<Node*>;
using GraphNodeMap = std::unordered_map<Node*, Node*>;
namespace {
// The delim(`;`) that is used to split the FLAGS_allow_cinn_ops
// & FLAGS_deny_cinn_ops.
constexpr char kDelim[] = ";";
const std::unordered_map<std::string, std::unordered_set<std::string>>
kDenyParamMap = {{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}};
const std::unordered_set<std::string> kDefaultDenyOps = {"feed", "fetch"};
std::unordered_set<std::string> GetDenyVarNames(const GraphNodeSet& cluster) {
std::unordered_set<std::string> OpTransInfo::GetDenyVarNames(
const GraphNodeSet& cluster) const {
std::unordered_set<std::string> deny_var_set;
auto get_debug_info = [](const std::unordered_set<std::string>& var_names) {
......@@ -78,16 +67,16 @@ std::unordered_set<std::string> GetDenyVarNames(const GraphNodeSet& cluster) {
};
for (auto* op : cluster) {
if (kDenyParamMap.count(op->Name())) {
if (deny_param_cond.count(op->Name())) {
const auto* desc = op->Op();
PADDLE_ENFORCE_NE(desc,
nullptr,
platform::errors::PreconditionNotMet(
"The Op %s's OpDesc should not be NULL, which has "
"a parameter in kDenyParamMap.",
"a parameter in deny_param_cond.",
op->Name().c_str()));
auto deny_param_names = kDenyParamMap.at(op->Name());
auto deny_param_names = deny_param_cond.at(op->Name());
VLOG(4) << "We found deny param " << get_debug_info(deny_param_names)
<< " in op [" << op->Name() << "].";
......@@ -118,6 +107,11 @@ std::unordered_set<std::string> GetDenyVarNames(const GraphNodeSet& cluster) {
return deny_var_set;
}
namespace {
// The delim(`;`) that is used to split the FLAGS_allow_cinn_ops
// & FLAGS_deny_cinn_ops.
constexpr char kDelim[] = ";";
std::unordered_set<std::string> StringSplit(const std::string& str,
const std::string& delim) {
std::regex reg(delim);
......@@ -561,30 +555,38 @@ static bool IsInplaceOp(const OpDesc& op_desc) {
void SearchAllSubgraphs(Graph* graph) {
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
auto teller = [&allow_ops, &deny_ops](const Node* node) {
OpTransInfo trans_info;
auto teller = [&allow_ops, &deny_ops, &trans_info](const Node* node) {
const auto& node_name = node->Name();
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node_name) != nullptr;
// skip the dynamic ops
bool is_dynamic = false;
if (trans_info.dynamic_op_cond.count(node_name)) {
is_dynamic = trans_info.dynamic_op_cond.at(node_name)(node);
}
// if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops
if (!allow_ops.empty()) {
return registered && allow_ops.count(node_name);
return registered && !is_dynamic && allow_ops.count(node_name);
}
// if the op type is registered in CINN and deny_ops is not empty, return
// true only when it is not in deny_ops
if (!deny_ops.empty()) {
return registered && !deny_ops.count(node_name);
return registered && !is_dynamic && !deny_ops.count(node_name);
}
// if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops,
// return true only when it is registered in CINN
return registered && !kDefaultDenyOps.count(node_name) &&
(node->IsOp() && !IsInplaceOp(*node->Op()));
return registered && !trans_info.default_deny_ops.count(node_name) &&
!is_dynamic && (node->IsOp() && !IsInplaceOp(*node->Op()));
};
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops;
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops;
std::vector<GraphNodeVec> clusters =
framework::ir::SubgraphDetector(graph, teller)();
LOG(INFO) << "--- [build_cinn_pass] detected " << clusters.size()
<< " cinn supported subgraphs";
auto cluster_debug_info = [](const GraphNodeSet& cluster) {
std::string res = "(";
......@@ -601,7 +603,7 @@ void SearchAllSubgraphs(Graph* graph) {
// Classify var node to inputs, outputs, and internals.
GraphNodeSet cluster_set(node_vec.begin(), node_vec.end());
auto deny_var_set = GetDenyVarNames(cluster_set);
auto deny_var_set = trans_info.GetDenyVarNames(cluster_set);
GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals;
AnalyseClusterVariables(cluster_set,
......
......@@ -14,12 +14,21 @@ limitations under the License. */
#pragma once
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle {
namespace framework {
namespace ir {
class MemOptVarInfo;
class Node;
} // namespace ir
namespace paddle2cinn {
......@@ -31,9 +40,40 @@ constexpr char kInternalVars[] = "InternalVars";
constexpr char kOutputVars[] = "OutputVars";
constexpr char kMemOptVarInfoFromMainGraph[] =
"mem_opt_var_info_from_main_graph";
using Name2VarInfoMap =
std::unordered_map<std::string,
std::shared_ptr<framework::ir::MemOptVarInfo>>;
using GraphNodeSet = std::unordered_set<ir::Node*>;
struct OpTransInfo {
const std::unordered_set<std::string> default_deny_ops{"feed", "fetch"};
const std::unordered_map<std::string, std::function<bool(const ir::Node*)>>
dynamic_op_cond{
{"slice", [](const ir::Node* node) -> bool {
if (!node->IsOp()) {
return false;
}
auto* op_desc = node->Op();
auto infer_flags =
op_desc->GetAttrIfExists<std::vector<int>>("infer_flags");
if (std::find_if(
infer_flags.begin(), infer_flags.end(), [](int v) {
return v < 0;
}) != infer_flags.end()) {
return true;
}
return false;
}}};
const std::unordered_map<std::string, std::unordered_set<std::string>>
deny_param_cond{{"batch_norm", {"ReserveSpace"}},
{"batch_norm_grad", {"ReserveSpace"}}};
std::unordered_set<std::string> GetDenyVarNames(
const GraphNodeSet& cluster) const;
};
// A pass named BuildCinnPass, the function of this pass is:
//
......
......@@ -4,6 +4,7 @@
*Pass*;
*profile*;
*phi*;
*cinn*;
local:
*;
};
......@@ -14,6 +14,7 @@
#pragma once
#include <algorithm>
#include <chrono>
namespace infrt {
......
......@@ -15,6 +15,8 @@
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/range_function.h"
......@@ -56,6 +58,8 @@ void ArangeKernel(const Context& dev_ctx,
auto stream = dev_ctx.stream();
int block = std::min(size, static_cast<int64_t>(256));
PADDLE_ENFORCE_NE(
block, 0, errors::OutOfRange("The value of block cannot be 0."));
int grid = (size + block - 1) / block;
Range<T><<<grid, block, 0, stream>>>(start_value, step_value, size, out_data);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册