未验证 提交 e887d719 编写于 作者: Y Yan Chunwei 提交者: GitHub

fix ir debug config (#15571)

上级 897789b1
...@@ -83,7 +83,6 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -83,7 +83,6 @@ void IRPassManager::CreatePasses(Argument *argument,
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir))); new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
} }
// graph_ = pass->Apply(std::move(graph_));
pre_pass = pass_name; pre_pass = pass_name;
passes_.emplace_back(std::move(pass)); passes_.emplace_back(std::move(pass));
...@@ -97,8 +96,9 @@ std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) { ...@@ -97,8 +96,9 @@ std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph.get());
// Apply all the passes // Apply all the passes
for (const auto &pass : passes_) { for (const auto &pass : passes_) {
if (pass->Type() == "graph_viz_pass") continue; if (pass->Type() != "graph_viz_pass") {
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type()); PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass->Type());
}
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
} }
return std::move(graph); return std::move(graph);
......
...@@ -318,4 +318,9 @@ NativeConfig AnalysisConfig::ToNativeConfig() const { ...@@ -318,4 +318,9 @@ NativeConfig AnalysisConfig::ToNativeConfig() const {
return config; return config;
} }
void AnalysisConfig::SwitchIrDebug(int x) {
ir_debug_ = x;
Update();
}
} // namespace paddle } // namespace paddle
...@@ -196,7 +196,7 @@ TEST(AnalysisPredictor, memory_optim) { ...@@ -196,7 +196,7 @@ TEST(AnalysisPredictor, memory_optim) {
AnalysisConfig config(FLAGS_dirname); AnalysisConfig config(FLAGS_dirname);
config.DisableGpu(); config.DisableGpu();
config.EnableMemoryOptim(true); config.EnableMemoryOptim(true);
config.pass_builder()->TurnOnDebug(); config.SwitchIrDebug();
auto native_predictor = auto native_predictor =
CreatePaddlePredictor<NativeConfig>(config.ToNativeConfig()); CreatePaddlePredictor<NativeConfig>(config.ToNativeConfig());
......
...@@ -140,9 +140,12 @@ struct AnalysisConfig { ...@@ -140,9 +140,12 @@ struct AnalysisConfig {
*/ */
bool tensorrt_engine_enabled() const { return use_tensorrt_; } bool tensorrt_engine_enabled() const { return use_tensorrt_; }
/** Control whther to debug IR graph analysis phase. /** \brief Control whether to debug IR graph analysis phase.
*
* This will generate DOT files for visualizing the computation graph after
* each analysis pass applied.
*/ */
void SwitchIrDebug(int x = true) { ir_debug_ = x; } void SwitchIrDebug(int x = true);
/** Turn on MKLDNN. /** Turn on MKLDNN.
*/ */
......
...@@ -142,7 +142,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) { ...@@ -142,7 +142,7 @@ void SetConfig(AnalysisConfig *cfg, bool use_mkldnn = false) {
cfg->SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params"); cfg->SetModel(FLAGS_infer_model + "/model", FLAGS_infer_model + "/params");
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchSpecifyInputNames(); cfg->SwitchSpecifyInputNames();
cfg->pass_builder()->TurnOnDebug(); cfg->SwitchIrDebug();
cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads);
if (use_mkldnn) { if (use_mkldnn) {
cfg->EnableMKLDNN(); cfg->EnableMKLDNN();
......
...@@ -69,7 +69,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { ...@@ -69,7 +69,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
TEST(Analyzer_Text_Classification, profile) { TEST(Analyzer_Text_Classification, profile) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
cfg.pass_builder()->TurnOnDebug(); cfg.SwitchIrDebug();
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册