diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index c1a28f1d1dca520987a5c85c044927eb38521673..0bf05c3c112c467b9cdae091e0d2d38d4d299655 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -35,11 +35,17 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( FeedFetchList ThreadedSSAGraphExecutor::Run( const std::vector &fetch_tensors) { std::unordered_map pending_ops; - std::unordered_map> pending_vars; + std::unordered_set pending_vars; + + BlockingQueue ready_vars; + std::unordered_set ready_ops; - auto InsertPendingVar = [&pending_vars](VarHandleBase &var) { - pending_vars[&var] = var.generated_op_ == nullptr; + auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) { + pending_vars.insert(&var); + if (var.generated_op_ == nullptr) { + ready_vars.Push(&var); + } }; auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) { @@ -101,7 +107,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( auto run_all_ready_ops = [&] { for (auto *op : ready_ops) { - RunOp(pending_vars, op); + RunOp(ready_vars, op); } ready_ops.clear(); }; @@ -118,29 +124,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( run_all_ready_ops(); // 2. Find ready variable - VarHandleBase *ready_var = nullptr; - for (auto &pair : pending_vars) { - if (pair.second.load(std::memory_order_acquire)) { - ready_var = pair.first; - break; - } - } - - // if there is no variable ready - if (ready_var == nullptr) { - // FIXME use conditional var instead of busy wait. - // if there is an exception, throw it - if (exception_) { - throw * exception_; - } - - VLOG(10) << "============================="; - for (auto &op : pending_ops) { - VLOG(10) << op.first->DebugString(); - } - // keep waiting the ready variables - continue; - } + VarHandleBase *ready_var = ready_vars.Pop(); // 3. Remove the dependency of ready_var. // Find the ready_ops after the ready_var. @@ -189,23 +173,15 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( } void ThreadedSSAGraphExecutor::RunOp( - std::unordered_map> &pending_vars, - details::OpHandleBase *op) { - std::vector *> *ready_buffer = - new std::vector *>(); - for (auto *var : op->outputs_) { - ready_buffer->emplace_back(&pending_vars[var]); - } - - auto op_run = [ready_buffer, op, this] { + BlockingQueue &ready_var_q, details::OpHandleBase *op) { + auto op_run = [&ready_var_q, op, this] { try { VLOG(10) << op->Name() << " : " << op->DebugString(); op->Run(use_event_); - for (auto *ready : *ready_buffer) { - ready->store(true, std::memory_order_release); + for (auto &each : op->outputs_) { + ready_var_q.Push(each); } - delete ready_buffer; } catch (platform::EnforceNotMet ex) { exception_.reset(new platform::EnforceNotMet(ex)); } catch (...) { diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 14b10cd0eb501966136f20ea8bdfc6ab1cd1d179..26ff14786397933382795c1371e6ab68185b0abe 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -24,6 +24,33 @@ class Scope; namespace details { +template +class BlockingQueue { + public: + void Push(const T &v) { + { + std::lock_guard g(mutex_); + q_.emplace_back(v); + } + cv_.notify_one(); + } + + T Pop() { + std::unique_lock lock(mutex_); + while (q_.empty()) { + cv_.wait(lock); + } + T v = q_.front(); + q_.pop_front(); + return v; + } + + private: + std::mutex mutex_; + std::condition_variable cv_; + std::deque q_; +}; + class ThreadedSSAGraphExecutor : public SSAGraphExecutor { public: ThreadedSSAGraphExecutor(size_t num_threads, bool use_event, @@ -38,9 +65,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ~ThreadedSSAGraphExecutor() {} private: - void RunOp( - std::unordered_map> &pending_vars, - details::OpHandleBase *op); + void RunOp(BlockingQueue &ready_var_q, + details::OpHandleBase *op); private: std::unique_ptr<::ThreadPool> pool_;