diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index e97a56a743e25cb03e41e778e0439f1573d74000..f7306bfc9a28e776229436cb0454fe95df93cc06 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -484,7 +484,8 @@ void AnalyseClusterVariables( const std::unordered_set& deny_var_set, GraphNodeSet* cluster_inputs, GraphNodeSet* cluster_outputs, - GraphNodeSet* cluster_internals) { + GraphNodeSet* cluster_internals, + bool is_inference_stage) { // collecting all input and output of op for (auto* op_node : cluster) { const auto& op_name = op_node->Name(); @@ -523,6 +524,18 @@ void AnalyseClusterVariables( for (auto* var_node : *cluster_internals) { cluster_outputs->erase(var_node); } + + if (is_inference_stage) { + // If part of the output of the Op is not used by other operators, change it + // to internal. such as transpose2 op's XShape out. + auto outs = *cluster_outputs; + for (auto* node : outs) { + if (node->outputs.empty()) { + cluster_outputs->erase(node); + cluster_internals->insert(node); + } + } + } } void AddLinkToCinnOp(const GraphNodeSet& cluster_inputs, @@ -611,7 +624,7 @@ void ReplaceSubGraphWithCinnOpNode( // Here we using SubgraphDetector to detecte the subgraph that // all of op node supported by CINN. We using OpMapperRegistry // to check whether the op node supported by CINN. -void SearchAllSubgraphs(Graph* graph) { +void SearchAllSubgraphs(Graph* graph, bool is_inference_stage) { auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); OpTransInfo trans_info; @@ -671,7 +684,8 @@ void SearchAllSubgraphs(Graph* graph) { deny_var_set, &cluster_inputs, &cluster_outputs, - &cluster_internals); + &cluster_internals, + is_inference_stage); VLOG(4) << "Cluster Ops: " << cluster_debug_info(cluster_set); VLOG(4) << "Cluster input vars: " << cluster_debug_info(cluster_inputs); @@ -698,7 +712,13 @@ void SearchAllSubgraphs(Graph* graph) { } } // namespace -void BuildCinnPass::ApplyImpl(Graph* graph) const { SearchAllSubgraphs(graph); } +void BuildCinnPass::ApplyImpl(Graph* graph) const { + bool is_inference_stage{false}; + if (Has("is_inference_stage")) { + is_inference_stage = Get("is_inference_stage"); + } + SearchAllSubgraphs(graph, is_inference_stage); +} } // namespace paddle2cinn } // namespace framework diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 8750a9afb44e48fe29b4e33eea61ccc06e083bfe..a72c1fe7622136ed80e2a98ed382c2b964f1937a 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -368,6 +368,9 @@ struct Argument { DECL_ARGUMENT_FIELD(enable_gpu_half, EnableGPUHalf, bool); DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int); + // cinn compiler related + DECL_ARGUMENT_FIELD(use_cinn_compiler, UseCinnCompiler, bool); + private: std::unordered_set valid_fields_; }; diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index cbcc48a7f68e85e4797716e1838215443e4c1983..25b371cb2ff39ebab23595e0a858181ac334d760 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -235,6 +235,8 @@ void IRPassManager::CreatePasses(Argument *argument, new framework::ProgramDesc *(&argument->main_program())); } else if (pass_name == "memory_optimize_pass") { pass->Set("root_predictor_id", new int(argument->root_predictor_id())); + } else if (pass_name == "build_cinn_pass") { + pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler())); } if (pass_name == "lite_subgraph_pass") { bool lite_enable_int8 = diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc old mode 100755 new mode 100644 index c5e648dffc0bfc8b0ab939ca897e69b4a2883c47..17afc4f840e7dcec20e1562232ebaebfde1771be --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -477,6 +477,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { // profile related. CP_MEMBER(with_profile_); + // cinn compiler related. + CP_MEMBER(use_cinn_compiler_); + // glog related. CP_MEMBER(with_glog_info_); @@ -542,7 +545,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { #undef CP_MEMBER Update(); - if (use_tensorrt_) { + if (use_tensorrt_ || use_cinn_compiler_) { // Update() will reset all the passes, when some tensorRT pass is deleted in // other.pass_builder(), it will set again, so we just remove the // deleted_pass. @@ -872,6 +875,14 @@ void AnalysisConfig::Update() { } } + // TODO(wilber): An ugly method to update pass, need to be fixed. + if (use_cinn_compiler_) { + pass_builder()->ClearPasses(); + for (const auto &pass : kCINNCompilerPasses) { + pass_builder()->AppendPass(pass); + } + } + if (use_dlnne_) { pass_builder()->ClearPasses(); for (const auto &pass : kDlnneSubgraphPasses) { @@ -1316,6 +1327,9 @@ std::string AnalysisConfig::Summary() { os.InsertRow({"use_lite", use_lite_ ? "true" : "false"}); } + // cinn compiler + os.InsertRow({"use_cinn_compiler", use_cinn_compiler_ ? "true" : "false"}); + // ir info os.InsertRow({"ir_optim", enable_ir_optim_ ? "true" : "false"}); os.InsertRow({"ir_debug", ir_debug_ ? "true" : "false"}); @@ -1429,4 +1443,19 @@ void AnalysisConfig::Exp_DisableMixedInferOps( mixed_black_list_ = black_list; } +void AnalysisConfig::Exp_EnableCINNCompiler() { +#ifdef PADDLE_WITH_CINN + use_cinn_compiler_ = true; + Update(); +#else + PADDLE_THROW(platform::errors::Unavailable( + "You tried to use CINN compiler, but Paddle was not compiled " + "with CINN.")); +#endif +} + +bool AnalysisConfig::cinn_compiler_enabled() const { + return use_cinn_compiler_; +} + } // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc old mode 100755 new mode 100644 index bf89db83dd4aea59d3ea69759d774fde0de7bbc4..13dba59492b556d92d7032c5ed74eeaa97d9591a --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1217,6 +1217,10 @@ void AnalysisPredictor::PrepareArgument() { argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_); } + if (config_.use_cinn_compiler_) { + argument_.SetUseCinnCompiler(config_.use_cinn_compiler_); + } + #ifdef PADDLE_WITH_MKLDNN if (config_.mkldnn_quantizer_enabled()) { LOG(INFO) << "Quantization is enabled"; @@ -1239,21 +1243,25 @@ void AnalysisPredictor::PrepareArgument() { #endif auto *pass_builder = config_.pass_builder(); + // TODO(inference): Need to reconstruct the pass_builder, pass should be + // processed in a single if (model_precision_ != phi::DataType::FLOAT32) { LOG(INFO) << "Model is mixed precision type with " << model_precision_ << ", we will use a new PassStrategy. Note that only the GPU " "backend is supported for now."; - pass_builder->ClearPasses(); - const auto &deleted_passes = pass_builder->GetAllDeletedPasses(); - if (config_.tensorrt_engine_enabled()) { - for (const auto &pass : kTrtLowerPrecisionPasses) { - if (deleted_passes.count(pass)) continue; - pass_builder->AppendPass(pass); - } - } else if (config_.use_gpu()) { - for (const auto &pass : kGpuLowerPrecisionPasses) { - if (deleted_passes.count(pass)) continue; - pass_builder->AppendPass(pass); + if (!config_.use_cinn_compiler_) { + pass_builder->ClearPasses(); + const auto &deleted_passes = pass_builder->GetAllDeletedPasses(); + if (config_.tensorrt_engine_enabled()) { + for (const auto &pass : kTrtLowerPrecisionPasses) { + if (deleted_passes.count(pass)) continue; + pass_builder->AppendPass(pass); + } + } else if (config_.use_gpu()) { + for (const auto &pass : kGpuLowerPrecisionPasses) { + if (deleted_passes.count(pass)) continue; + pass_builder->AppendPass(pass); + } } } } diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 5521caee9f4307462c748bad6a4512c84763d371..5bf5d3de7b0f00dac516a31d1c3baec7601978da 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -1016,6 +1016,19 @@ struct PD_INFER_DECL AnalysisConfig { void SetSkipLoadParams(bool value) { skip_load_params_ = value; } + /// + /// \brief Enable use cinn compiler optimization. + /// + void Exp_EnableCINNCompiler(); + + /// + /// \brief A boolean state telling whether the CINN compiler optimization is + /// turned on. + /// + /// \return bool Whether the CINN compiler optimization is turned on. + /// + bool cinn_compiler_enabled() const; + protected: // Update the config. void Update(); @@ -1143,6 +1156,9 @@ struct PD_INFER_DECL AnalysisConfig { Precision lite_precision_mode_; bool lite_zero_copy_; + // CINN compiler related. + bool use_cinn_compiler_{false}; + // XPU related. bool use_xpu_{false}; int xpu_device_id_{0}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 4ac91231121d13f76c0f9f32c5a2913964ee5b34..4e397fbd041c7a6b11f00ec7350df2abcc94fad3 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -204,6 +204,13 @@ const std::vector kTrtLowerPrecisionPasses{ "tensorrt_subgraph_pass", }; +const std::vector kCINNCompilerPasses{ + "gpu_cpu_map_matmul_v2_to_mul_pass", + "gpu_cpu_map_matmul_v2_to_matmul_pass", + "gpu_cpu_map_matmul_to_mul_pass", + "build_cinn_pass", +}; + GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { passes_.assign({ // "identity_scale_op_clean_pass", // diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 0990a61da34e16fc2e9b32c259f3a9b34a44fb7f..8dea84400e8e1ae141737d747813c5d23c5e18cb 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -349,6 +349,9 @@ PD_INFER_DECL extern const std::vector kDlnneSubgraphPasses; /// \brief List of lite subgraph passes. PD_INFER_DECL extern const std::vector kLiteSubgraphPasses; +/// \brief List of cinn compiler passes. +PD_INFER_DECL extern const std::vector kCINNCompilerPasses; + /// \brief TODO(inference): Most of the existing pass fusion operators do not /// support fp16/bf16 precision, temporarily use low precision pass to prevent /// running errors. After fusion operator supports low precision, delete this.