提交 50f71f50 编写于 作者: Y Yu Yang

Using blocking queue

上级 7dcb217e
...@@ -35,11 +35,17 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ...@@ -35,11 +35,17 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
FeedFetchList ThreadedSSAGraphExecutor::Run( FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
std::unordered_map<OpHandleBase *, size_t> pending_ops; std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars; std::unordered_set<VarHandleBase *> pending_vars;
BlockingQueue<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops; std::unordered_set<OpHandleBase *> ready_ops;
auto InsertPendingVar = [&pending_vars](VarHandleBase &var) { auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
pending_vars[&var] = var.generated_op_ == nullptr; pending_vars.insert(&var);
if (var.generated_op_ == nullptr) {
ready_vars.Push(&var);
}
}; };
auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) { auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
...@@ -101,7 +107,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -101,7 +107,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto run_all_ready_ops = [&] { auto run_all_ready_ops = [&] {
for (auto *op : ready_ops) { for (auto *op : ready_ops) {
RunOp(pending_vars, op); RunOp(ready_vars, op);
} }
ready_ops.clear(); ready_ops.clear();
}; };
...@@ -118,29 +124,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -118,29 +124,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
run_all_ready_ops(); run_all_ready_ops();
// 2. Find ready variable // 2. Find ready variable
VarHandleBase *ready_var = nullptr; VarHandleBase *ready_var = ready_vars.Pop();
for (auto &pair : pending_vars) {
if (pair.second.load(std::memory_order_acquire)) {
ready_var = pair.first;
break;
}
}
// if there is no variable ready
if (ready_var == nullptr) {
// FIXME use conditional var instead of busy wait.
// if there is an exception, throw it
if (exception_) {
throw * exception_;
}
VLOG(10) << "=============================";
for (auto &op : pending_ops) {
VLOG(10) << op.first->DebugString();
}
// keep waiting the ready variables
continue;
}
// 3. Remove the dependency of ready_var. // 3. Remove the dependency of ready_var.
// Find the ready_ops after the ready_var. // Find the ready_ops after the ready_var.
...@@ -189,23 +173,15 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -189,23 +173,15 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
} }
void ThreadedSSAGraphExecutor::RunOp( void ThreadedSSAGraphExecutor::RunOp(
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars, BlockingQueue<VarHandleBase *> &ready_var_q, details::OpHandleBase *op) {
details::OpHandleBase *op) { auto op_run = [&ready_var_q, op, this] {
std::vector<std::atomic<bool> *> *ready_buffer =
new std::vector<std::atomic<bool> *>();
for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]);
}
auto op_run = [ready_buffer, op, this] {
try { try {
VLOG(10) << op->Name() << " : " << op->DebugString(); VLOG(10) << op->Name() << " : " << op->DebugString();
op->Run(use_event_); op->Run(use_event_);
for (auto *ready : *ready_buffer) { for (auto &each : op->outputs_) {
ready->store(true, std::memory_order_release); ready_var_q.Push(each);
} }
delete ready_buffer;
} catch (platform::EnforceNotMet ex) { } catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex)); exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) { } catch (...) {
......
...@@ -24,6 +24,33 @@ class Scope; ...@@ -24,6 +24,33 @@ class Scope;
namespace details { namespace details {
template <typename T>
class BlockingQueue {
public:
void Push(const T &v) {
{
std::lock_guard<std::mutex> g(mutex_);
q_.emplace_back(v);
}
cv_.notify_one();
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
while (q_.empty()) {
cv_.wait(lock);
}
T v = q_.front();
q_.pop_front();
return v;
}
private:
std::mutex mutex_;
std::condition_variable cv_;
std::deque<T> q_;
};
class ThreadedSSAGraphExecutor : public SSAGraphExecutor { class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public: public:
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event, ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
...@@ -38,9 +65,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -38,9 +65,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() {} ~ThreadedSSAGraphExecutor() {}
private: private:
void RunOp( void RunOp(BlockingQueue<VarHandleBase *> &ready_var_q,
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars, details::OpHandleBase *op);
details::OpHandleBase *op);
private: private:
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册