From dab7f36909a61af51beacd145228bb2a4acc4db5 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 25 Feb 2019 22:49:03 +0800 Subject: [PATCH] optimize code test=develop --- .../details/async_ssa_graph_executor.cc | 6 ++-- .../details/async_ssa_graph_executor.h | 4 +-- paddle/fluid/framework/parallel_executor.cc | 30 +++++++++---------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index b6d1ee50739..8757842996f 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -21,12 +21,12 @@ namespace details { AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::unique_ptr &&graph) + ir::Graph* graph) : strategy_(std::move(strategy)), local_scopes_(std::move(local_scopes)), pool_(places.size() >= 2 ? new ::ThreadPool(places.size()) : nullptr), places_(std::move(places)), - graph_(std::move(graph)) { + graph_(graph) { VLOG(3) << "build AsyncSSAGraphExecutor"; PADDLE_ENFORCE_EQ(places_.size(), local_scopes_.size()); @@ -38,7 +38,7 @@ AsyncSSAGraphExecutor::AsyncSSAGraphExecutor( << " to run the operators of the graph on each device."; for (size_t i = 0; i < places.size(); ++i) { executors_.emplace_back(new details::ThreadedSSAGraphExecutor( - strategy_, {local_scopes_[i]}, {places_[i]}, graph_.get())); + strategy_, {local_scopes_[i]}, {places_[i]}, graph_)); } } diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.h b/paddle/fluid/framework/details/async_ssa_graph_executor.h index 50f207361fb..8536852a00f 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.h @@ -29,7 +29,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, const std::vector &places, - std::unique_ptr &&graph); + ir::Graph *graph); ~AsyncSSAGraphExecutor() final = default; const ir::Graph &Graph() const override { return *graph_; } @@ -40,7 +40,7 @@ class AsyncSSAGraphExecutor : public SSAGraphExecutor { std::vector local_scopes_; std::unique_ptr<::ThreadPool> pool_{nullptr}; std::vector places_; - std::unique_ptr graph_; + ir::Graph *graph_; std::vector> executors_; ExceptionHolder exception_holder_; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index a498ec5b0b5..081d06b6aa2 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -269,25 +269,26 @@ ParallelExecutor::ParallelExecutor( #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { VLOG(3) << "use local async mode"; - temp_owned_graph = - build_strategy.Apply(std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name, - {member_->local_scopes_[0]}, member_->nranks_, - member_->use_cuda_, member_->nccl_ctxs_.get()); + temp_owned_graph = build_strategy.Apply( + std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name, + {member_->local_scopes_[0]}, member_->nranks_, member_->use_cuda_, + member_->nccl_ctxs_.get()); } else { - temp_owned_graph = build_strategy.Apply(std::move(temp_owned_graph), member_->places_, loss_var_name, - member_->local_scopes_, member_->nranks_, - member_->use_cuda_, member_->nccl_ctxs_.get()); + temp_owned_graph = build_strategy.Apply( + std::move(temp_owned_graph), member_->places_, loss_var_name, + member_->local_scopes_, member_->nranks_, member_->use_cuda_, + member_->nccl_ctxs_.get()); } #else if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { VLOG(3) << "use local async mode"; - temp_owned_graph = build_strategy.Apply(std::move(temp_owned_graph), {member_->places_[0]}, - loss_var_name, {member_->local_scopes_[0]}, - member_->nranks_, member_->use_cuda_); + temp_owned_graph = build_strategy.Apply( + std::move(temp_owned_graph), {member_->places_[0]}, loss_var_name, + {member_->local_scopes_[0]}, member_->nranks_, member_->use_cuda_); } else { - temp_owned_graph = build_strategy.Apply(std::move(temp_owned_graph), member_->places_, loss_var_name, - member_->local_scopes_, member_->nranks_, - member_->use_cuda_); + temp_owned_graph = build_strategy.Apply( + std::move(temp_owned_graph), member_->places_, loss_var_name, + member_->local_scopes_, member_->nranks_, member_->use_cuda_); } #endif @@ -333,8 +334,7 @@ ParallelExecutor::ParallelExecutor( if (build_strategy.async_mode_ && !build_strategy.is_distribution_) { VLOG(3) << "use AsyncSSAGraphExecutor"; member_->executor_.reset(new details::AsyncSSAGraphExecutor( - exec_strategy, member_->local_scopes_, member_->places_, - graph)); + exec_strategy, member_->local_scopes_, member_->places_, graph)); } else if (build_strategy.enable_parallel_graph_) { VLOG(3) << "use ParallelSSAGraphExecutor"; #ifdef PADDLE_WITH_CUDA -- GitLab