From f768fbf7157e4b500de3aa456beddaa138f00cd5 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Tue, 26 Feb 2019 15:01:59 +0800 Subject: [PATCH] support multi graph test=develop --- .../details/async_ssa_graph_executor.cc | 6 +-- .../details/async_ssa_graph_executor.h | 6 +-- paddle/fluid/framework/parallel_executor.cc | 40 ++++++++++++++----- paddle/fluid/framework/parallel_executor.h | 2 +- .../fluid/operators/reader/blocking_queue.h | 1 + .../operators/reader/create_py_reader_op.cc | 5 ++- paddle/fluid/pybind/pybind.cc | 2 +- python/paddle/fluid/parallel_executor.py | 9 ++++- 8 files changed, 50 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 21741667a3a..dfb9d73dcbe 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -20,12 +20,12 @@ namespace details { AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, - const std::vector &places, ir::Graph *graph) + const std::vector &places, std::vector graphs) : strategy_(std::move(strategy)), local_scopes_(std::move(local_scopes)), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), places_(std::move(places)), - graph_(graph) { + graphs_(std::move(graphs)) { VLOG(3) << "build AsyncSSAGraphExecutor"; PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); @@ -37,7 +37,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( << " to run the operators of the graph on each device."; for (size_t i = 0; i < places.size(); ++i) { executors_.emplace_back(new details::ThreadedSSAGraphExecutor( - strategy_, {local_scopes_[i]}, {places_[i]}, graph_)); + strategy_, {local_scopes_[i]}, {places_[i]}, graphs_[i])); } } diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.h b/paddle/fluid/framework/details/async_ssa_graph_executor.h index 8536852a00f..ff85ba2c6cf 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.h @@ -29,9 +29,9 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - ir::Graph *graph); + std::vector graphs); ~AsyncSSAGraphExecutor() final = default; - const ir::Graph &Graph() const override { return *graph_; } + const ir::Graph &Graph() const override { return *graphs_[0]; } FeedFetchList Run(const std::vector &fetch_tensors) override; @@ -40,7 +40,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { std::vector local_scopes_; std::unique_ptr<::ThreadPool> pool_{nullptr}; std::vector places_; - ir::Graph *graph_; + std::vector graphs_; std::vector> executors_; ExceptionHolder exception_holder_; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 081d06b6aa2..b1f40911487 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -188,7 +188,7 @@ ParallelExecutor::ParallelExecutor( const std::string &loss_var_name, Scope *scope, const std::vector &local_scopes, const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy, - ir::Graph *graph) + std::vector graphs) : member_(new ParallelExecutorPrivate(places)) { member_->global_scope_ = scope; member_->use_cuda_ = exec_strategy.use_cuda_; @@ -222,6 +222,8 @@ ParallelExecutor::ParallelExecutor( PADDLE_ENFORCE(!member_->use_cuda_, "gpu mode does not support async_mode_ now!"); } + + ir::Graph *graph = graphs[0]; std::unique_ptr temp_owned_graph(graph); // FIXME(Yancey1989): parallel graph mode get better performance @@ -262,17 +264,26 @@ ParallelExecutor::ParallelExecutor( if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { BCastParamsToDevices(bcast_vars); } -// Startup Program has been run. All local scopes has correct parameters. + // Startup Program has been run. All local scopes has correct parameters. -// Step 2. Convert main_program to SSA form and dependency graph. Also, insert -// ncclOp + // Step 2. Convert main_program to SSA form and dependency graph. Also, insert + // ncclOp + std::vector async_graphs(places.size()); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { VLOG(3) << "use local async mode"; - temp_owned_graph = build_strategy.Apply( - std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name, - {member_->local_scopes_[0]}, member_->nranks_, member_->use_cuda_, - member_->nccl_ctxs_.get()); + temp_owned_graph = + build_strategy.Apply(std::move(temp_owned_graph), {member_->places_[0]}, + loss_var_name, {member_->local_scopes_[0]}, 1, + member_->use_cuda_, member_->nccl_ctxs_.get()); + for (int i = 1; i < member_->places_.size(); ++i) { + std::unique_ptr temp_graph(graphs[i]); + temp_graph = + build_strategy.Apply(std::move(temp_graph), {member_->places_[i]}, + loss_var_name, {member_->local_scopes_[i]}, 1, + member_->use_cuda_, member_->nccl_ctxs_.get()); + async_graphs[i] = temp_graph.release(); + } } else { temp_owned_graph = build_strategy.Apply( std::move(temp_owned_graph), member_->places_, loss_var_name, @@ -284,7 +295,14 @@ ParallelExecutor::ParallelExecutor( VLOG(3) << "use local async mode"; temp_owned_graph = build_strategy.Apply( std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name, - {member_->local_scopes_[0]}, member_->nranks_, member_->use_cuda_); + {member_->local_scopes_[0]}, 1, member_->use_cuda_); + for (int i = 1; i < member_->places_.size(); ++i) { + std::unique_ptr temp_graph(graphs[i]); + temp_graph = build_strategy.Apply( + std::move(temp_graph), {member_->places_[i]}, loss_var_name, + {member_->local_scopes_[i]}, 1, member_->use_cuda_); + async_graphs[i] = temp_graph.release(); + } } else { temp_owned_graph = build_strategy.Apply( std::move(temp_owned_graph), member_->places_, loss_var_name, @@ -304,6 +322,8 @@ ParallelExecutor::ParallelExecutor( graph = temp_owned_graph.release(); } + async_graphs[0] = graph; + // Step 3. Create vars in each scope. Passes may also create new vars. // skip control vars and empty vars std::vector var_infos; @@ -334,7 +354,7 @@ ParallelExecutor::ParallelExecutor( if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { VLOG(3) << "use AsyncSSAGraphExecutor"; member_->executor_.reset(new details::AsyncSSAGraphExecutor( - exec_strategy, member_->local_scopes_, member_->places_, graph)); + exec_strategy, member_->local_scopes_, member_->places_, async_graphs)); } else if (build_strategy.enable_parallel_graph_) { VLOG(3) << "use ParallelSSAGraphExecutor"; #ifdef PADDLE_WITH_CUDA diff --git a/paddle/fluid/framework/parallel_executor.h b/paddle/fluid/framework/parallel_executor.h index ddf60b39466..0e05b2a460a 100644 --- a/paddle/fluid/framework/parallel_executor.h +++ b/paddle/fluid/framework/parallel_executor.h @@ -50,7 +50,7 @@ class ParallelExecutor { const std::vector &local_scopes, const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy, - ir::Graph *graph); + std::vector graphs); ~ParallelExecutor(); diff --git a/paddle/fluid/operators/reader/blocking_queue.h b/paddle/fluid/operators/reader/blocking_queue.h index 45c3ad802fc..c99b2bc593b 100644 --- a/paddle/fluid/operators/reader/blocking_queue.h +++ b/paddle/fluid/operators/reader/blocking_queue.h @@ -95,6 +95,7 @@ class BlockingQueue { void Close() { std::lock_guard lock(mutex_); + VLOG(3) << "close queue"; closed_ = true; send_cv_.notify_all(); receive_cv_.notify_all(); diff --git a/paddle/fluid/operators/reader/create_py_reader_op.cc b/paddle/fluid/operators/reader/create_py_reader_op.cc index 901a92ab5b5..b2469ad0eb2 100644 --- a/paddle/fluid/operators/reader/create_py_reader_op.cc +++ b/paddle/fluid/operators/reader/create_py_reader_op.cc @@ -35,7 +35,10 @@ class PyReader : public framework::FileReader { ~PyReader() { queue_->Close(); } - void Shutdown() override { queue_->Close(); } + void Shutdown() override { + VLOG(3) << "PyReader shutdown!"; + queue_->Close(); + } void Start() override { queue_->ReOpen(); } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f9e73667794..fdee5a6d665 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1230,7 +1230,7 @@ All parameter, weight, gradient are variables in Paddle. pe.def(py::init &, const std::unordered_set &, const std::string &, Scope *, std::vector &, const ExecutionStrategy &, - const BuildStrategy &, ir::Graph *>()) + const BuildStrategy &, std::vector>()) // NOTE: even we return a vec* to Python use reference policy. // We still cannot get local_scope from this vector, since the element // of vec will be freed by Python GC. We can only return Scope* diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 889156ff74d..9c578ef662b 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -177,12 +177,17 @@ class ParallelExecutor(object): # step7: init ParallelExecutor # ParallelExecutor API will be deprecated, don't support parallel graph. - self._graph = core.Graph(main.desc) + self._graphs = [] + if build_strategy.async_mode: + for _ in range(cpu_num): + self._graphs.append(core.Graph(main.desc)) + else: + self._graphs.append(core.Graph(main.desc)) self.executor = core.ParallelExecutor( places, persistable_vars, cpt.to_text(loss_name) if loss_name else six.u(''), scope, - local_scopes, exec_strategy, build_strategy, self._graph) + local_scopes, exec_strategy, build_strategy, self._graphs) self.scope = scope -- GitLab