未验证 提交 67abfc15 编写于 作者: L liuyuhui 提交者: GitHub

[Kunlun] fix dead lock for exec_op_count_ (#30718)

上级 13ef444f
...@@ -122,8 +122,11 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( ...@@ -122,8 +122,11 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
for (auto cur_op : ready_fetch_ops) { for (auto cur_op : ready_fetch_ops) {
ready_ops->Push(cur_op); ready_ops->Push(cur_op);
} }
// Atomic variable, no need to lock
exec_op_count_ = 0; {
std::lock_guard<std::mutex> lock(mutex_);
exec_op_count_ = 0;
}
platform::XPUPlace cur_place; platform::XPUPlace cur_place;
std::size_t cur_count = 0; std::size_t cur_count = 0;
...@@ -133,6 +136,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( ...@@ -133,6 +136,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
auto cur_op = ready_ops->Pop(); auto cur_op = ready_ops->Pop();
// when execption, get cur_op == nullptr // when execption, get cur_op == nullptr
if (cur_op == nullptr) { if (cur_op == nullptr) {
std::lock_guard<std::mutex> lock(mutex_);
exec_op_count_ = op_deps_.size(); exec_op_count_ = op_deps_.size();
break; break;
} }
...@@ -151,6 +155,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream( ...@@ -151,6 +155,7 @@ FetchResultType BindThreadedSSAGraphExecutor::RunMainStream(
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [&] { return exec_op_count_ >= op_deps_.size(); }); cv_.wait(lock, [&] { return exec_op_count_ >= op_deps_.size(); });
} }
if (exception_.IsCaught()) { if (exception_.IsCaught()) {
ExecutionFinal(&fetch_ops); ExecutionFinal(&fetch_ops);
} }
...@@ -255,9 +260,11 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync( ...@@ -255,9 +260,11 @@ void BindThreadedSSAGraphExecutor::RunMultiDeviceOpAsync(
ready_ops->Push(nullptr); ready_ops->Push(nullptr);
exception_.Catch(std::current_exception()); exception_.Catch(std::current_exception());
} }
// Atomic variable, no need to lock {
exec_op_count_++; std::lock_guard<std::mutex> lock(mutex_);
cv_.notify_all(); exec_op_count_++;
cv_.notify_all();
}
}); });
} }
// RunOpAsyncMainStream function is used for computed OPs // RunOpAsyncMainStream function is used for computed OPs
...@@ -286,9 +293,11 @@ void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream( ...@@ -286,9 +293,11 @@ void BindThreadedSSAGraphExecutor::RunOpAsyncMainStream(
ready_ops->Push(nullptr); ready_ops->Push(nullptr);
exception_.Catch(std::current_exception()); exception_.Catch(std::current_exception());
} }
// Atomic variable, no need to lock {
exec_op_count_++; std::lock_guard<std::mutex> lock(mutex_);
cv_.notify_all(); exec_op_count_++;
cv_.notify_all();
}
}); });
} }
......
...@@ -80,7 +80,7 @@ class BindThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -80,7 +80,7 @@ class BindThreadedSSAGraphExecutor : public SSAGraphExecutor {
std::mutex mutex_; std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
std::atomic<unsigned int> exec_op_count_; uint32_t exec_op_count_;
std::atomic<int> error_state; std::atomic<int> error_state;
void RunOpAsyncMainStream( void RunOpAsyncMainStream(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册