未验证 提交 3a387df6 编写于 作者: W Wilber 提交者: GitHub

[Inference] inference add cinn interface (#48741)

上级 379216ae
......@@ -484,7 +484,8 @@ void AnalyseClusterVariables(
const std::unordered_set<std::string>& deny_var_set,
GraphNodeSet* cluster_inputs,
GraphNodeSet* cluster_outputs,
GraphNodeSet* cluster_internals) {
GraphNodeSet* cluster_internals,
bool is_inference_stage) {
// collecting all input and output of op
for (auto* op_node : cluster) {
const auto& op_name = op_node->Name();
......@@ -523,6 +524,18 @@ void AnalyseClusterVariables(
for (auto* var_node : *cluster_internals) {
cluster_outputs->erase(var_node);
}
if (is_inference_stage) {
// If part of the output of the Op is not used by other operators, change it
// to internal. such as transpose2 op's XShape out.
auto outs = *cluster_outputs;
for (auto* node : outs) {
if (node->outputs.empty()) {
cluster_outputs->erase(node);
cluster_internals->insert(node);
}
}
}
}
void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs,
......@@ -611,7 +624,7 @@ void ReplaceSubGraphWithCinnOpNode(
// Here we using SubgraphDetector to detecte the subgraph that
// all of op node supported by CINN. We using OpMapperRegistry
// to check whether the op node supported by CINN.
void SearchAllSubgraphs(Graph* graph) {
void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) {
auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
OpTransInfo trans_info;
......@@ -671,7 +684,8 @@ void SearchAllSubgraphs(Graph* graph) {
deny_var_set,
&cluster_inputs,
&cluster_outputs,
&cluster_internals);
&cluster_internals,
is_inference_stage);
VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set);
VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs);
......@@ -698,7 +712,13 @@ void SearchAllSubgraphs(Graph* graph) {
}
} // namespace
void BuildCinnPass::ApplyImpl(Graph* graph) const { SearchAllSubgraphs(graph); }
void BuildCinnPass::ApplyImpl(Graph* graph) const {
bool is_inference_stage{false};
if (Has("is_inference_stage")) {
is_inference_stage = Get<bool>("is_inference_stage");
}
SearchAllSubgraphs(graph, is_inference_stage);
}
} // namespace paddle2cinn
} // namespace framework
......
......@@ -368,6 +368,9 @@ struct Argument {
DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
// cinn compiler related
DECL_ARGUMENT_FIELD(use_cinn_compiler, UseCinnCompiler, bool);
private:
std::unordered_set<std::string> valid_fields_;
};
......
......@@ -235,6 +235,8 @@ void IRPassManager::CreatePasses(Argument *argument,
new framework::ProgramDesc *(&argument->main_program()));
} else if (pass_name == "memory_optimize_pass") {
pass->Set("root_predictor_id", new int(argument->root_predictor_id()));
} else if (pass_name == "build_cinn_pass") {
pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler()));
}
if (pass_name == "lite_subgraph_pass") {
bool lite_enable_int8 =
......
......@@ -477,6 +477,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// profile related.
CP_MEMBER(with_profile_);
// cinn compiler related.
CP_MEMBER(use_cinn_compiler_);
// glog related.
CP_MEMBER(with_glog_info_);
......@@ -542,7 +545,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
#undef CP_MEMBER
Update();
if (use_tensorrt_) {
if (use_tensorrt_ || use_cinn_compiler_) {
// Update() will reset all the passes, when some tensorRT pass is deleted in
// other.pass_builder(), it will set again, so we just remove the
// deleted_pass.
......@@ -872,6 +875,14 @@ void AnalysisConfig::Update() {
}
}
// TODO(wilber): An ugly method to update pass, need to be fixed.
if (use_cinn_compiler_) {
pass_builder()->ClearPasses();
for (const auto &pass : kCINNCompilerPasses) {
pass_builder()->AppendPass(pass);
}
}
if (use_dlnne_) {
pass_builder()->ClearPasses();
for (const auto &pass : kDlnneSubgraphPasses) {
......@@ -1316,6 +1327,9 @@ std::string AnalysisConfig::Summary() {
os.InsertRow({"use_lite", use_lite_ ? "true" : "false"});
}
// cinn compiler
os.InsertRow({"use_cinn_compiler", use_cinn_compiler_ ? "true" : "false"});
// ir info
os.InsertRow({"ir_optim", enable_ir_optim_ ? "true" : "false"});
os.InsertRow({"ir_debug", ir_debug_ ? "true" : "false"});
......@@ -1429,4 +1443,19 @@ void AnalysisConfig::Exp_DisableMixedInferOps(
mixed_black_list_ = black_list;
}
void AnalysisConfig::Exp_EnableCINNCompiler() {
#ifdef PADDLE_WITH_CINN
use_cinn_compiler_ = true;
Update();
#else
PADDLE_THROW(platform::errors::Unavailable(
"You tried to use CINN compiler, but Paddle was not compiled "
"with CINN."));
#endif
}
bool AnalysisConfig::cinn_compiler_enabled() const {
return use_cinn_compiler_;
}
} // namespace paddle
......@@ -1217,6 +1217,10 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
}
if (config_.use_cinn_compiler_) {
argument_.SetUseCinnCompiler(config_.use_cinn_compiler_);
}
#ifdef PADDLE_WITH_MKLDNN
if (config_.mkldnn_quantizer_enabled()) {
LOG(INFO) << "Quantization is enabled";
......@@ -1239,21 +1243,25 @@ void AnalysisPredictor::PrepareArgument() {
#endif
auto *pass_builder = config_.pass_builder();
// TODO(inference): Need to reconstruct the pass_builder, pass should be
// processed in a single
if (model_precision_ != phi::DataType::FLOAT32) {
LOG(INFO) << "Model is mixed precision type with " << model_precision_
<< ", we will use a new PassStrategy. Note that only the GPU "
"backend is supported for now.";
pass_builder->ClearPasses();
const auto &deleted_passes = pass_builder->GetAllDeletedPasses();
if (config_.tensorrt_engine_enabled()) {
for (const auto &pass : kTrtLowerPrecisionPasses) {
if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
}
} else if (config_.use_gpu()) {
for (const auto &pass : kGpuLowerPrecisionPasses) {
if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
if (!config_.use_cinn_compiler_) {
pass_builder->ClearPasses();
const auto &deleted_passes = pass_builder->GetAllDeletedPasses();
if (config_.tensorrt_engine_enabled()) {
for (const auto &pass : kTrtLowerPrecisionPasses) {
if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
}
} else if (config_.use_gpu()) {
for (const auto &pass : kGpuLowerPrecisionPasses) {
if (deleted_passes.count(pass)) continue;
pass_builder->AppendPass(pass);
}
}
}
}
......
......@@ -1016,6 +1016,19 @@ struct PD_INFER_DECL AnalysisConfig {
void SetSkipLoadParams(bool value) { skip_load_params_ = value; }
///
/// \brief Enable use cinn compiler optimization.
///
void Exp_EnableCINNCompiler();
///
/// \brief A boolean state telling whether the CINN compiler optimization is
/// turned on.
///
/// \return bool Whether the CINN compiler optimization is turned on.
///
bool cinn_compiler_enabled() const;
protected:
// Update the config.
void Update();
......@@ -1143,6 +1156,9 @@ struct PD_INFER_DECL AnalysisConfig {
Precision lite_precision_mode_;
bool lite_zero_copy_;
// CINN compiler related.
bool use_cinn_compiler_{false};
// XPU related.
bool use_xpu_{false};
int xpu_device_id_{0};
......
......@@ -204,6 +204,13 @@ const std::vector<std::string> kTrtLowerPrecisionPasses{
"tensorrt_subgraph_pass",
};
const std::vector<std::string> kCINNCompilerPasses{
"gpu_cpu_map_matmul_v2_to_mul_pass",
"gpu_cpu_map_matmul_v2_to_matmul_pass",
"gpu_cpu_map_matmul_to_mul_pass",
"build_cinn_pass",
};
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({
// "identity_scale_op_clean_pass", //
......
......@@ -349,6 +349,9 @@ PD_INFER_DECL extern const std::vector<std::string> kDlnneSubgraphPasses;
/// \brief List of lite subgraph passes.
PD_INFER_DECL extern const std::vector<std::string> kLiteSubgraphPasses;
/// \brief List of cinn compiler passes.
PD_INFER_DECL extern const std::vector<std::string> kCINNCompilerPasses;
/// \brief TODO(inference): Most of the existing pass fusion operators do not
/// support fp16/bf16 precision, temporarily use low precision pass to prevent
/// running errors. After fusion operator supports low precision, delete this.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册