提交 7f3be090 编写于 作者: Q Qiao Longfei

fix multi graph test=develop

上级 9465c3d0
...@@ -249,6 +249,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply( ...@@ -249,6 +249,7 @@ std::unique_ptr<ir::Graph> BuildStrategy::Apply(
graph = pass->Apply(std::move(graph)); graph = pass->Apply(std::move(graph));
VLOG(3) << "Finish Apply Pass " << pass->Type(); VLOG(3) << "Finish Apply Pass " << pass->Type();
} }
VLOG(3) << "All Passes Applied";
return graph; return graph;
} }
......
...@@ -259,14 +259,15 @@ ParallelExecutor::ParallelExecutor( ...@@ -259,14 +259,15 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
std::unique_ptr<ir::Graph> graph;
std::vector<std::unique_ptr<ir::Graph>> graphs; std::vector<std::unique_ptr<ir::Graph>> graphs;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
graph = build_strategy.Apply(main_program, member_->places_, loss_var_name, std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
member_->local_scopes_, member_->nranks_, main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->use_cuda_, member_->nccl_ctxs_.get()); member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph));
#else #else
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use local async mode";
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, {member_->places_[i]}, loss_var_name, main_program, {member_->places_[i]}, loss_var_name,
...@@ -274,22 +275,26 @@ ParallelExecutor::ParallelExecutor( ...@@ -274,22 +275,26 @@ ParallelExecutor::ParallelExecutor(
graphs.push_back(std::move(graph)); graphs.push_back(std::move(graph));
} }
} else { } else {
graph = build_strategy.Apply(main_program, member_->places_, loss_var_name, std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
member_->local_scopes_, member_->nranks_, main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->use_cuda_); member_->nranks_, member_->use_cuda_);
graphs.push_back(std::move(graph));
} }
#endif #endif
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
VLOG(10) << "Eager Deletion Threshold " VLOG(10) << "Eager Deletion Threshold "
<< static_cast<float>(max_memory_size) / (1 << 30); << static_cast<float>(max_memory_size) / (1 << 30);
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
graph = member_->PrepareGCAndRefCnts(std::move(graph), for (size_t i = 0; i < graphs.size(); ++i) {
static_cast<size_t>(max_memory_size)); graphs[i] = member_->PrepareGCAndRefCnts(
std::move(graphs[i]), static_cast<size_t>(max_memory_size));
}
} }
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars // skip control vars and empty vars
std::vector<details::VariableInfo> var_infos; std::vector<details::VariableInfo> var_infos;
for (auto &graph : graphs) {
for (auto &node : graph->Nodes()) { for (auto &node : graph->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back(); var_infos.emplace_back();
...@@ -298,15 +303,16 @@ ParallelExecutor::ParallelExecutor( ...@@ -298,15 +303,16 @@ ParallelExecutor::ParallelExecutor(
var_infos.back().persistable_ = node->Var()->Persistable(); var_infos.back().persistable_ = node->Var()->Persistable();
} }
} }
}
// If the loss_var_name is given, the number of graph should be only one. // If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) { if (loss_var_name.size()) {
size_t graph_num = ir::GraphNum(*graph); size_t graph_num = ir::GraphNum(*graphs[0]);
if (graph_num > 1) { if (graph_num > 1) {
LOG(WARNING) LOG(WARNING)
<< "The number of graph should be only one, " << "The number of graph should be only one, "
"but the current graph has " "but the current graph has "
<< ir::GraphNum(*graph) << ir::GraphNum(*graphs[0])
<< " sub_graphs. If you want to see the nodes of the " << " sub_graphs. If you want to see the nodes of the "
"sub_graphs, you should use 'FLAGS_print_sub_graph_dir' " "sub_graphs, you should use 'FLAGS_print_sub_graph_dir' "
"to specify the output dir. NOTES: if you not do training, " "to specify the output dir. NOTES: if you not do training, "
...@@ -326,7 +332,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -326,7 +332,7 @@ ParallelExecutor::ParallelExecutor(
// allreduce_seq_pass doesn't need it as the attr. // allreduce_seq_pass doesn't need it as the attr.
member_->executor_.reset(new details::ParallelSSAGraphExecutor( member_->executor_.reset(new details::ParallelSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, main_program, exec_strategy, member_->local_scopes_, member_->places_, main_program,
std::move(graph))); std::move(graphs[0])));
#else #else
PADDLE_THROW( PADDLE_THROW(
"Paddle should be compiled with CUDA for ParallelGraph Execution."); "Paddle should be compiled with CUDA for ParallelGraph Execution.");
...@@ -336,12 +342,12 @@ ParallelExecutor::ParallelExecutor( ...@@ -336,12 +342,12 @@ ParallelExecutor::ParallelExecutor(
VLOG(3) << "use ThreadedSSAGraphExecutor"; VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_,
std::move(graph))); std::move(graphs[0])));
} else { } else {
VLOG(3) << "use FastThreadedSSAGraphExecutor"; VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_,
std::move(graph))); std::move(graphs[0])));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册