diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 105e21cab600b642aafc2eb3c619a801fb4c40d7..a6998f45df2d1ad32edbf191f6fbc5552142d0f6 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -124,16 +124,26 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( run_all_ready_ops(); // 2. Find ready variable - VarHandleBase *ready_var = ready_vars.Pop(); - + bool timeout; + auto cur_ready_vars = ready_vars.PopAll(100, &timeout); + + if (timeout) { + if (exception_) { + throw * exception_; + } else { + continue; + } + } // 3. Remove the dependency of ready_var. // Find the ready_ops after the ready_var. - pending_vars.erase(ready_var); - for (auto *op : ready_var->pending_ops_) { - auto &deps = pending_ops[op]; - --deps; - if (deps == 0) { - ready_ops.insert(op); + for (auto ready_var : cur_ready_vars) { + pending_vars.erase(ready_var); + for (auto *op : ready_var->pending_ops_) { + auto &deps = pending_ops[op]; + --deps; + if (deps == 0) { + ready_ops.insert(op); + } } } // Keep loop until all vars are ready. diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 839217031145a276b25992e334657608c3637758..da559d85535197389fd2e19bffde85aa223d38e9 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "ThreadPool.h" // ThreadPool in thrird party #include "paddle/fluid/framework/details/ssa_graph_executor.h" @@ -27,10 +28,10 @@ namespace details { template class BlockingQueue { public: - void Push(const T &v) { + void Push(const T &item) { { std::lock_guard g(mutex_); - q_.emplace_back(v); + q_.emplace_back(item); } cv_.notify_one(); } @@ -56,6 +57,18 @@ class BlockingQueue { return v; } + std::deque PopAll(size_t ms, bool *timeout) { + auto time = + std::chrono::system_clock::now() + std::chrono::milliseconds(ms); + std::unique_lock lock(mutex_); + *timeout = !cv_.wait_until(lock, time, [this] { return !q_.empty(); }); + std::deque ret; + if (!*timeout) { + std::swap(ret, q_); + } + return ret; + } + private: std::mutex mutex_; std::condition_variable cv_;