From e887d71958d1db99a8766f2a79cc481b51663e95 Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Thu, 31 Jan 2019 09:20:41 +0800 Subject: [PATCH] fix ir debug config (#15571) --- paddle/fluid/inference/analysis/ir_pass_manager.cc | 6 +++--- paddle/fluid/inference/api/analysis_config.cc | 5 +++++ paddle/fluid/inference/api/analysis_predictor_tester.cc | 2 +- paddle/fluid/inference/api/paddle_analysis_config.h | 7 +++++-- .../fluid/inference/tests/api/analyzer_seq_pool1_tester.cc | 2 +- .../tests/api/analyzer_text_classification_tester.cc | 2 +- 6 files changed, 16 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index fe3c84118..7476c199c 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -83,7 +83,6 @@ void IRPassManager::CreatePasses(Argument *argument, new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir))); } - // graph_ = pass->Apply(std::move(graph_)); pre_pass = pass_name; passes_.emplace_back(std::move(pass)); @@ -97,8 +96,9 @@ std::unique_ptr IRPassManager::Apply(std::unique_ptr graph) { PADDLE_ENFORCE(graph.get()); // Apply all the passes for (const auto &pass : passes_) { - if (pass->Type() == "graph_viz_pass") continue; - PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type()); + if (pass->Type() != "graph_viz_pass") { + PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type()); + } graph = pass->Apply(std::move(graph)); } return std::move(graph); diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index eecab238a..e92273b4d 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -318,4 +318,9 @@ NativeConfig AnalysisConfig::ToNativeConfig() const { return config; } +void AnalysisConfig::SwitchIrDebug(int x) { + ir_debug_ = x; + Update(); +} + } // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor_tester.cc b/paddle/fluid/inference/api/analysis_predictor_tester.cc index 6d11b4610..002ba90e4 100644 --- a/paddle/fluid/inference/api/analysis_predictor_tester.cc +++ b/paddle/fluid/inference/api/analysis_predictor_tester.cc @@ -196,7 +196,7 @@ TEST(AnalysisPredictor, memory_optim) { AnalysisConfig config(FLAGS_dirname); config.DisableGpu(); config.EnableMemoryOptim(true); - config.pass_builder()->TurnOnDebug(); + config.SwitchIrDebug(); auto native_predictor = CreatePaddlePredictor(config.ToNativeConfig()); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 9d9ed6a39..47361b327 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -140,9 +140,12 @@ struct AnalysisConfig { */ bool tensorrt_engine_enabled() const { return use_tensorrt_; } - /** Control whther to debug IR graph analysis phase. + /** \brief Control whether to debug IR graph analysis phase. + * + * This will generate DOT files for visualizing the computation graph after + * each analysis pass applied. */ - void SwitchIrDebug(int x = true) { ir_debug_ = x; } + void SwitchIrDebug(int x = true); /** Turn on MKLDNN. */ diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc index 8be2a6d79..dd953e0dc 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc @@ -142,7 +142,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { cfg->SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params"); cfg->DisableGpu(); cfg->SwitchSpecifyInputNames(); - cfg->pass_builder()->TurnOnDebug(); + cfg->SwitchIrDebug(); cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); if (use_mkldnn) { cfg->EnableMKLDNN(); diff --git a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc index 2db297e20..2003be820 100644 --- a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc @@ -69,7 +69,7 @@ void SetInput(std::vector> *inputs) { TEST(Analyzer_Text_Classification, profile) { AnalysisConfig cfg; SetConfig(&cfg); - cfg.pass_builder()->TurnOnDebug(); + cfg.SwitchIrDebug(); std::vector outputs; std::vector> input_slots_all; -- GitLab