提交 0779e355 编写于 作者: N nhzlx

fix two bug:

1. graph and program_desc alignment
2. trt stream

test=develop
上级 027d24c8
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include <set>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
...@@ -79,7 +80,7 @@ NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) { ...@@ -79,7 +80,7 @@ NodesTSIterator::NodesTSIterator(const std::vector<Node *> &source) {
} }
std::unordered_set<Node *> visited; std::unordered_set<Node *> visited;
std::unordered_set<Node *> to_visit{source.begin(), source.end()}; std::set<Node *> to_visit{source.begin(), source.end()};
std::vector<Node *> inlink_visited; std::vector<Node *> inlink_visited;
while (!to_visit.empty()) { while (!to_visit.empty()) {
......
...@@ -74,8 +74,9 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -74,8 +74,9 @@ void IRPassManager::CreatePasses(Argument *argument,
bool enable_int8 = false; bool enable_int8 = false;
if (argument->tensorrt_precision_mode() == if (argument->tensorrt_precision_mode() ==
contrib::AnalysisConfig::Precision::kInt8) contrib::AnalysisConfig::Precision::kInt8) {
enable_int8 = true; enable_int8 = true;
}
pass->Set("enable_int8", new bool(enable_int8)); pass->Set("enable_int8", new bool(enable_int8));
pass->Set("model_dir", new std::string(argument->model_path())); pass->Set("model_dir", new std::string(argument->model_path()));
...@@ -103,12 +104,14 @@ std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) { ...@@ -103,12 +104,14 @@ std::unique_ptr<Graph> IRPassManager::Apply(std::unique_ptr<Graph> graph) {
} }
framework::proto::ProgramDesc IRPassManager::AcquireProgram( framework::proto::ProgramDesc IRPassManager::AcquireProgram(
std::unique_ptr<Graph> *graph, const ProgramDesc &program) const { std::unique_ptr<Graph> *graph, ProgramDesc *program) const {
auto pass = auto pass =
framework::ir::PassRegistry::Instance().Get("graph_to_program_pass"); framework::ir::PassRegistry::Instance().Get("graph_to_program_pass");
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
ProgramDesc desc; ProgramDesc desc;
desc.CopyFrom(*const_cast<ProgramDesc &>(program).Proto()); desc.CopyFrom(*program->Proto());
pass->SetNotOwned("program", &desc); pass->SetNotOwned("program", &desc);
auto *the_graph = graph->release(); auto *the_graph = graph->release();
*graph = pass->Apply(std::unique_ptr<Graph>(the_graph)); *graph = pass->Apply(std::unique_ptr<Graph>(the_graph));
......
...@@ -42,8 +42,8 @@ class IRPassManager final { ...@@ -42,8 +42,8 @@ class IRPassManager final {
std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph); std::unique_ptr<Graph> Apply(std::unique_ptr<Graph> graph);
framework::proto::ProgramDesc AcquireProgram( framework::proto::ProgramDesc AcquireProgram(std::unique_ptr<Graph> *graph,
std::unique_ptr<Graph> *graph, const ProgramDesc &program) const; ProgramDesc *program) const;
framework::ir::Graph &graph() const { return *graph_; } framework::ir::Graph &graph() const { return *graph_; }
......
...@@ -31,7 +31,11 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) { ...@@ -31,7 +31,11 @@ void IrGraphToProgramPass::RunImpl(Argument *argument) {
} }
std::unique_ptr<Graph> graph(argument->main_graph_ptr()); std::unique_ptr<Graph> graph(argument->main_graph_ptr());
framework::ProgramDesc desc(argument->main_program());
// Direct using ProgramDesc desc(argument->main_program()) may cause
// incomplete copies of information.
framework::ProgramDesc desc;
desc.CopyFrom(*argument->main_program().Proto());
pass->SetNotOwned("program", &desc); pass->SetNotOwned("program", &desc);
auto thegraph = pass->Apply(std::move(graph)); auto thegraph = pass->Apply(std::move(graph));
thegraph.release(); // the argument still own the graph. thegraph.release(); // the argument still own the graph.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册