From 0779e355442619f80a94617caf10c8cb7abab1ab Mon Sep 17 00:00:00 2001 From: nhzlx Date: Thu, 24 Jan 2019 13:27:55 +0000 Subject: [PATCH] fix two bug: 1. graph and program_desc alignment 2. trt stream test=develop --- paddle/fluid/framework/ir/graph_traits.cc | 3 ++- paddle/fluid/inference/analysis/ir_pass_manager.cc | 9 ++++++--- paddle/fluid/inference/analysis/ir_pass_manager.h | 4 ++-- .../analysis/passes/ir_graph_to_program_pass.cc | 6 +++++- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_traits.cc b/paddle/fluid/framework/ir/graph_traits.cc index 2ee12cc41..929d9edc3 100644 --- a/paddle/fluid/framework/ir/graph_traits.cc +++ b/paddle/fluid/framework/ir/graph_traits.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/ir/graph_traits.h" +#include #include namespace paddle { @@ -79,7 +80,7 @@ NodesTSIterator::NodesTSIterator(const std::vector &source) { } std::unordered_set visited; - std::unordered_set to_visit{source.begin(), source.end()}; + std::set to_visit{source.begin(), source.end()}; std::vector inlink_visited; while (!to_visit.empty()) { diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 376d6aef2..9aaae1614 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -74,8 +74,9 @@ void IRPassManager::CreatePasses(Argument *argument, bool enable_int8 = false; if (argument->tensorrt_precision_mode() == - contrib::AnalysisConfig::Precision::kInt8) + contrib::AnalysisConfig::Precision::kInt8) { enable_int8 = true; + } pass->Set("enable_int8", new bool(enable_int8)); pass->Set("model_dir", new std::string(argument->model_path())); @@ -103,12 +104,14 @@ std::unique_ptr IRPassManager::Apply(std::unique_ptr graph) { } framework::proto::ProgramDesc IRPassManager::AcquireProgram( - std::unique_ptr *graph, const ProgramDesc &program) const { + std::unique_ptr *graph, ProgramDesc *program) const { auto pass = framework::ir::PassRegistry::Instance().Get("graph_to_program_pass"); + // Direct using ProgramDesc desc(argument->main_program()) may cause + // incomplete copies of information. ProgramDesc desc; - desc.CopyFrom(*const_cast(program).Proto()); + desc.CopyFrom(*program->Proto()); pass->SetNotOwned("program", &desc); auto *the_graph = graph->release(); *graph = pass->Apply(std::unique_ptr(the_graph)); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.h b/paddle/fluid/inference/analysis/ir_pass_manager.h index 983a58264..f378d35d9 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.h +++ b/paddle/fluid/inference/analysis/ir_pass_manager.h @@ -42,8 +42,8 @@ class IRPassManager final { std::unique_ptr Apply(std::unique_ptr graph); - framework::proto::ProgramDesc AcquireProgram( - std::unique_ptr *graph, const ProgramDesc &program) const; + framework::proto::ProgramDesc AcquireProgram(std::unique_ptr *graph, + ProgramDesc *program) const; framework::ir::Graph &graph() const { return *graph_; } diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc index f1da37af3..6b3d80fce 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_to_program_pass.cc @@ -31,7 +31,11 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) { } std::unique_ptr graph(argument->main_graph_ptr()); - framework::ProgramDesc desc(argument->main_program()); + + // Direct using ProgramDesc desc(argument->main_program()) may cause + // incomplete copies of information. + framework::ProgramDesc desc; + desc.CopyFrom(*argument->main_program().Proto()); pass->SetNotOwned("program", &desc); auto thegraph = pass->Apply(std::move(graph)); thegraph.release(); // the argument still own the graph. -- GitLab