提交 43c82376 编写于 作者: Q Qiao Longfei

use one graph

上级 10393dd0
...@@ -21,15 +21,14 @@ namespace details { ...@@ -21,15 +21,14 @@ namespace details {
AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs) std::unique_ptr<ir::Graph> &&graph)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr),
places_(std::move(places)), places_(std::move(places)),
graphs_(std::move(graphs)) { graph_(std::move(graph)) {
VLOG(3) << "build AsyncSSAGraphExecutor"; VLOG(3) << "build AsyncSSAGraphExecutor";
PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size());
PADDLE_ENFORCE_EQ(graphs_.size(), local_scopes_.size());
// set the correct size of thread pool to each device. // set the correct size of thread pool to each device.
strategy_.num_threads_ = strategy_.num_threads_ < places_.size() strategy_.num_threads_ = strategy_.num_threads_ < places_.size()
...@@ -39,7 +38,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( ...@@ -39,7 +38,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor(
<< " to run the operators of the graph on each device."; << " to run the operators of the graph on each device.";
for (size_t i = 0; i < places.size(); ++i) { for (size_t i = 0; i < places.size(); ++i) {
executors_.emplace_back(new details::ThreadedSSAGraphExecutor( executors_.emplace_back(new details::ThreadedSSAGraphExecutor(
strategy_, {local_scopes_[i]}, {places_[i]}, std::move(graphs_[i]))); strategy_, {local_scopes_[i]}, {places_[i]}, graph_.get()));
} }
} }
......
...@@ -29,9 +29,9 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { ...@@ -29,9 +29,9 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, AsyncSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::vector<std::unique_ptr<ir::Graph>> &&graphs); std::unique_ptr<ir::Graph> &&graph);
~AsyncSSAGraphExecutor() final = default; ~AsyncSSAGraphExecutor() final = default;
const ir::Graph &Graph() const override { return *graphs_[0]; } const ir::Graph &Graph() const override { return *graph_; }
FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override; FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
...@@ -40,7 +40,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { ...@@ -40,7 +40,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor {
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::unique_ptr<::ThreadPool> pool_{nullptr}; std::unique_ptr<::ThreadPool> pool_{nullptr};
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
std::vector<std::unique_ptr<ir::Graph>> graphs_; std::unique_ptr<ir::Graph> graph_;
std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_; std::vector<std::unique_ptr<details::ThreadedSSAGraphExecutor>> executors_;
ExceptionHolder exception_holder_; ExceptionHolder exception_holder_;
......
...@@ -264,53 +264,42 @@ ParallelExecutor::ParallelExecutor( ...@@ -264,53 +264,42 @@ ParallelExecutor::ParallelExecutor(
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 2. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
std::vector<std::unique_ptr<ir::Graph>> graphs; std::unique_ptr<ir::Graph> graph;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use local async mode"; VLOG(3) << "use local async mode";
for (size_t i = 0; i < member_->places_.size(); ++i) { graph =
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( build_strategy.Apply(main_program, {member_->places_[0]}, loss_var_name,
main_program, {member_->places_[i]}, loss_var_name, {member_->local_scopes_[0]}, member_->nranks_,
{member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_, member_->use_cuda_, member_->nccl_ctxs_.get());
member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph));
}
} else { } else {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( graph = build_strategy.Apply(main_program, member_->places_, loss_var_name,
main_program, member_->places_, loss_var_name, member_->local_scopes_, member_->local_scopes_, member_->nranks_,
member_->nranks_, member_->use_cuda_, member_->nccl_ctxs_.get()); member_->use_cuda_, member_->nccl_ctxs_.get());
graphs.push_back(std::move(graph));
} }
#else #else
if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { if (build_strategy.async_mode_ && !build_strategy.is_distribution_) {
VLOG(3) << "use local async mode"; VLOG(3) << "use local async mode";
for (size_t i = 0; i < member_->places_.size(); ++i) { graph = build_strategy.Apply(main_program, {member_->places_[0]},
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( loss_var_name, {member_->local_scopes_[0]},
main_program, {member_->places_[i]}, loss_var_name,
{member_->local_scopes_[i]}, member_->nranks_, member_->use_cuda_);
graphs.push_back(std::move(graph));
}
} else {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, member_->places_, loss_var_name, member_->local_scopes_,
member_->nranks_, member_->use_cuda_); member_->nranks_, member_->use_cuda_);
graphs.push_back(std::move(graph)); } else {
graph = build_strategy.Apply(main_program, member_->places_, loss_var_name,
member_->local_scopes_, member_->nranks_,
member_->use_cuda_);
} }
#endif #endif
auto max_memory_size = GetEagerDeletionThreshold(); auto max_memory_size = GetEagerDeletionThreshold();
VLOG(10) << "Eager Deletion Threshold " VLOG(10) << "Eager Deletion Threshold "
<< static_cast<float>(max_memory_size) / (1 << 30); << static_cast<float>(max_memory_size) / (1 << 30);
if (max_memory_size >= 0) { if (max_memory_size >= 0) {
for (size_t i = 0; i < graphs.size(); ++i) { graph = member_->PrepareGCAndRefCnts(std::move(graph),
graphs[i] = member_->PrepareGCAndRefCnts( static_cast<size_t>(max_memory_size));
std::move(graphs[i]), static_cast<size_t>(max_memory_size));
}
} }
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars // skip control vars and empty vars
std::vector<details::VariableInfo> var_infos; std::vector<details::VariableInfo> var_infos;
for (auto &graph : graphs) {
for (auto &node : graph->Nodes()) { for (auto &node : graph->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back(); var_infos.emplace_back();
...@@ -319,16 +308,15 @@ ParallelExecutor::ParallelExecutor( ...@@ -319,16 +308,15 @@ ParallelExecutor::ParallelExecutor(
var_infos.back().persistable_ = node->Var()->Persistable(); var_infos.back().persistable_ = node->Var()->Persistable();
} }
} }
}
// If the loss_var_name is given, the number of graph should be only one. // If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) { if (loss_var_name.size()) {
size_t graph_num = ir::GraphNum(*graphs[0]); size_t graph_num = ir::GraphNum(*graph);
if (graph_num > 1) { if (graph_num > 1) {
LOG(WARNING) LOG(WARNING)
<< "The number of graph should be only one, " << "The number of graph should be only one, "
"but the current graph has " "but the current graph has "
<< ir::GraphNum(*graphs[0]) << ir::GraphNum(*graph)
<< " sub_graphs. If you want to see the nodes of the " << " sub_graphs. If you want to see the nodes of the "
"sub_graphs, you should use 'FLAGS_print_sub_graph_dir' " "sub_graphs, you should use 'FLAGS_print_sub_graph_dir' "
"to specify the output dir. NOTES: if you not do training, " "to specify the output dir. NOTES: if you not do training, "
...@@ -340,7 +328,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -340,7 +328,7 @@ ParallelExecutor::ParallelExecutor(
VLOG(3) << "use AsyncSSAGraphExecutor"; VLOG(3) << "use AsyncSSAGraphExecutor";
member_->executor_.reset(new details::AsyncSSAGraphExecutor( member_->executor_.reset(new details::AsyncSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs))); std::move(graph)));
} else if (build_strategy.enable_parallel_graph_) { } else if (build_strategy.enable_parallel_graph_) {
VLOG(3) << "use ParallelSSAGraphExecutor"; VLOG(3) << "use ParallelSSAGraphExecutor";
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -358,12 +346,12 @@ ParallelExecutor::ParallelExecutor( ...@@ -358,12 +346,12 @@ ParallelExecutor::ParallelExecutor(
VLOG(3) << "use ThreadedSSAGraphExecutor"; VLOG(3) << "use ThreadedSSAGraphExecutor";
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs[0]))); std::move(graph)));
} else { } else {
VLOG(3) << "use FastThreadedSSAGraphExecutor"; VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->places_, exec_strategy, member_->local_scopes_, member_->places_,
std::move(graphs[0]))); std::move(graph)));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册