diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc index 5ddf331cfca39a4e81a42d9ff8efd5af7bcf6829..55b5f113589e090386d287e228349f22fb94a7ab 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc @@ -76,7 +76,7 @@ void NCCLAllReduceOpHandle::RunImpl() { } } -std::string NCCLAllReduceOpHandle::Name() const { return "NCCL AllReduce"; } +std::string NCCLAllReduceOpHandle::Name() const { return "nccl_all_reduce"; } } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h index 045070bb6a97e90600cd24d9f43cd2a10a4bc1f5..3d61fa79f759b9c57ea3e60cc6639143bfbb1eac 100644 --- a/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/nccl_all_reduce_op_handle.h @@ -14,6 +14,9 @@ #pragma once +#include +#include + #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" @@ -34,6 +37,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { std::string Name() const override; + bool IsDelayedOp() override { return true; }; + protected: void RunImpl() override; }; diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 71672fd24c65ee654fb9f703ea5808c31ee8fbb0..54c2d627ff30461a93e765241652d42e204c6689 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -13,6 +13,8 @@ // limitations under the License. #pragma once +#include +#include #include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/platform/device_context.h" @@ -53,6 +55,8 @@ class OpHandleBase { void AddOutput(VarHandleBase *out); + virtual bool IsDelayedOp() { return false; } + protected: virtual void RunImpl() = 0; }; diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 3f8655147b688239509dea98925df310a46cbef8..075eed4ecca0f38f8868cb6742e89020d05bae97 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -29,17 +29,27 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( local_scopes_(local_scopes), places_(places), fetch_ctxs_(places), - use_event_(use_event) {} + use_event_(use_event), + running_ops_(0) {} + +void ThreadedSSAGraphExecutor::RunDelayedOps( + const std::unordered_set &delayed_ops) { + for (auto op : delayed_ops) { + op->Run(use_event_); + } +} FeedFetchList ThreadedSSAGraphExecutor::Run( const std::vector &fetch_tensors) { std::unordered_map pending_ops; std::unordered_set pending_vars; - BlockingQueue ready_vars; - std::unordered_set ready_ops; + std::unordered_set delayed_ops; + std::unordered_set after_delayed_ops; + std::unordered_set delayed_vars; + auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) { pending_vars.insert(&var); if (var.generated_op_ == nullptr) { @@ -106,7 +116,14 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto run_all_ready_ops = [&] { for (auto *op : ready_ops) { - RunOp(ready_vars, op); + if (op->IsDelayedOp()) { + delayed_ops.insert(op); + delayed_vars.insert(op->outputs_.begin(), op->outputs_.end()); + ready_vars.Extend(op->outputs_); + continue; + } + running_ops_++; + RunOp(&ready_vars, op); } ready_ops.clear(); }; @@ -124,7 +141,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( // 2. Find ready variable bool timeout; - auto cur_ready_vars = ready_vars.PopAll(1000, &timeout); + auto cur_ready_vars = ready_vars.PopAll(1, &timeout); if (timeout) { if (exception_) { @@ -141,13 +158,24 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto &deps = pending_ops[op]; --deps; if (deps == 0) { - ready_ops.insert(op); + if (delayed_vars.find(ready_var) != delayed_vars.end()) { + after_delayed_ops.insert(op); + } else { + ready_ops.insert(op); + } } } } + if (ready_ops.empty() && !delayed_ops.empty() && running_ops_ == 0) { + RunDelayedOps(delayed_ops); + delayed_ops.clear(); + for (auto *op : after_delayed_ops) { + ready_ops.insert(op); + } + after_delayed_ops.clear(); + } // Keep loop until all vars are ready. } - ++computation_count_; auto sync_computation = [&] { @@ -182,12 +210,13 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( } void ThreadedSSAGraphExecutor::RunOp( - BlockingQueue &ready_var_q, details::OpHandleBase *op) { - auto op_run = [&ready_var_q, op, this] { + BlockingQueue *ready_var_q, details::OpHandleBase *op) { + auto op_run = [ready_var_q, op, this] { try { VLOG(10) << op->Name() << " : " << op->DebugString(); op->Run(use_event_); - ready_var_q.Extend(op->outputs_); + running_ops_--; + ready_var_q->Extend(op->outputs_); } catch (platform::EnforceNotMet ex) { exception_.reset(new platform::EnforceNotMet(ex)); } catch (...) { diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 2ea57ac8f96bc9c2b5c98bcd25d9ce921c3683cd..6193b897e45f02e8f59e8427a89472177b56f398 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -14,7 +14,12 @@ #pragma once -#include +#include +#include +#include +#include +#include + #include #include "ThreadPool.h" // ThreadPool in thrird party #include "paddle/fluid/framework/details/ssa_graph_executor.h" @@ -79,9 +84,11 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ~ThreadedSSAGraphExecutor() {} private: - void RunOp(BlockingQueue &ready_var_q, + void RunOp(BlockingQueue *ready_var_q, details::OpHandleBase *op); + void RunDelayedOps(const std::unordered_set &delayed_ops); + private: std::unique_ptr<::ThreadPool> pool_; std::vector local_scopes_; @@ -89,6 +96,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { platform::DeviceContextPool fetch_ctxs_; const bool use_event_; std::unique_ptr exception_; + std::atomic running_ops_; size_t computation_count_{0}; size_t max_async_computation{100}; diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 577eea92d217f9feb54085b4d9f071494c3c5165..002a6d362f2d8b771634eec0f4aeedebde61e375 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" +#include "paddle/fluid/platform/profiler.h" #include #include @@ -151,6 +152,7 @@ void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::Run(const std::vector &fetch_tensors, const std::string &fetched_var_name) { + platform::RecordBlock b(0); auto fetch_data = member_->executor_->Run(fetch_tensors); *member_->global_scope_->Var(fetched_var_name)->GetMutable() = fetch_data; diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 5e0588fa73241a8752e1b3195a123820165f070d..33e8d3bf210c72f5cfcf7d2f884b1397c25d7eae 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -16,6 +16,7 @@ import core import multiprocessing import framework import executor +import sys __all__ = ['ParallelExecutor'] @@ -35,7 +36,7 @@ class ParallelExecutor(object): places.append(p) if num_threads is None: - num_threads = min(len(places) * 2, multiprocessing.cpu_count()) + num_threads = len(places) startup = framework.default_startup_program() main = framework.default_main_program()