diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index fe3c841186c35ea28c1d44007d91de5b997c1388..7476c199cfd073ec0962fa9a48f24750a6484bb5 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -83,7 +83,6 @@ void IRPassManager::CreatePasses(Argument *argument, new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir))); } - // graph_ = pass->Apply(std::move(graph_)); pre_pass = pass_name; passes_.emplace_back(std::move(pass)); @@ -97,8 +96,9 @@ std::unique_ptr IRPassManager::Apply(std::unique_ptr graph) { PADDLE_ENFORCE(graph.get()); // Apply all the passes for (const auto &pass : passes_) { - if (pass->Type() == "graph_viz_pass") continue; - PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type()); + if (pass->Type() != "graph_viz_pass") { + PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type()); + } graph = pass->Apply(std::move(graph)); } return std::move(graph); diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index eecab238a88e90399eb70f17caa57633af4e2a69..e92273b4dd94f11e0e90c91fd82dafe42bf158f3 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -318,4 +318,9 @@ NativeConfig AnalysisConfig::ToNativeConfig() const { return config; } +void AnalysisConfig::SwitchIrDebug(int x) { + ir_debug_ = x; + Update(); +} + } // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index 6d11b461082d0ed8ba08c9e280bba86737b86e71..002ba90e40e69d565f5a54e374a3f0083b84273f 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -196,7 +196,7 @@ TEST(AnalysisPredictor, memory_optim) { AnalysisConfig config(FLAGS_dirname); config.DisableGpu(); config.EnableMemoryOptim(true); - config.pass_builder()->TurnOnDebug(); + config.SwitchIrDebug(); auto native_predictor = CreatePaddlePredictor(config.ToNativeConfig()); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 9d9ed6a39d8324002a8850deae9bb8dd5af7ef9b..47361b3279e14dd65a0e6e7f864e508ef1183045 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -140,9 +140,12 @@ struct AnalysisConfig { */ bool tensorrt_engine_enabled() const { return use_tensorrt_; } - /** Control whther to debug IR graph analysis phase. + /** \brief Control whether to debug IR graph analysis phase. + * + * This will generate DOT files for visualizing the computation graph after + * each analysis pass applied. */ - void SwitchIrDebug(int x = true) { ir_debug_ = x; } + void SwitchIrDebug(int x = true); /** Turn on MKLDNN. */ diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc index 8be2a6d79b2ede2c149aa523e38c3960ab30acb1..dd953e0dccbb3749bfcc87966453c6976dfefa10 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc @@ -142,7 +142,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { cfg->SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params"); cfg->DisableGpu(); cfg->SwitchSpecifyInputNames(); - cfg->pass_builder()->TurnOnDebug(); + cfg->SwitchIrDebug(); cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); if (use_mkldnn) { cfg->EnableMKLDNN(); diff --git a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc index 2db297e2005c6b657259187d6b6b76657d9e4388..2003be82019333ca97b9fa8ef83668825fe5710d 100644 --- a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc @@ -69,7 +69,7 @@ void SetInput(std::vector> *inputs) { TEST(Analyzer_Text_Classification, profile) { AnalysisConfig cfg; SetConfig(&cfg); - cfg.pass_builder()->TurnOnDebug(); + cfg.SwitchIrDebug(); std::vector outputs; std::vector> input_slots_all;