From d0ac92531dacfa09c8142ad515eb04ee9a8a8ef9 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Sun, 1 Apr 2018 02:35:07 -0700 Subject: [PATCH] Improve ParallelExecutor performance --- .../details/nccl_all_reduce_op_handle.cc | 2 +- .../details/nccl_all_reduce_op_handle.h | 5 ++ .../fluid/framework/details/op_handle_base.h | 4 ++ .../details/threaded_ssa_graph_executor.cc | 49 +++++++++++++++---- .../details/threaded_ssa_graph_executor.h | 12 ++++- paddle/fluid/framework/parallel_executor.cc | 2 + python/paddle/fluid/parallel_executor.py | 3 +- 7 files changed, 63 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index 5ddf331cfc..55b5f11358 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -76,7 +76,7 @@ void NCCLAllReduceOpHandle::RunImpl() { } } -std::string NCCLAllReduceOpHandle::Name() const { return "NCCL AllReduce"; } +std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; } } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h index 045070bb6a..3d61fa79f7 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h @@ -14,6 +14,9 @@ #pragma once +#include +#include + #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" @@ -34,6 +37,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { std::string Name() const override; + bool IsDelayedOp() override { return true; }; + protected: void RunImpl() override; }; diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 71672fd24c..54c2d627ff 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -13,6 +13,8 @@ // limitations under the License. #pragma once +#include +#include #include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/platform/device_context.h" @@ -53,6 +55,8 @@ class OpHandleBase { void AddOutput(VarHandleBase *out); + virtual bool IsDelayedOp() { return false; } + protected: virtual void RunImpl() = 0; }; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 3f8655147b..075eed4ecc 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -29,17 +29,27 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( local_scopes_(local_scopes), places_(places), fetch_ctxs_(places), - use_event_(use_event) {} + use_event_(use_event), + running_ops_(0) {} + +void ThreadedSSAGraphExecutor::RunDelayedOps( + const std::unordered_set &delayed_ops) { + for (auto op : delayed_ops) { + op->Run(use_event_); + } +} FeedFetchList ThreadedSSAGraphExecutor::Run( const std::vector &fetch_tensors) { std::unordered_map pending_ops; std::unordered_set pending_vars; - BlockingQueue ready_vars; - std::unordered_set ready_ops; + std::unordered_set delayed_ops; + std::unordered_set after_delayed_ops; + std::unordered_set delayed_vars; + auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) { pending_vars.insert(&var); if (var.generated_op_ == nullptr) { @@ -106,7 +116,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto run_all_ready_ops = [&] { for (auto *op : ready_ops) { - RunOp(ready_vars, op); + if (op->IsDelayedOp()) { + delayed_ops.insert(op); + delayed_vars.insert(op->outputs_.begin(), op->outputs_.end()); + ready_vars.Extend(op->outputs_); + continue; + } + running_ops_++; + RunOp(&ready_vars, op); } ready_ops.clear(); }; @@ -124,7 +141,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // 2. Find ready variable bool timeout; - auto cur_ready_vars = ready_vars.PopAll(1000, &timeout); + auto cur_ready_vars = ready_vars.PopAll(1, &timeout); if (timeout) { if (exception_) { @@ -141,13 +158,24 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto &deps = pending_ops[op]; --deps; if (deps == 0) { - ready_ops.insert(op); + if (delayed_vars.find(ready_var) != delayed_vars.end()) { + after_delayed_ops.insert(op); + } else { + ready_ops.insert(op); + } } } } + if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) { + RunDelayedOps(delayed_ops); + delayed_ops.clear(); + for (auto *op : after_delayed_ops) { + ready_ops.insert(op); + } + after_delayed_ops.clear(); + } // Keep loop until all vars are ready. } - ++computation_count_; auto sync_computation = [&] { @@ -182,12 +210,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( } void ThreadedSSAGraphExecutor::RunOp( - BlockingQueue &ready_var_q, details::OpHandleBase *op) { - auto op_run = [&ready_var_q, op, this] { + BlockingQueue *ready_var_q, details::OpHandleBase *op) { + auto op_run = [ready_var_q, op, this] { try { VLOG(10) << op->Name() << " : " << op->DebugString(); op->Run(use_event_); - ready_var_q.Extend(op->outputs_); + running_ops_--; + ready_var_q->Extend(op->outputs_); } catch (platform::EnforceNotMet ex) { exception_.reset(new platform::EnforceNotMet(ex)); } catch (...) { diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 2ea57ac8f9..6193b897e4 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -14,7 +14,12 @@ #pragma once -#include +#include +#include +#include +#include +#include + #include #include "ThreadPool.h" // ThreadPool in thrird party #include "paddle/fluid/framework/details/ssa_graph_executor.h" @@ -79,9 +84,11 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ~ThreadedSSAGraphExecutor() {} private: - void RunOp(BlockingQueue &ready_var_q, + void RunOp(BlockingQueue *ready_var_q, details::OpHandleBase *op); + void RunDelayedOps(const std::unordered_set &delayed_ops); + private: std::unique_ptr<::ThreadPool> pool_; std::vector local_scopes_; @@ -89,6 +96,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { platform::DeviceContextPool fetch_ctxs_; const bool use_event_; std::unique_ptr exception_; + std::atomic running_ops_; size_t computation_count_{0}; size_t max_async_computation{100}; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 577eea92d2..002a6d362f 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" +#include "paddle/fluid/platform/profiler.h" #include #include @@ -151,6 +152,7 @@ void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::Run(const std::vector &fetch_tensors, const std::string &fetched_var_name) { + platform::RecordBlock b(0); auto fetch_data = member_->executor_->Run(fetch_tensors); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetch_data; diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 5e0588fa73..33e8d3bf21 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -16,6 +16,7 @@ import core import multiprocessing import framework import executor +import sys __all__ = ['ParallelExecutor'] @@ -35,7 +36,7 @@ class ParallelExecutor(object): places.append(p) if num_threads is None: - num_threads = min(len(places) * 2, multiprocessing.cpu_count()) + num_threads = len(places) startup = framework.default_startup_program() main = framework.default_main_program() -- GitLab