未验证 提交 67abfc15 编写于 作者: L liuyuhui 提交者: GitHub

[Kunlun] fix dead lock for exec_op_count_ (#30718)

上级 13ef444f
......@@ -122,8 +122,11 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
for (auto cur_op : ready_fetch_ops) {
ready_ops->Push(cur_op);
}
// Atomic variable, no need to lock
{
std::lock_guard<std::mutex> lock(mutex_);
exec_op_count_ = 0;
}
platform::XPUPlace cur_place;
std::size_t cur_count = 0;
......@@ -133,6 +136,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
auto cur_op = ready_ops->Pop();
// when execption, get cur_op == nullptr
if (cur_op == nullptr) {
std::lock_guard<std::mutex> lock(mutex_);
exec_op_count_ = op_deps_.size();
break;
}
......@@ -151,6 +155,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return exec_op_count_ >= op_deps_.size(); });
}
if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops);
}
......@@ -255,9 +260,11 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync(
ready_ops->Push(nullptr);
exception_.Catch(std::current_exception());
}
// Atomic variable, no need to lock
{
std::lock_guard<std::mutex> lock(mutex_);
exec_op_count_++;
cv_.notify_all();
}
});
}
// RunOpAsyncMainStream function is used for computed OPs
......@@ -286,9 +293,11 @@ void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream(
ready_ops->Push(nullptr);
exception_.Catch(std::current_exception());
}
// Atomic variable, no need to lock
{
std::lock_guard<std::mutex> lock(mutex_);
exec_op_count_++;
cv_.notify_all();
}
});
}
......
......@@ -80,7 +80,7 @@ class BindThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::mutex mutex_;
std::condition_variable cv_;
std::atomic<unsigned int> exec_op_count_;
uint32_t exec_op_count_;
std::atomic<int> error_state;
void RunOpAsyncMainStream(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册