提交 50f71f50 编写于 作者: Y Yu Yang

Using blocking queue

上级 7dcb217e
......@@ -35,11 +35,17 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
FeedFetchList ThreadedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
std::unordered_map<OpHandleBase *, size_t> pending_ops;
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
std::unordered_set<VarHandleBase *> pending_vars;
BlockingQueue<VarHandleBase *> ready_vars;
std::unordered_set<OpHandleBase *> ready_ops;
auto InsertPendingVar = [&pending_vars](VarHandleBase &var) {
pending_vars[&var] = var.generated_op_ == nullptr;
auto InsertPendingVar = [&pending_vars, &ready_vars](VarHandleBase &var) {
pending_vars.insert(&var);
if (var.generated_op_ == nullptr) {
ready_vars.Push(&var);
}
};
auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
......@@ -101,7 +107,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
auto run_all_ready_ops = [&] {
for (auto *op : ready_ops) {
RunOp(pending_vars, op);
RunOp(ready_vars, op);
}
ready_ops.clear();
};
......@@ -118,29 +124,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
run_all_ready_ops();
// 2. Find ready variable
VarHandleBase *ready_var = nullptr;
for (auto &pair : pending_vars) {
if (pair.second.load(std::memory_order_acquire)) {
ready_var = pair.first;
break;
}
}
// if there is no variable ready
if (ready_var == nullptr) {
// FIXME use conditional var instead of busy wait.
// if there is an exception, throw it
if (exception_) {
throw * exception_;
}
VLOG(10) << "=============================";
for (auto &op : pending_ops) {
VLOG(10) << op.first->DebugString();
}
// keep waiting the ready variables
continue;
}
VarHandleBase *ready_var = ready_vars.Pop();
// 3. Remove the dependency of ready_var.
// Find the ready_ops after the ready_var.
......@@ -189,23 +173,15 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}
void ThreadedSSAGraphExecutor::RunOp(
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
details::OpHandleBase *op) {
std::vector<std::atomic<bool> *> *ready_buffer =
new std::vector<std::atomic<bool> *>();
for (auto *var : op->outputs_) {
ready_buffer->emplace_back(&pending_vars[var]);
}
auto op_run = [ready_buffer, op, this] {
BlockingQueue<VarHandleBase *> &ready_var_q, details::OpHandleBase *op) {
auto op_run = [&ready_var_q, op, this] {
try {
VLOG(10) << op->Name() << " : " << op->DebugString();
op->Run(use_event_);
for (auto *ready : *ready_buffer) {
ready->store(true, std::memory_order_release);
for (auto &each : op->outputs_) {
ready_var_q.Push(each);
}
delete ready_buffer;
} catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex));
} catch (...) {
......
......@@ -24,6 +24,33 @@ class Scope;
namespace details {
template <typename T>
class BlockingQueue {
public:
void Push(const T &v) {
{
std::lock_guard<std::mutex> g(mutex_);
q_.emplace_back(v);
}
cv_.notify_one();
}
T Pop() {
std::unique_lock<std::mutex> lock(mutex_);
while (q_.empty()) {
cv_.wait(lock);
}
T v = q_.front();
q_.pop_front();
return v;
}
private:
std::mutex mutex_;
std::condition_variable cv_;
std::deque<T> q_;
};
class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
ThreadedSSAGraphExecutor(size_t num_threads, bool use_event,
......@@ -38,9 +65,8 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
~ThreadedSSAGraphExecutor() {}
private:
void RunOp(
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
details::OpHandleBase *op);
void RunOp(BlockingQueue<VarHandleBase *> &ready_var_q,
details::OpHandleBase *op);
private:
std::unique_ptr<::ThreadPool> pool_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册